Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
iperov
GitHub Repository: iperov/deepfacelab
Path: blob/master/samplelib/SampleGeneratorFacePerson.py
628 views
1
import copy
2
import multiprocessing
3
import traceback
4
5
import cv2
6
import numpy as np
7
8
from core import mplib
9
from core.joblib import SubprocessGenerator, ThisThreadGenerator
10
from facelib import LandmarksProcessor
11
from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor,
12
SampleType)
13
14
15
16
class Index2DHost():
17
"""
18
Provides random shuffled 2D indexes for multiprocesses
19
"""
20
def __init__(self, indexes2D):
21
self.sq = multiprocessing.Queue()
22
self.cqs = []
23
self.clis = []
24
self.thread = threading.Thread(target=self.host_thread, args=(indexes2D,) )
25
self.thread.daemon = True
26
self.thread.start()
27
28
def host_thread(self, indexes2D):
29
indexes_counts_len = len(indexes2D)
30
31
idxs = [*range(indexes_counts_len)]
32
idxs_2D = [None]*indexes_counts_len
33
shuffle_idxs = []
34
shuffle_idxs_2D = [None]*indexes_counts_len
35
for i in range(indexes_counts_len):
36
idxs_2D[i] = indexes2D[i]
37
shuffle_idxs_2D[i] = []
38
39
sq = self.sq
40
41
while True:
42
while not sq.empty():
43
obj = sq.get()
44
cq_id, cmd = obj[0], obj[1]
45
46
if cmd == 0: #get_1D
47
count = obj[2]
48
49
result = []
50
for i in range(count):
51
if len(shuffle_idxs) == 0:
52
shuffle_idxs = idxs.copy()
53
np.random.shuffle(shuffle_idxs)
54
result.append(shuffle_idxs.pop())
55
self.cqs[cq_id].put (result)
56
elif cmd == 1: #get_2D
57
targ_idxs,count = obj[2], obj[3]
58
result = []
59
60
for targ_idx in targ_idxs:
61
sub_idxs = []
62
for i in range(count):
63
ar = shuffle_idxs_2D[targ_idx]
64
if len(ar) == 0:
65
ar = shuffle_idxs_2D[targ_idx] = idxs_2D[targ_idx].copy()
66
np.random.shuffle(ar)
67
sub_idxs.append(ar.pop())
68
result.append (sub_idxs)
69
self.cqs[cq_id].put (result)
70
71
time.sleep(0.001)
72
73
def create_cli(self):
74
cq = multiprocessing.Queue()
75
self.cqs.append ( cq )
76
cq_id = len(self.cqs)-1
77
return Index2DHost.Cli(self.sq, cq, cq_id)
78
79
# disable pickling
80
def __getstate__(self):
81
return dict()
82
def __setstate__(self, d):
83
self.__dict__.update(d)
84
85
class Cli():
86
def __init__(self, sq, cq, cq_id):
87
self.sq = sq
88
self.cq = cq
89
self.cq_id = cq_id
90
91
def get_1D(self, count):
92
self.sq.put ( (self.cq_id,0, count) )
93
94
while True:
95
if not self.cq.empty():
96
return self.cq.get()
97
time.sleep(0.001)
98
99
def get_2D(self, idxs, count):
100
self.sq.put ( (self.cq_id,1,idxs,count) )
101
102
while True:
103
if not self.cq.empty():
104
return self.cq.get()
105
time.sleep(0.001)
106
107
'''
108
arg
109
output_sample_types = [
110
[SampleProcessor.TypeFlags, size, (optional) {} opts ] ,
111
...
112
]
113
'''
114
class SampleGeneratorFacePerson(SampleGeneratorBase):
115
def __init__ (self, samples_path, debug=False, batch_size=1,
116
sample_process_options=SampleProcessor.Options(),
117
output_sample_types=[],
118
person_id_mode=1,
119
**kwargs):
120
121
super().__init__(debug, batch_size)
122
self.sample_process_options = sample_process_options
123
self.output_sample_types = output_sample_types
124
self.person_id_mode = person_id_mode
125
126
raise NotImplementedError("Currently SampleGeneratorFacePerson is not implemented.")
127
128
samples_host = SampleLoader.mp_host (SampleType.FACE, samples_path)
129
samples = samples_host.get_list()
130
self.samples_len = len(samples)
131
132
if self.samples_len == 0:
133
raise ValueError('No training data provided.')
134
135
unique_person_names = { sample.person_name for sample in samples }
136
persons_name_idxs = { person_name : [] for person_name in unique_person_names }
137
for i,sample in enumerate(samples):
138
persons_name_idxs[sample.person_name].append (i)
139
indexes2D = [ persons_name_idxs[person_name] for person_name in unique_person_names ]
140
index2d_host = Index2DHost(indexes2D)
141
142
if self.debug:
143
self.generators_count = 1
144
self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, (samples_host.create_cli(), index2d_host.create_cli(),) )]
145
else:
146
self.generators_count = np.clip(multiprocessing.cpu_count(), 2, 4)
147
self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, (samples_host.create_cli(), index2d_host.create_cli(),) ) for i in range(self.generators_count) ]
148
149
self.generator_counter = -1
150
151
def __iter__(self):
152
return self
153
154
def __next__(self):
155
self.generator_counter += 1
156
generator = self.generators[self.generator_counter % len(self.generators) ]
157
return next(generator)
158
159
def batch_func(self, param ):
160
samples, index2d_host, = param
161
bs = self.batch_size
162
163
while True:
164
person_idxs = index2d_host.get_1D(bs)
165
samples_idxs = index2d_host.get_2D(person_idxs, 1)
166
167
batches = None
168
for n_batch in range(bs):
169
person_id = person_idxs[n_batch]
170
sample_idx = samples_idxs[n_batch][0]
171
172
sample = samples[ sample_idx ]
173
try:
174
x, = SampleProcessor.process ([sample], self.sample_process_options, self.output_sample_types, self.debug)
175
except:
176
raise Exception ("Exception occured in sample %s. Error: %s" % (sample.filename, traceback.format_exc() ) )
177
178
if batches is None:
179
batches = [ [] for _ in range(len(x)) ]
180
181
batches += [ [] ]
182
i_person_id = len(batches)-1
183
184
for i in range(len(x)):
185
batches[i].append ( x[i] )
186
187
batches[i_person_id].append ( np.array([person_id]) )
188
189
yield [ np.array(batch) for batch in batches]
190
191
@staticmethod
192
def get_person_id_max_count(samples_path):
193
return SampleLoader.get_person_id_max_count(samples_path)
194
195
"""
196
if self.person_id_mode==1:
197
samples_len = len(samples)
198
samples_idxs = [*range(samples_len)]
199
shuffle_idxs = []
200
elif self.person_id_mode==2:
201
persons_count = len(samples)
202
203
person_idxs = []
204
for j in range(persons_count):
205
for i in range(j+1,persons_count):
206
person_idxs += [ [i,j] ]
207
208
shuffle_person_idxs = []
209
210
samples_idxs = [None]*persons_count
211
shuffle_idxs = [None]*persons_count
212
213
for i in range(persons_count):
214
samples_idxs[i] = [*range(len(samples[i]))]
215
shuffle_idxs[i] = []
216
elif self.person_id_mode==3:
217
persons_count = len(samples)
218
219
person_idxs = [ *range(persons_count) ]
220
shuffle_person_idxs = []
221
222
samples_idxs = [None]*persons_count
223
shuffle_idxs = [None]*persons_count
224
225
for i in range(persons_count):
226
samples_idxs[i] = [*range(len(samples[i]))]
227
shuffle_idxs[i] = []
228
229
if self.person_id_mode==2:
230
if len(shuffle_person_idxs) == 0:
231
shuffle_person_idxs = person_idxs.copy()
232
np.random.shuffle(shuffle_person_idxs)
233
person_ids = shuffle_person_idxs.pop()
234
235
236
batches = None
237
for n_batch in range(self.batch_size):
238
239
if self.person_id_mode==1:
240
if len(shuffle_idxs) == 0:
241
shuffle_idxs = samples_idxs.copy()
242
np.random.shuffle(shuffle_idxs) ###
243
244
idx = shuffle_idxs.pop()
245
sample = samples[ idx ]
246
247
try:
248
x, = SampleProcessor.process ([sample], self.sample_process_options, self.output_sample_types, self.debug)
249
except:
250
raise Exception ("Exception occured in sample %s. Error: %s" % (sample.filename, traceback.format_exc() ) )
251
252
if type(x) != tuple and type(x) != list:
253
raise Exception('SampleProcessor.process returns NOT tuple/list')
254
255
if batches is None:
256
batches = [ [] for _ in range(len(x)) ]
257
258
batches += [ [] ]
259
i_person_id = len(batches)-1
260
261
for i in range(len(x)):
262
batches[i].append ( x[i] )
263
264
batches[i_person_id].append ( np.array([sample.person_id]) )
265
266
267
elif self.person_id_mode==2:
268
person_id1, person_id2 = person_ids
269
270
if len(shuffle_idxs[person_id1]) == 0:
271
shuffle_idxs[person_id1] = samples_idxs[person_id1].copy()
272
np.random.shuffle(shuffle_idxs[person_id1])
273
274
idx = shuffle_idxs[person_id1].pop()
275
sample1 = samples[person_id1][idx]
276
277
if len(shuffle_idxs[person_id2]) == 0:
278
shuffle_idxs[person_id2] = samples_idxs[person_id2].copy()
279
np.random.shuffle(shuffle_idxs[person_id2])
280
281
idx = shuffle_idxs[person_id2].pop()
282
sample2 = samples[person_id2][idx]
283
284
if sample1 is not None and sample2 is not None:
285
try:
286
x1, = SampleProcessor.process ([sample1], self.sample_process_options, self.output_sample_types, self.debug)
287
except:
288
raise Exception ("Exception occured in sample %s. Error: %s" % (sample1.filename, traceback.format_exc() ) )
289
290
try:
291
x2, = SampleProcessor.process ([sample2], self.sample_process_options, self.output_sample_types, self.debug)
292
except:
293
raise Exception ("Exception occured in sample %s. Error: %s" % (sample2.filename, traceback.format_exc() ) )
294
295
x1_len = len(x1)
296
if batches is None:
297
batches = [ [] for _ in range(x1_len) ]
298
batches += [ [] ]
299
i_person_id1 = len(batches)-1
300
301
batches += [ [] for _ in range(len(x2)) ]
302
batches += [ [] ]
303
i_person_id2 = len(batches)-1
304
305
for i in range(x1_len):
306
batches[i].append ( x1[i] )
307
308
for i in range(len(x2)):
309
batches[x1_len+1+i].append ( x2[i] )
310
311
batches[i_person_id1].append ( np.array([sample1.person_id]) )
312
313
batches[i_person_id2].append ( np.array([sample2.person_id]) )
314
315
elif self.person_id_mode==3:
316
if len(shuffle_person_idxs) == 0:
317
shuffle_person_idxs = person_idxs.copy()
318
np.random.shuffle(shuffle_person_idxs)
319
person_id = shuffle_person_idxs.pop()
320
321
if len(shuffle_idxs[person_id]) == 0:
322
shuffle_idxs[person_id] = samples_idxs[person_id].copy()
323
np.random.shuffle(shuffle_idxs[person_id])
324
325
idx = shuffle_idxs[person_id].pop()
326
sample1 = samples[person_id][idx]
327
328
if len(shuffle_idxs[person_id]) == 0:
329
shuffle_idxs[person_id] = samples_idxs[person_id].copy()
330
np.random.shuffle(shuffle_idxs[person_id])
331
332
idx = shuffle_idxs[person_id].pop()
333
sample2 = samples[person_id][idx]
334
335
if sample1 is not None and sample2 is not None:
336
try:
337
x1, = SampleProcessor.process ([sample1], self.sample_process_options, self.output_sample_types, self.debug)
338
except:
339
raise Exception ("Exception occured in sample %s. Error: %s" % (sample1.filename, traceback.format_exc() ) )
340
341
try:
342
x2, = SampleProcessor.process ([sample2], self.sample_process_options, self.output_sample_types, self.debug)
343
except:
344
raise Exception ("Exception occured in sample %s. Error: %s" % (sample2.filename, traceback.format_exc() ) )
345
346
x1_len = len(x1)
347
if batches is None:
348
batches = [ [] for _ in range(x1_len) ]
349
batches += [ [] ]
350
i_person_id1 = len(batches)-1
351
352
batches += [ [] for _ in range(len(x2)) ]
353
batches += [ [] ]
354
i_person_id2 = len(batches)-1
355
356
for i in range(x1_len):
357
batches[i].append ( x1[i] )
358
359
for i in range(len(x2)):
360
batches[x1_len+1+i].append ( x2[i] )
361
362
batches[i_person_id1].append ( np.array([sample1.person_id]) )
363
364
batches[i_person_id2].append ( np.array([sample2.person_id]) )
365
"""
366
367