Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
iperov
GitHub Repository: iperov/deepfacelab
Path: blob/master/samplelib/SampleGeneratorImageTemporal.py
628 views
1
import traceback
2
3
import cv2
4
import numpy as np
5
6
from core.joblib import SubprocessGenerator, ThisThreadGenerator
7
from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor,
8
SampleType)
9
10
11
'''
12
output_sample_types = [
13
[SampleProcessor.TypeFlags, size, (optional)random_sub_size] ,
14
...
15
]
16
'''
17
class SampleGeneratorImageTemporal(SampleGeneratorBase):
18
def __init__ (self, samples_path, debug, batch_size, temporal_image_count, sample_process_options=SampleProcessor.Options(), output_sample_types=[], **kwargs):
19
super().__init__(debug, batch_size)
20
21
self.temporal_image_count = temporal_image_count
22
self.sample_process_options = sample_process_options
23
self.output_sample_types = output_sample_types
24
25
self.samples = SampleLoader.load (SampleType.IMAGE, samples_path)
26
27
self.generator_samples = [ self.samples ]
28
self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, 0 )] if self.debug else \
29
[iter_utils.SubprocessGenerator ( self.batch_func, 0 )]
30
31
self.generator_counter = -1
32
33
def __iter__(self):
34
return self
35
36
def __next__(self):
37
self.generator_counter += 1
38
generator = self.generators[self.generator_counter % len(self.generators) ]
39
return next(generator)
40
41
def batch_func(self, generator_id):
42
samples = self.generator_samples[generator_id]
43
samples_len = len(samples)
44
if samples_len == 0:
45
raise ValueError('No training data provided.')
46
47
mult_max = 4
48
samples_sub_len = samples_len - ( (self.temporal_image_count)*mult_max - (mult_max-1) )
49
50
if samples_sub_len <= 0:
51
raise ValueError('Not enough samples to fit temporal line.')
52
53
shuffle_idxs = []
54
55
while True:
56
57
batches = None
58
for n_batch in range(self.batch_size):
59
60
if len(shuffle_idxs) == 0:
61
shuffle_idxs = [ *range(samples_sub_len) ]
62
np.random.shuffle (shuffle_idxs)
63
64
idx = shuffle_idxs.pop()
65
66
temporal_samples = []
67
mult = np.random.randint(mult_max)+1
68
for i in range( self.temporal_image_count ):
69
sample = samples[ idx+i*mult ]
70
try:
71
temporal_samples += SampleProcessor.process ([sample], self.sample_process_options, self.output_sample_types, self.debug)[0]
72
except:
73
raise Exception ("Exception occured in sample %s. Error: %s" % (sample.filename, traceback.format_exc() ) )
74
75
if batches is None:
76
batches = [ [] for _ in range(len(temporal_samples)) ]
77
78
for i in range(len(temporal_samples)):
79
batches[i].append ( temporal_samples[i] )
80
81
yield [ np.array(batch) for batch in batches]
82
83