Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
iperov
GitHub Repository: iperov/deepfacelab
Path: blob/master/merger/InteractiveMergerSubprocessor.py
628 views
1
import multiprocessing
2
import os
3
import pickle
4
import sys
5
import traceback
6
from pathlib import Path
7
8
import numpy as np
9
10
from core import imagelib, pathex
11
from core.cv2ex import *
12
from core.interact import interact as io
13
from core.joblib import Subprocessor
14
from merger import MergeFaceAvatar, MergeMasked, MergerConfig
15
16
from .MergerScreen import Screen, ScreenManager
17
18
MERGER_DEBUG = False
19
class InteractiveMergerSubprocessor(Subprocessor):
20
21
class Frame(object):
22
def __init__(self, prev_temporal_frame_infos=None,
23
frame_info=None,
24
next_temporal_frame_infos=None):
25
self.prev_temporal_frame_infos = prev_temporal_frame_infos
26
self.frame_info = frame_info
27
self.next_temporal_frame_infos = next_temporal_frame_infos
28
self.output_filepath = None
29
self.output_mask_filepath = None
30
31
self.idx = None
32
self.cfg = None
33
self.is_done = False
34
self.is_processing = False
35
self.is_shown = False
36
self.image = None
37
38
class ProcessingFrame(object):
39
def __init__(self, idx=None,
40
cfg=None,
41
prev_temporal_frame_infos=None,
42
frame_info=None,
43
next_temporal_frame_infos=None,
44
output_filepath=None,
45
output_mask_filepath=None,
46
need_return_image = False):
47
self.idx = idx
48
self.cfg = cfg
49
self.prev_temporal_frame_infos = prev_temporal_frame_infos
50
self.frame_info = frame_info
51
self.next_temporal_frame_infos = next_temporal_frame_infos
52
self.output_filepath = output_filepath
53
self.output_mask_filepath = output_mask_filepath
54
55
self.need_return_image = need_return_image
56
if self.need_return_image:
57
self.image = None
58
59
class Cli(Subprocessor.Cli):
60
61
#override
62
def on_initialize(self, client_dict):
63
self.log_info ('Running on %s.' % (client_dict['device_name']) )
64
self.device_idx = client_dict['device_idx']
65
self.device_name = client_dict['device_name']
66
self.predictor_func = client_dict['predictor_func']
67
self.predictor_input_shape = client_dict['predictor_input_shape']
68
self.face_enhancer_func = client_dict['face_enhancer_func']
69
self.xseg_256_extract_func = client_dict['xseg_256_extract_func']
70
71
72
#transfer and set stdin in order to work code.interact in debug subprocess
73
stdin_fd = client_dict['stdin_fd']
74
if stdin_fd is not None:
75
sys.stdin = os.fdopen(stdin_fd)
76
77
return None
78
79
#override
80
def process_data(self, pf): #pf=ProcessingFrame
81
cfg = pf.cfg.copy()
82
83
frame_info = pf.frame_info
84
filepath = frame_info.filepath
85
86
if len(frame_info.landmarks_list) == 0:
87
88
if cfg.mode == 'raw-predict':
89
h,w,c = self.predictor_input_shape
90
img_bgr = np.zeros( (h,w,3), dtype=np.uint8)
91
img_mask = np.zeros( (h,w,1), dtype=np.uint8)
92
else:
93
self.log_info (f'no faces found for {filepath.name}, copying without faces')
94
img_bgr = cv2_imread(filepath)
95
imagelib.normalize_channels(img_bgr, 3)
96
h,w,c = img_bgr.shape
97
img_mask = np.zeros( (h,w,1), dtype=img_bgr.dtype)
98
99
cv2_imwrite (pf.output_filepath, img_bgr)
100
cv2_imwrite (pf.output_mask_filepath, img_mask)
101
102
if pf.need_return_image:
103
pf.image = np.concatenate ([img_bgr, img_mask], axis=-1)
104
105
else:
106
if cfg.type == MergerConfig.TYPE_MASKED:
107
try:
108
final_img = MergeMasked (self.predictor_func, self.predictor_input_shape,
109
face_enhancer_func=self.face_enhancer_func,
110
xseg_256_extract_func=self.xseg_256_extract_func,
111
cfg=cfg,
112
frame_info=frame_info)
113
except Exception as e:
114
e_str = traceback.format_exc()
115
if 'MemoryError' in e_str:
116
raise Subprocessor.SilenceException
117
else:
118
raise Exception( f'Error while merging file [{filepath}]: {e_str}' )
119
120
elif cfg.type == MergerConfig.TYPE_FACE_AVATAR:
121
final_img = MergeFaceAvatar (self.predictor_func, self.predictor_input_shape,
122
cfg, pf.prev_temporal_frame_infos,
123
pf.frame_info,
124
pf.next_temporal_frame_infos )
125
126
cv2_imwrite (pf.output_filepath, final_img[...,0:3] )
127
cv2_imwrite (pf.output_mask_filepath, final_img[...,3:4] )
128
129
if pf.need_return_image:
130
pf.image = final_img
131
132
return pf
133
134
#overridable
135
def get_data_name (self, pf):
136
#return string identificator of your data
137
return pf.frame_info.filepath
138
139
140
141
142
#override
143
def __init__(self, is_interactive, merger_session_filepath, predictor_func, predictor_input_shape, face_enhancer_func, xseg_256_extract_func, merger_config, frames, frames_root_path, output_path, output_mask_path, model_iter, subprocess_count=4):
144
if len (frames) == 0:
145
raise ValueError ("len (frames) == 0")
146
147
super().__init__('Merger', InteractiveMergerSubprocessor.Cli, io_loop_sleep_time=0.001)
148
149
self.is_interactive = is_interactive
150
self.merger_session_filepath = Path(merger_session_filepath)
151
self.merger_config = merger_config
152
153
self.predictor_func = predictor_func
154
self.predictor_input_shape = predictor_input_shape
155
156
self.face_enhancer_func = face_enhancer_func
157
self.xseg_256_extract_func = xseg_256_extract_func
158
159
self.frames_root_path = frames_root_path
160
self.output_path = output_path
161
self.output_mask_path = output_mask_path
162
self.model_iter = model_iter
163
164
self.prefetch_frame_count = self.process_count = subprocess_count
165
166
session_data = None
167
if self.is_interactive and self.merger_session_filepath.exists():
168
io.input_skip_pending()
169
if io.input_bool ("Use saved session?", True):
170
try:
171
with open( str(self.merger_session_filepath), "rb") as f:
172
session_data = pickle.loads(f.read())
173
174
except Exception as e:
175
pass
176
177
rewind_to_frame_idx = None
178
self.frames = frames
179
self.frames_idxs = [ *range(len(self.frames)) ]
180
self.frames_done_idxs = []
181
182
if self.is_interactive and session_data is not None:
183
# Loaded session data, check it
184
s_frames = session_data.get('frames', None)
185
s_frames_idxs = session_data.get('frames_idxs', None)
186
s_frames_done_idxs = session_data.get('frames_done_idxs', None)
187
s_model_iter = session_data.get('model_iter', None)
188
189
frames_equal = (s_frames is not None) and \
190
(s_frames_idxs is not None) and \
191
(s_frames_done_idxs is not None) and \
192
(s_model_iter is not None) and \
193
(len(frames) == len(s_frames)) # frames count must match
194
195
if frames_equal:
196
for i in range(len(frames)):
197
frame = frames[i]
198
s_frame = s_frames[i]
199
# frames filenames must match
200
if frame.frame_info.filepath.name != s_frame.frame_info.filepath.name:
201
frames_equal = False
202
if not frames_equal:
203
break
204
205
if frames_equal:
206
io.log_info ('Using saved session from ' + '/'.join (self.merger_session_filepath.parts[-2:]) )
207
208
for frame in s_frames:
209
if frame.cfg is not None:
210
# recreate MergerConfig class using constructor with get_config() as dict params
211
# so if any new param will be added, old merger session will work properly
212
frame.cfg = frame.cfg.__class__( **frame.cfg.get_config() )
213
214
self.frames = s_frames
215
self.frames_idxs = s_frames_idxs
216
self.frames_done_idxs = s_frames_done_idxs
217
218
if self.model_iter != s_model_iter:
219
# model was more trained, recompute all frames
220
rewind_to_frame_idx = -1
221
for frame in self.frames:
222
frame.is_done = False
223
elif len(self.frames_idxs) == 0:
224
# all frames are done?
225
rewind_to_frame_idx = -1
226
227
if len(self.frames_idxs) != 0:
228
cur_frame = self.frames[self.frames_idxs[0]]
229
cur_frame.is_shown = False
230
231
if not frames_equal:
232
session_data = None
233
234
if session_data is None:
235
for filename in pathex.get_image_paths(self.output_path): #remove all images in output_path
236
Path(filename).unlink()
237
238
for filename in pathex.get_image_paths(self.output_mask_path): #remove all images in output_mask_path
239
Path(filename).unlink()
240
241
242
frames[0].cfg = self.merger_config.copy()
243
244
for i in range( len(self.frames) ):
245
frame = self.frames[i]
246
frame.idx = i
247
frame.output_filepath = self.output_path / ( frame.frame_info.filepath.stem + '.png' )
248
frame.output_mask_filepath = self.output_mask_path / ( frame.frame_info.filepath.stem + '.png' )
249
250
if not frame.output_filepath.exists() or \
251
not frame.output_mask_filepath.exists():
252
# if some frame does not exist, recompute and rewind
253
frame.is_done = False
254
frame.is_shown = False
255
256
if rewind_to_frame_idx is None:
257
rewind_to_frame_idx = i-1
258
else:
259
rewind_to_frame_idx = min(rewind_to_frame_idx, i-1)
260
261
if rewind_to_frame_idx is not None:
262
while len(self.frames_done_idxs) > 0:
263
if self.frames_done_idxs[-1] > rewind_to_frame_idx:
264
prev_frame = self.frames[self.frames_done_idxs.pop()]
265
self.frames_idxs.insert(0, prev_frame.idx)
266
else:
267
break
268
#override
269
def process_info_generator(self):
270
r = [0] if MERGER_DEBUG else range(self.process_count)
271
272
for i in r:
273
yield 'CPU%d' % (i), {}, {'device_idx': i,
274
'device_name': 'CPU%d' % (i),
275
'predictor_func': self.predictor_func,
276
'predictor_input_shape' : self.predictor_input_shape,
277
'face_enhancer_func': self.face_enhancer_func,
278
'xseg_256_extract_func' : self.xseg_256_extract_func,
279
'stdin_fd': sys.stdin.fileno() if MERGER_DEBUG else None
280
}
281
282
#overridable optional
283
def on_clients_initialized(self):
284
io.progress_bar ("Merging", len(self.frames_idxs)+len(self.frames_done_idxs), initial=len(self.frames_done_idxs) )
285
286
self.process_remain_frames = not self.is_interactive
287
self.is_interactive_quitting = not self.is_interactive
288
289
if self.is_interactive:
290
help_images = {
291
MergerConfig.TYPE_MASKED : cv2_imread ( str(Path(__file__).parent / 'gfx' / 'help_merger_masked.jpg') ),
292
MergerConfig.TYPE_FACE_AVATAR : cv2_imread ( str(Path(__file__).parent / 'gfx' / 'help_merger_face_avatar.jpg') ),
293
}
294
295
self.main_screen = Screen(initial_scale_to_width=1368, image=None, waiting_icon=True)
296
self.help_screen = Screen(initial_scale_to_height=768, image=help_images[self.merger_config.type], waiting_icon=False)
297
self.screen_manager = ScreenManager( "Merger", [self.main_screen, self.help_screen], capture_keys=True )
298
self.screen_manager.set_current (self.help_screen)
299
self.screen_manager.show_current()
300
301
self.masked_keys_funcs = {
302
'`' : lambda cfg,shift_pressed: cfg.set_mode(0),
303
'1' : lambda cfg,shift_pressed: cfg.set_mode(1),
304
'2' : lambda cfg,shift_pressed: cfg.set_mode(2),
305
'3' : lambda cfg,shift_pressed: cfg.set_mode(3),
306
'4' : lambda cfg,shift_pressed: cfg.set_mode(4),
307
'5' : lambda cfg,shift_pressed: cfg.set_mode(5),
308
'6' : lambda cfg,shift_pressed: cfg.set_mode(6),
309
'q' : lambda cfg,shift_pressed: cfg.add_hist_match_threshold(1 if not shift_pressed else 5),
310
'a' : lambda cfg,shift_pressed: cfg.add_hist_match_threshold(-1 if not shift_pressed else -5),
311
'w' : lambda cfg,shift_pressed: cfg.add_erode_mask_modifier(1 if not shift_pressed else 5),
312
's' : lambda cfg,shift_pressed: cfg.add_erode_mask_modifier(-1 if not shift_pressed else -5),
313
'e' : lambda cfg,shift_pressed: cfg.add_blur_mask_modifier(1 if not shift_pressed else 5),
314
'd' : lambda cfg,shift_pressed: cfg.add_blur_mask_modifier(-1 if not shift_pressed else -5),
315
'r' : lambda cfg,shift_pressed: cfg.add_motion_blur_power(1 if not shift_pressed else 5),
316
'f' : lambda cfg,shift_pressed: cfg.add_motion_blur_power(-1 if not shift_pressed else -5),
317
't' : lambda cfg,shift_pressed: cfg.add_super_resolution_power(1 if not shift_pressed else 5),
318
'g' : lambda cfg,shift_pressed: cfg.add_super_resolution_power(-1 if not shift_pressed else -5),
319
'y' : lambda cfg,shift_pressed: cfg.add_blursharpen_amount(1 if not shift_pressed else 5),
320
'h' : lambda cfg,shift_pressed: cfg.add_blursharpen_amount(-1 if not shift_pressed else -5),
321
'u' : lambda cfg,shift_pressed: cfg.add_output_face_scale(1 if not shift_pressed else 5),
322
'j' : lambda cfg,shift_pressed: cfg.add_output_face_scale(-1 if not shift_pressed else -5),
323
'i' : lambda cfg,shift_pressed: cfg.add_image_denoise_power(1 if not shift_pressed else 5),
324
'k' : lambda cfg,shift_pressed: cfg.add_image_denoise_power(-1 if not shift_pressed else -5),
325
'o' : lambda cfg,shift_pressed: cfg.add_bicubic_degrade_power(1 if not shift_pressed else 5),
326
'l' : lambda cfg,shift_pressed: cfg.add_bicubic_degrade_power(-1 if not shift_pressed else -5),
327
'p' : lambda cfg,shift_pressed: cfg.add_color_degrade_power(1 if not shift_pressed else 5),
328
';' : lambda cfg,shift_pressed: cfg.add_color_degrade_power(-1),
329
':' : lambda cfg,shift_pressed: cfg.add_color_degrade_power(-5),
330
'z' : lambda cfg,shift_pressed: cfg.toggle_masked_hist_match(),
331
'x' : lambda cfg,shift_pressed: cfg.toggle_mask_mode(),
332
'c' : lambda cfg,shift_pressed: cfg.toggle_color_transfer_mode(),
333
'n' : lambda cfg,shift_pressed: cfg.toggle_sharpen_mode(),
334
}
335
self.masked_keys = list(self.masked_keys_funcs.keys())
336
337
#overridable optional
338
def on_clients_finalized(self):
339
io.progress_bar_close()
340
341
if self.is_interactive:
342
self.screen_manager.finalize()
343
344
for frame in self.frames:
345
frame.output_filepath = None
346
frame.output_mask_filepath = None
347
frame.image = None
348
349
session_data = {
350
'frames': self.frames,
351
'frames_idxs': self.frames_idxs,
352
'frames_done_idxs': self.frames_done_idxs,
353
'model_iter' : self.model_iter,
354
}
355
self.merger_session_filepath.write_bytes( pickle.dumps(session_data) )
356
357
io.log_info ("Session is saved to " + '/'.join (self.merger_session_filepath.parts[-2:]) )
358
359
#override
360
def on_tick(self):
361
io.process_messages()
362
363
go_prev_frame = False
364
go_first_frame = False
365
go_prev_frame_overriding_cfg = False
366
go_first_frame_overriding_cfg = False
367
368
go_next_frame = self.process_remain_frames
369
go_next_frame_overriding_cfg = False
370
go_last_frame_overriding_cfg = False
371
372
cur_frame = None
373
if len(self.frames_idxs) != 0:
374
cur_frame = self.frames[self.frames_idxs[0]]
375
376
if self.is_interactive:
377
378
screen_image = None if self.process_remain_frames else \
379
self.main_screen.get_image()
380
381
self.main_screen.set_waiting_icon( self.process_remain_frames or \
382
self.is_interactive_quitting )
383
384
if cur_frame is not None and not self.is_interactive_quitting:
385
386
if not self.process_remain_frames:
387
if cur_frame.is_done:
388
if not cur_frame.is_shown:
389
if cur_frame.image is None:
390
image = cv2_imread (cur_frame.output_filepath, verbose=False)
391
image_mask = cv2_imread (cur_frame.output_mask_filepath, verbose=False)
392
if image is None or image_mask is None:
393
# unable to read? recompute then
394
cur_frame.is_done = False
395
else:
396
image = imagelib.normalize_channels(image, 3)
397
image_mask = imagelib.normalize_channels(image_mask, 1)
398
cur_frame.image = np.concatenate([image, image_mask], -1)
399
400
if cur_frame.is_done:
401
io.log_info (cur_frame.cfg.to_string( cur_frame.frame_info.filepath.name) )
402
cur_frame.is_shown = True
403
screen_image = cur_frame.image
404
else:
405
self.main_screen.set_waiting_icon(True)
406
407
self.main_screen.set_image(screen_image)
408
self.screen_manager.show_current()
409
410
key_events = self.screen_manager.get_key_events()
411
key, chr_key, ctrl_pressed, alt_pressed, shift_pressed = key_events[-1] if len(key_events) > 0 else (0,0,False,False,False)
412
413
if key == 9: #tab
414
self.screen_manager.switch_screens()
415
else:
416
if key == 27: #esc
417
self.is_interactive_quitting = True
418
elif self.screen_manager.get_current() is self.main_screen:
419
420
if self.merger_config.type == MergerConfig.TYPE_MASKED and chr_key in self.masked_keys:
421
self.process_remain_frames = False
422
423
if cur_frame is not None:
424
cfg = cur_frame.cfg
425
prev_cfg = cfg.copy()
426
427
if cfg.type == MergerConfig.TYPE_MASKED:
428
self.masked_keys_funcs[chr_key](cfg, shift_pressed)
429
430
if prev_cfg != cfg:
431
io.log_info ( cfg.to_string(cur_frame.frame_info.filepath.name) )
432
cur_frame.is_done = False
433
cur_frame.is_shown = False
434
else:
435
436
if chr_key == ',' or chr_key == 'm':
437
self.process_remain_frames = False
438
go_prev_frame = True
439
440
if chr_key == ',':
441
if shift_pressed:
442
go_first_frame = True
443
444
elif chr_key == 'm':
445
if not shift_pressed:
446
go_prev_frame_overriding_cfg = True
447
else:
448
go_first_frame_overriding_cfg = True
449
450
elif chr_key == '.' or chr_key == '/':
451
self.process_remain_frames = False
452
go_next_frame = True
453
454
if chr_key == '.':
455
if shift_pressed:
456
self.process_remain_frames = not self.process_remain_frames
457
458
elif chr_key == '/':
459
if not shift_pressed:
460
go_next_frame_overriding_cfg = True
461
else:
462
go_last_frame_overriding_cfg = True
463
464
elif chr_key == '-':
465
self.screen_manager.get_current().diff_scale(-0.1)
466
elif chr_key == '=':
467
self.screen_manager.get_current().diff_scale(0.1)
468
elif chr_key == 'v':
469
self.screen_manager.get_current().toggle_show_checker_board()
470
471
if go_prev_frame:
472
if cur_frame is None or cur_frame.is_done:
473
if cur_frame is not None:
474
cur_frame.image = None
475
476
while True:
477
if len(self.frames_done_idxs) > 0:
478
prev_frame = self.frames[self.frames_done_idxs.pop()]
479
self.frames_idxs.insert(0, prev_frame.idx)
480
prev_frame.is_shown = False
481
io.progress_bar_inc(-1)
482
483
if cur_frame is not None and (go_prev_frame_overriding_cfg or go_first_frame_overriding_cfg):
484
if prev_frame.cfg != cur_frame.cfg:
485
prev_frame.cfg = cur_frame.cfg.copy()
486
prev_frame.is_done = False
487
488
cur_frame = prev_frame
489
490
if go_first_frame_overriding_cfg or go_first_frame:
491
if len(self.frames_done_idxs) > 0:
492
continue
493
break
494
495
elif go_next_frame:
496
if cur_frame is not None and cur_frame.is_done:
497
cur_frame.image = None
498
cur_frame.is_shown = True
499
self.frames_done_idxs.append(cur_frame.idx)
500
self.frames_idxs.pop(0)
501
io.progress_bar_inc(1)
502
503
f = self.frames
504
505
if len(self.frames_idxs) != 0:
506
next_frame = f[ self.frames_idxs[0] ]
507
next_frame.is_shown = False
508
509
if go_next_frame_overriding_cfg or go_last_frame_overriding_cfg:
510
511
if go_next_frame_overriding_cfg:
512
to_frames = next_frame.idx+1
513
else:
514
to_frames = len(f)
515
516
for i in range( next_frame.idx, to_frames ):
517
f[i].cfg = None
518
519
for i in range( min(len(self.frames_idxs), self.prefetch_frame_count) ):
520
frame = f[ self.frames_idxs[i] ]
521
if frame.cfg is None:
522
if i == 0:
523
frame.cfg = cur_frame.cfg.copy()
524
else:
525
frame.cfg = f[ self.frames_idxs[i-1] ].cfg.copy()
526
527
frame.is_done = False #initiate solve again
528
frame.is_shown = False
529
530
if len(self.frames_idxs) == 0:
531
self.process_remain_frames = False
532
533
return (self.is_interactive and self.is_interactive_quitting) or \
534
(not self.is_interactive and self.process_remain_frames == False)
535
536
537
#override
538
def on_data_return (self, host_dict, pf):
539
frame = self.frames[pf.idx]
540
frame.is_done = False
541
frame.is_processing = False
542
543
#override
544
def on_result (self, host_dict, pf_sent, pf_result):
545
frame = self.frames[pf_result.idx]
546
frame.is_processing = False
547
if frame.cfg == pf_result.cfg:
548
frame.is_done = True
549
frame.image = pf_result.image
550
551
#override
552
def get_data(self, host_dict):
553
if self.is_interactive and self.is_interactive_quitting:
554
return None
555
556
for i in range ( min(len(self.frames_idxs), self.prefetch_frame_count) ):
557
frame = self.frames[ self.frames_idxs[i] ]
558
559
if not frame.is_done and not frame.is_processing and frame.cfg is not None:
560
frame.is_processing = True
561
return InteractiveMergerSubprocessor.ProcessingFrame(idx=frame.idx,
562
cfg=frame.cfg.copy(),
563
prev_temporal_frame_infos=frame.prev_temporal_frame_infos,
564
frame_info=frame.frame_info,
565
next_temporal_frame_infos=frame.next_temporal_frame_infos,
566
output_filepath=frame.output_filepath,
567
output_mask_filepath=frame.output_mask_filepath,
568
need_return_image=True )
569
570
return None
571
572
#override
573
def get_result(self):
574
return 0
575