Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
iperov
GitHub Repository: iperov/deepfacelab
Path: blob/master/samplelib/SampleGeneratorFace.py
628 views
1
import multiprocessing
2
import time
3
import traceback
4
5
import cv2
6
import numpy as np
7
8
from core import mplib
9
from core.interact import interact as io
10
from core.joblib import SubprocessGenerator, ThisThreadGenerator
11
from facelib import LandmarksProcessor
12
from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor,
13
SampleType)
14
15
16
'''
17
arg
18
output_sample_types = [
19
[SampleProcessor.TypeFlags, size, (optional) {} opts ] ,
20
...
21
]
22
'''
23
class SampleGeneratorFace(SampleGeneratorBase):
24
def __init__ (self, samples_path, debug=False, batch_size=1,
25
random_ct_samples_path=None,
26
sample_process_options=SampleProcessor.Options(),
27
output_sample_types=[],
28
uniform_yaw_distribution=False,
29
generators_count=4,
30
raise_on_no_data=True,
31
**kwargs):
32
33
super().__init__(debug, batch_size)
34
self.initialized = False
35
self.sample_process_options = sample_process_options
36
self.output_sample_types = output_sample_types
37
38
if self.debug:
39
self.generators_count = 1
40
else:
41
self.generators_count = max(1, generators_count)
42
43
samples = SampleLoader.load (SampleType.FACE, samples_path)
44
self.samples_len = len(samples)
45
46
if self.samples_len == 0:
47
if raise_on_no_data:
48
raise ValueError('No training data provided.')
49
else:
50
return
51
52
if uniform_yaw_distribution:
53
samples_pyr = [ ( idx, sample.get_pitch_yaw_roll() ) for idx, sample in enumerate(samples) ]
54
55
grads = 128
56
#instead of math.pi / 2, using -1.2,+1.2 because actually maximum yaw for 2DFAN landmarks are -1.2+1.2
57
grads_space = np.linspace (-1.2, 1.2,grads)
58
59
yaws_sample_list = [None]*grads
60
for g in io.progress_bar_generator ( range(grads), "Sort by yaw"):
61
yaw = grads_space[g]
62
next_yaw = grads_space[g+1] if g < grads-1 else yaw
63
64
yaw_samples = []
65
for idx, pyr in samples_pyr:
66
s_yaw = -pyr[1]
67
if (g == 0 and s_yaw < next_yaw) or \
68
(g < grads-1 and s_yaw >= yaw and s_yaw < next_yaw) or \
69
(g == grads-1 and s_yaw >= yaw):
70
yaw_samples += [ idx ]
71
if len(yaw_samples) > 0:
72
yaws_sample_list[g] = yaw_samples
73
74
yaws_sample_list = [ y for y in yaws_sample_list if y is not None ]
75
76
index_host = mplib.Index2DHost( yaws_sample_list )
77
else:
78
index_host = mplib.IndexHost(self.samples_len)
79
80
if random_ct_samples_path is not None:
81
ct_samples = SampleLoader.load (SampleType.FACE, random_ct_samples_path)
82
ct_index_host = mplib.IndexHost( len(ct_samples) )
83
else:
84
ct_samples = None
85
ct_index_host = None
86
87
if self.debug:
88
self.generators = [ThisThreadGenerator ( self.batch_func, (samples, index_host.create_cli(), ct_samples, ct_index_host.create_cli() if ct_index_host is not None else None) )]
89
else:
90
self.generators = [SubprocessGenerator ( self.batch_func, (samples, index_host.create_cli(), ct_samples, ct_index_host.create_cli() if ct_index_host is not None else None), start_now=False ) \
91
for i in range(self.generators_count) ]
92
93
SubprocessGenerator.start_in_parallel( self.generators )
94
95
self.generator_counter = -1
96
97
self.initialized = True
98
99
#overridable
100
def is_initialized(self):
101
return self.initialized
102
103
def __iter__(self):
104
return self
105
106
def __next__(self):
107
if not self.initialized:
108
return []
109
110
self.generator_counter += 1
111
generator = self.generators[self.generator_counter % len(self.generators) ]
112
return next(generator)
113
114
def batch_func(self, param ):
115
samples, index_host, ct_samples, ct_index_host = param
116
117
bs = self.batch_size
118
while True:
119
batches = None
120
121
indexes = index_host.multi_get(bs)
122
ct_indexes = ct_index_host.multi_get(bs) if ct_samples is not None else None
123
124
t = time.time()
125
for n_batch in range(bs):
126
sample_idx = indexes[n_batch]
127
sample = samples[sample_idx]
128
129
ct_sample = None
130
if ct_samples is not None:
131
ct_sample = ct_samples[ct_indexes[n_batch]]
132
133
try:
134
x, = SampleProcessor.process ([sample], self.sample_process_options, self.output_sample_types, self.debug, ct_sample=ct_sample)
135
except:
136
raise Exception ("Exception occured in sample %s. Error: %s" % (sample.filename, traceback.format_exc() ) )
137
138
if batches is None:
139
batches = [ [] for _ in range(len(x)) ]
140
141
for i in range(len(x)):
142
batches[i].append ( x[i] )
143
144
yield [ np.array(batch) for batch in batches]
145
146