Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
iperov
GitHub Repository: iperov/deepfacelab
Path: blob/master/samplelib/SampleGeneratorImage.py
628 views
1
import traceback
2
3
import cv2
4
import numpy as np
5
6
from core.joblib import SubprocessGenerator, ThisThreadGenerator
7
from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor,
8
SampleType)
9
10
11
class SampleGeneratorImage(SampleGeneratorBase):
12
def __init__ (self, samples_path, debug, batch_size, sample_process_options=SampleProcessor.Options(), output_sample_types=[], raise_on_no_data=True, **kwargs):
13
super().__init__(debug, batch_size)
14
self.initialized = False
15
self.sample_process_options = sample_process_options
16
self.output_sample_types = output_sample_types
17
18
samples = SampleLoader.load (SampleType.IMAGE, samples_path)
19
20
if len(samples) == 0:
21
if raise_on_no_data:
22
raise ValueError('No training data provided.')
23
return
24
25
self.generators = [ThisThreadGenerator ( self.batch_func, samples )] if self.debug else \
26
[SubprocessGenerator ( self.batch_func, samples )]
27
28
self.generator_counter = -1
29
self.initialized = True
30
31
def __iter__(self):
32
return self
33
34
def __next__(self):
35
self.generator_counter += 1
36
generator = self.generators[self.generator_counter % len(self.generators) ]
37
return next(generator)
38
39
def batch_func(self, samples):
40
samples_len = len(samples)
41
42
43
idxs = [ *range(samples_len) ]
44
shuffle_idxs = []
45
46
while True:
47
48
batches = None
49
for n_batch in range(self.batch_size):
50
51
if len(shuffle_idxs) == 0:
52
shuffle_idxs = idxs.copy()
53
np.random.shuffle (shuffle_idxs)
54
55
idx = shuffle_idxs.pop()
56
sample = samples[idx]
57
58
x, = SampleProcessor.process ([sample], self.sample_process_options, self.output_sample_types, self.debug)
59
60
if batches is None:
61
batches = [ [] for _ in range(len(x)) ]
62
63
for i in range(len(x)):
64
batches[i].append ( x[i] )
65
66
yield [ np.array(batch) for batch in batches]
67
68