Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
iperov
GitHub Repository: iperov/deepfacelab
Path: blob/master/models/ModelBase.py
628 views
1
import colorsys
2
import inspect
3
import json
4
import multiprocessing
5
import operator
6
import os
7
import pickle
8
import shutil
9
import tempfile
10
import time
11
from pathlib import Path
12
13
import cv2
14
import numpy as np
15
16
from core import imagelib, pathex
17
from core.cv2ex import *
18
from core.interact import interact as io
19
from core.leras import nn
20
from samplelib import SampleGeneratorBase
21
22
23
class ModelBase(object):
24
def __init__(self, is_training=False,
25
is_exporting=False,
26
saved_models_path=None,
27
training_data_src_path=None,
28
training_data_dst_path=None,
29
pretraining_data_path=None,
30
pretrained_model_path=None,
31
no_preview=False,
32
force_model_name=None,
33
force_gpu_idxs=None,
34
cpu_only=False,
35
debug=False,
36
force_model_class_name=None,
37
silent_start=False,
38
**kwargs):
39
self.is_training = is_training
40
self.is_exporting = is_exporting
41
self.saved_models_path = saved_models_path
42
self.training_data_src_path = training_data_src_path
43
self.training_data_dst_path = training_data_dst_path
44
self.pretraining_data_path = pretraining_data_path
45
self.pretrained_model_path = pretrained_model_path
46
self.no_preview = no_preview
47
self.debug = debug
48
49
self.model_class_name = model_class_name = Path(inspect.getmodule(self).__file__).parent.name.rsplit("_", 1)[1]
50
51
if force_model_class_name is None:
52
if force_model_name is not None:
53
self.model_name = force_model_name
54
else:
55
while True:
56
# gather all model dat files
57
saved_models_names = []
58
for filepath in pathex.get_file_paths(saved_models_path):
59
filepath_name = filepath.name
60
if filepath_name.endswith(f'{model_class_name}_data.dat'):
61
saved_models_names += [ (filepath_name.split('_')[0], os.path.getmtime(filepath)) ]
62
63
# sort by modified datetime
64
saved_models_names = sorted(saved_models_names, key=operator.itemgetter(1), reverse=True )
65
saved_models_names = [ x[0] for x in saved_models_names ]
66
67
68
if len(saved_models_names) != 0:
69
if silent_start:
70
self.model_name = saved_models_names[0]
71
io.log_info(f'Silent start: choosed model "{self.model_name}"')
72
else:
73
io.log_info ("Choose one of saved models, or enter a name to create a new model.")
74
io.log_info ("[r] : rename")
75
io.log_info ("[d] : delete")
76
io.log_info ("")
77
for i, model_name in enumerate(saved_models_names):
78
s = f"[{i}] : {model_name} "
79
if i == 0:
80
s += "- latest"
81
io.log_info (s)
82
83
inp = io.input_str(f"", "0", show_default_value=False )
84
model_idx = -1
85
try:
86
model_idx = np.clip ( int(inp), 0, len(saved_models_names)-1 )
87
except:
88
pass
89
90
if model_idx == -1:
91
if len(inp) == 1:
92
is_rename = inp[0] == 'r'
93
is_delete = inp[0] == 'd'
94
95
if is_rename or is_delete:
96
if len(saved_models_names) != 0:
97
98
if is_rename:
99
name = io.input_str(f"Enter the name of the model you want to rename")
100
elif is_delete:
101
name = io.input_str(f"Enter the name of the model you want to delete")
102
103
if name in saved_models_names:
104
105
if is_rename:
106
new_model_name = io.input_str(f"Enter new name of the model")
107
108
for filepath in pathex.get_paths(saved_models_path):
109
filepath_name = filepath.name
110
111
model_filename, remain_filename = filepath_name.split('_', 1)
112
if model_filename == name:
113
114
if is_rename:
115
new_filepath = filepath.parent / ( new_model_name + '_' + remain_filename )
116
filepath.rename (new_filepath)
117
elif is_delete:
118
filepath.unlink()
119
continue
120
121
self.model_name = inp
122
else:
123
self.model_name = saved_models_names[model_idx]
124
125
else:
126
self.model_name = io.input_str(f"No saved models found. Enter a name of a new model", "new")
127
self.model_name = self.model_name.replace('_', ' ')
128
break
129
130
131
self.model_name = self.model_name + '_' + self.model_class_name
132
else:
133
self.model_name = force_model_class_name
134
135
self.iter = 0
136
self.options = {}
137
self.options_show_override = {}
138
self.loss_history = []
139
self.sample_for_preview = None
140
self.choosed_gpu_indexes = None
141
142
model_data = {}
143
self.model_data_path = Path( self.get_strpath_storage_for_file('data.dat') )
144
if self.model_data_path.exists():
145
io.log_info (f"Loading {self.model_name} model...")
146
model_data = pickle.loads ( self.model_data_path.read_bytes() )
147
self.iter = model_data.get('iter',0)
148
if self.iter != 0:
149
self.options = model_data['options']
150
self.loss_history = model_data.get('loss_history', [])
151
self.sample_for_preview = model_data.get('sample_for_preview', None)
152
self.choosed_gpu_indexes = model_data.get('choosed_gpu_indexes', None)
153
154
if self.is_first_run():
155
io.log_info ("\nModel first run.")
156
157
if silent_start:
158
self.device_config = nn.DeviceConfig.BestGPU()
159
io.log_info (f"Silent start: choosed device {'CPU' if self.device_config.cpu_only else self.device_config.devices[0].name}")
160
else:
161
self.device_config = nn.DeviceConfig.GPUIndexes( force_gpu_idxs or nn.ask_choose_device_idxs(suggest_best_multi_gpu=True)) \
162
if not cpu_only else nn.DeviceConfig.CPU()
163
164
nn.initialize(self.device_config)
165
166
####
167
self.default_options_path = saved_models_path / f'{self.model_class_name}_default_options.dat'
168
self.default_options = {}
169
if self.default_options_path.exists():
170
try:
171
self.default_options = pickle.loads ( self.default_options_path.read_bytes() )
172
except:
173
pass
174
175
self.choose_preview_history = False
176
self.batch_size = self.load_or_def_option('batch_size', 1)
177
#####
178
179
io.input_skip_pending()
180
self.on_initialize_options()
181
182
if self.is_first_run():
183
# save as default options only for first run model initialize
184
self.default_options_path.write_bytes( pickle.dumps (self.options) )
185
186
self.autobackup_hour = self.options.get('autobackup_hour', 0)
187
self.write_preview_history = self.options.get('write_preview_history', False)
188
self.target_iter = self.options.get('target_iter',0)
189
self.random_flip = self.options.get('random_flip',True)
190
self.random_src_flip = self.options.get('random_src_flip', False)
191
self.random_dst_flip = self.options.get('random_dst_flip', True)
192
193
self.on_initialize()
194
self.options['batch_size'] = self.batch_size
195
196
self.preview_history_writer = None
197
if self.is_training:
198
self.preview_history_path = self.saved_models_path / ( f'{self.get_model_name()}_history' )
199
self.autobackups_path = self.saved_models_path / ( f'{self.get_model_name()}_autobackups' )
200
201
if self.write_preview_history or io.is_colab():
202
if not self.preview_history_path.exists():
203
self.preview_history_path.mkdir(exist_ok=True)
204
else:
205
if self.iter == 0:
206
for filename in pathex.get_image_paths(self.preview_history_path):
207
Path(filename).unlink()
208
209
if self.generator_list is None:
210
raise ValueError( 'You didnt set_training_data_generators()')
211
else:
212
for i, generator in enumerate(self.generator_list):
213
if not isinstance(generator, SampleGeneratorBase):
214
raise ValueError('training data generator is not subclass of SampleGeneratorBase')
215
216
self.update_sample_for_preview(choose_preview_history=self.choose_preview_history)
217
218
if self.autobackup_hour != 0:
219
self.autobackup_start_time = time.time()
220
221
if not self.autobackups_path.exists():
222
self.autobackups_path.mkdir(exist_ok=True)
223
224
io.log_info( self.get_summary_text() )
225
226
def update_sample_for_preview(self, choose_preview_history=False, force_new=False):
227
if self.sample_for_preview is None or choose_preview_history or force_new:
228
if choose_preview_history and io.is_support_windows():
229
wnd_name = "[p] - next. [space] - switch preview type. [enter] - confirm."
230
io.log_info (f"Choose image for the preview history. {wnd_name}")
231
io.named_window(wnd_name)
232
io.capture_keys(wnd_name)
233
choosed = False
234
preview_id_counter = 0
235
while not choosed:
236
self.sample_for_preview = self.generate_next_samples()
237
previews = self.get_history_previews()
238
239
io.show_image( wnd_name, ( previews[preview_id_counter % len(previews) ][1] *255).astype(np.uint8) )
240
241
while True:
242
key_events = io.get_key_events(wnd_name)
243
key, chr_key, ctrl_pressed, alt_pressed, shift_pressed = key_events[-1] if len(key_events) > 0 else (0,0,False,False,False)
244
if key == ord('\n') or key == ord('\r'):
245
choosed = True
246
break
247
elif key == ord(' '):
248
preview_id_counter += 1
249
break
250
elif key == ord('p'):
251
break
252
253
try:
254
io.process_messages(0.1)
255
except KeyboardInterrupt:
256
choosed = True
257
258
io.destroy_window(wnd_name)
259
else:
260
self.sample_for_preview = self.generate_next_samples()
261
262
try:
263
self.get_history_previews()
264
except:
265
self.sample_for_preview = self.generate_next_samples()
266
267
self.last_sample = self.sample_for_preview
268
269
def load_or_def_option(self, name, def_value):
270
options_val = self.options.get(name, None)
271
if options_val is not None:
272
return options_val
273
274
def_opt_val = self.default_options.get(name, None)
275
if def_opt_val is not None:
276
return def_opt_val
277
278
return def_value
279
280
def ask_override(self):
281
return self.is_training and self.iter != 0 and io.input_in_time ("Press enter in 2 seconds to override model settings.", 5 if io.is_colab() else 2 )
282
283
def ask_autobackup_hour(self, default_value=0):
284
default_autobackup_hour = self.options['autobackup_hour'] = self.load_or_def_option('autobackup_hour', default_value)
285
self.options['autobackup_hour'] = io.input_int(f"Autobackup every N hour", default_autobackup_hour, add_info="0..24", help_message="Autobackup model files with preview every N hour. Latest backup located in model/<>_autobackups/01")
286
287
def ask_write_preview_history(self, default_value=False):
288
default_write_preview_history = self.load_or_def_option('write_preview_history', default_value)
289
self.options['write_preview_history'] = io.input_bool(f"Write preview history", default_write_preview_history, help_message="Preview history will be writed to <ModelName>_history folder.")
290
291
if self.options['write_preview_history']:
292
if io.is_support_windows():
293
self.choose_preview_history = io.input_bool("Choose image for the preview history", False)
294
elif io.is_colab():
295
self.choose_preview_history = io.input_bool("Randomly choose new image for preview history", False, help_message="Preview image history will stay stuck with old faces if you reuse the same model on different celebs. Choose no unless you are changing src/dst to a new person")
296
297
def ask_target_iter(self, default_value=0):
298
default_target_iter = self.load_or_def_option('target_iter', default_value)
299
self.options['target_iter'] = max(0, io.input_int("Target iteration", default_target_iter))
300
301
def ask_random_flip(self):
302
default_random_flip = self.load_or_def_option('random_flip', True)
303
self.options['random_flip'] = io.input_bool("Flip faces randomly", default_random_flip, help_message="Predicted face will look more naturally without this option, but src faceset should cover all face directions as dst faceset.")
304
305
def ask_random_src_flip(self):
306
default_random_src_flip = self.load_or_def_option('random_src_flip', False)
307
self.options['random_src_flip'] = io.input_bool("Flip SRC faces randomly", default_random_src_flip, help_message="Random horizontal flip SRC faceset. Covers more angles, but the face may look less naturally.")
308
309
def ask_random_dst_flip(self):
310
default_random_dst_flip = self.load_or_def_option('random_dst_flip', True)
311
self.options['random_dst_flip'] = io.input_bool("Flip DST faces randomly", default_random_dst_flip, help_message="Random horizontal flip DST faceset. Makes generalization of src->dst better, if src random flip is not enabled.")
312
313
def ask_batch_size(self, suggest_batch_size=None, range=None):
314
default_batch_size = self.load_or_def_option('batch_size', suggest_batch_size or self.batch_size)
315
316
batch_size = max(0, io.input_int("Batch_size", default_batch_size, valid_range=range, help_message="Larger batch size is better for NN's generalization, but it can cause Out of Memory error. Tune this value for your videocard manually."))
317
318
if range is not None:
319
batch_size = np.clip(batch_size, range[0], range[1])
320
321
self.options['batch_size'] = self.batch_size = batch_size
322
323
324
#overridable
325
def on_initialize_options(self):
326
pass
327
328
#overridable
329
def on_initialize(self):
330
'''
331
initialize your models
332
333
store and retrieve your model options in self.options['']
334
335
check example
336
'''
337
pass
338
339
#overridable
340
def onSave(self):
341
#save your models here
342
pass
343
344
#overridable
345
def onTrainOneIter(self, sample, generator_list):
346
#train your models here
347
348
#return array of losses
349
return ( ('loss_src', 0), ('loss_dst', 0) )
350
351
#overridable
352
def onGetPreview(self, sample, for_history=False):
353
#you can return multiple previews
354
#return [ ('preview_name',preview_rgb), ... ]
355
return []
356
357
#overridable if you want model name differs from folder name
358
def get_model_name(self):
359
return self.model_name
360
361
#overridable , return [ [model, filename],... ] list
362
def get_model_filename_list(self):
363
return []
364
365
#overridable
366
def get_MergerConfig(self):
367
#return predictor_func, predictor_input_shape, MergerConfig() for the model
368
raise NotImplementedError
369
370
def get_pretraining_data_path(self):
371
return self.pretraining_data_path
372
373
def get_target_iter(self):
374
return self.target_iter
375
376
def is_reached_iter_goal(self):
377
return self.target_iter != 0 and self.iter >= self.target_iter
378
379
def get_previews(self):
380
return self.onGetPreview ( self.last_sample )
381
382
def get_history_previews(self):
383
return self.onGetPreview (self.sample_for_preview, for_history=True)
384
385
def get_preview_history_writer(self):
386
if self.preview_history_writer is None:
387
self.preview_history_writer = PreviewHistoryWriter()
388
return self.preview_history_writer
389
390
def save(self):
391
Path( self.get_summary_path() ).write_text( self.get_summary_text() )
392
393
self.onSave()
394
395
model_data = {
396
'iter': self.iter,
397
'options': self.options,
398
'loss_history': self.loss_history,
399
'sample_for_preview' : self.sample_for_preview,
400
'choosed_gpu_indexes' : self.choosed_gpu_indexes,
401
}
402
pathex.write_bytes_safe (self.model_data_path, pickle.dumps(model_data) )
403
404
if self.autobackup_hour != 0:
405
diff_hour = int ( (time.time() - self.autobackup_start_time) // 3600 )
406
407
if diff_hour > 0 and diff_hour % self.autobackup_hour == 0:
408
self.autobackup_start_time += self.autobackup_hour*3600
409
self.create_backup()
410
411
def create_backup(self):
412
io.log_info ("Creating backup...", end='\r')
413
414
if not self.autobackups_path.exists():
415
self.autobackups_path.mkdir(exist_ok=True)
416
417
bckp_filename_list = [ self.get_strpath_storage_for_file(filename) for _, filename in self.get_model_filename_list() ]
418
bckp_filename_list += [ str(self.get_summary_path()), str(self.model_data_path) ]
419
420
for i in range(24,0,-1):
421
idx_str = '%.2d' % i
422
next_idx_str = '%.2d' % (i+1)
423
424
idx_backup_path = self.autobackups_path / idx_str
425
next_idx_packup_path = self.autobackups_path / next_idx_str
426
427
if idx_backup_path.exists():
428
if i == 24:
429
pathex.delete_all_files(idx_backup_path)
430
else:
431
next_idx_packup_path.mkdir(exist_ok=True)
432
pathex.move_all_files (idx_backup_path, next_idx_packup_path)
433
434
if i == 1:
435
idx_backup_path.mkdir(exist_ok=True)
436
for filename in bckp_filename_list:
437
shutil.copy ( str(filename), str(idx_backup_path / Path(filename).name) )
438
439
previews = self.get_previews()
440
plist = []
441
for i in range(len(previews)):
442
name, bgr = previews[i]
443
plist += [ (bgr, idx_backup_path / ( ('preview_%s.jpg') % (name)) ) ]
444
445
if len(plist) != 0:
446
self.get_preview_history_writer().post(plist, self.loss_history, self.iter)
447
448
def debug_one_iter(self):
449
images = []
450
for generator in self.generator_list:
451
for i,batch in enumerate(next(generator)):
452
if len(batch.shape) == 4:
453
images.append( batch[0] )
454
455
return imagelib.equalize_and_stack_square (images)
456
457
def generate_next_samples(self):
458
sample = []
459
for generator in self.generator_list:
460
if generator.is_initialized():
461
sample.append ( generator.generate_next() )
462
else:
463
sample.append ( [] )
464
self.last_sample = sample
465
return sample
466
467
#overridable
468
def should_save_preview_history(self):
469
return (not io.is_colab() and self.iter % 10 == 0) or (io.is_colab() and self.iter % 100 == 0)
470
471
def train_one_iter(self):
472
473
iter_time = time.time()
474
losses = self.onTrainOneIter()
475
iter_time = time.time() - iter_time
476
477
self.loss_history.append ( [float(loss[1]) for loss in losses] )
478
479
if self.should_save_preview_history():
480
plist = []
481
482
if io.is_colab():
483
previews = self.get_previews()
484
for i in range(len(previews)):
485
name, bgr = previews[i]
486
plist += [ (bgr, self.get_strpath_storage_for_file('preview_%s.jpg' % (name) ) ) ]
487
488
if self.write_preview_history:
489
previews = self.get_history_previews()
490
for i in range(len(previews)):
491
name, bgr = previews[i]
492
path = self.preview_history_path / name
493
plist += [ ( bgr, str ( path / ( f'{self.iter:07d}.jpg') ) ) ]
494
if not io.is_colab():
495
plist += [ ( bgr, str ( path / ( '_last.jpg' ) )) ]
496
497
if len(plist) != 0:
498
self.get_preview_history_writer().post(plist, self.loss_history, self.iter)
499
500
self.iter += 1
501
502
return self.iter, iter_time
503
504
def pass_one_iter(self):
505
self.generate_next_samples()
506
507
def finalize(self):
508
nn.close_session()
509
510
def is_first_run(self):
511
return self.iter == 0
512
513
def is_debug(self):
514
return self.debug
515
516
def set_batch_size(self, batch_size):
517
self.batch_size = batch_size
518
519
def get_batch_size(self):
520
return self.batch_size
521
522
def get_iter(self):
523
return self.iter
524
525
def set_iter(self, iter):
526
self.iter = iter
527
self.loss_history = self.loss_history[:iter]
528
529
def get_loss_history(self):
530
return self.loss_history
531
532
def set_training_data_generators (self, generator_list):
533
self.generator_list = generator_list
534
535
def get_training_data_generators (self):
536
return self.generator_list
537
538
def get_model_root_path(self):
539
return self.saved_models_path
540
541
def get_strpath_storage_for_file(self, filename):
542
return str( self.saved_models_path / ( self.get_model_name() + '_' + filename) )
543
544
def get_summary_path(self):
545
return self.get_strpath_storage_for_file('summary.txt')
546
547
def get_summary_text(self):
548
visible_options = self.options.copy()
549
visible_options.update(self.options_show_override)
550
551
###Generate text summary of model hyperparameters
552
#Find the longest key name and value string. Used as column widths.
553
width_name = max([len(k) for k in visible_options.keys()] + [17]) + 1 # Single space buffer to left edge. Minimum of 17, the length of the longest static string used "Current iteration"
554
width_value = max([len(str(x)) for x in visible_options.values()] + [len(str(self.get_iter())), len(self.get_model_name())]) + 1 # Single space buffer to right edge
555
if len(self.device_config.devices) != 0: #Check length of GPU names
556
width_value = max([len(device.name)+1 for device in self.device_config.devices] + [width_value])
557
width_total = width_name + width_value + 2 #Plus 2 for ": "
558
559
summary_text = []
560
summary_text += [f'=={" Model Summary ":=^{width_total}}=='] # Model/status summary
561
summary_text += [f'=={" "*width_total}==']
562
summary_text += [f'=={"Model name": >{width_name}}: {self.get_model_name(): <{width_value}}=='] # Name
563
summary_text += [f'=={" "*width_total}==']
564
summary_text += [f'=={"Current iteration": >{width_name}}: {str(self.get_iter()): <{width_value}}=='] # Iter
565
summary_text += [f'=={" "*width_total}==']
566
567
summary_text += [f'=={" Model Options ":-^{width_total}}=='] # Model options
568
summary_text += [f'=={" "*width_total}==']
569
for key in visible_options.keys():
570
summary_text += [f'=={key: >{width_name}}: {str(visible_options[key]): <{width_value}}=='] # visible_options key/value pairs
571
summary_text += [f'=={" "*width_total}==']
572
573
summary_text += [f'=={" Running On ":-^{width_total}}=='] # Training hardware info
574
summary_text += [f'=={" "*width_total}==']
575
if len(self.device_config.devices) == 0:
576
summary_text += [f'=={"Using device": >{width_name}}: {"CPU": <{width_value}}=='] # cpu_only
577
else:
578
for device in self.device_config.devices:
579
summary_text += [f'=={"Device index": >{width_name}}: {device.index: <{width_value}}=='] # GPU hardware device index
580
summary_text += [f'=={"Name": >{width_name}}: {device.name: <{width_value}}=='] # GPU name
581
vram_str = f'{device.total_mem_gb:.2f}GB' # GPU VRAM - Formated as #.## (or ##.##)
582
summary_text += [f'=={"VRAM": >{width_name}}: {vram_str: <{width_value}}==']
583
summary_text += [f'=={" "*width_total}==']
584
summary_text += [f'=={"="*width_total}==']
585
summary_text = "\n".join (summary_text)
586
return summary_text
587
588
@staticmethod
589
def get_loss_history_preview(loss_history, iter, w, c):
590
loss_history = np.array (loss_history.copy())
591
592
lh_height = 100
593
lh_img = np.ones ( (lh_height,w,c) ) * 0.1
594
595
if len(loss_history) != 0:
596
loss_count = len(loss_history[0])
597
lh_len = len(loss_history)
598
599
l_per_col = lh_len / w
600
plist_max = [ [ max (0.0, loss_history[int(col*l_per_col)][p],
601
*[ loss_history[i_ab][p]
602
for i_ab in range( int(col*l_per_col), int((col+1)*l_per_col) )
603
]
604
)
605
for p in range(loss_count)
606
]
607
for col in range(w)
608
]
609
610
plist_min = [ [ min (plist_max[col][p], loss_history[int(col*l_per_col)][p],
611
*[ loss_history[i_ab][p]
612
for i_ab in range( int(col*l_per_col), int((col+1)*l_per_col) )
613
]
614
)
615
for p in range(loss_count)
616
]
617
for col in range(w)
618
]
619
620
plist_abs_max = np.mean(loss_history[ len(loss_history) // 5 : ]) * 2
621
622
for col in range(0, w):
623
for p in range(0,loss_count):
624
point_color = [1.0]*c
625
point_color[0:3] = colorsys.hsv_to_rgb ( p * (1.0/loss_count), 1.0, 1.0 )
626
627
ph_max = int ( (plist_max[col][p] / plist_abs_max) * (lh_height-1) )
628
ph_max = np.clip( ph_max, 0, lh_height-1 )
629
630
ph_min = int ( (plist_min[col][p] / plist_abs_max) * (lh_height-1) )
631
ph_min = np.clip( ph_min, 0, lh_height-1 )
632
633
for ph in range(ph_min, ph_max+1):
634
lh_img[ (lh_height-ph-1), col ] = point_color
635
636
lh_lines = 5
637
lh_line_height = (lh_height-1)/lh_lines
638
for i in range(0,lh_lines+1):
639
lh_img[ int(i*lh_line_height), : ] = (0.8,)*c
640
641
last_line_t = int((lh_lines-1)*lh_line_height)
642
last_line_b = int(lh_lines*lh_line_height)
643
644
lh_text = 'Iter: %d' % (iter) if iter != 0 else ''
645
646
lh_img[last_line_t:last_line_b, 0:w] += imagelib.get_text_image ( (last_line_b-last_line_t,w,c), lh_text, color=[0.8]*c )
647
return lh_img
648
649
class PreviewHistoryWriter():
650
def __init__(self):
651
self.sq = multiprocessing.Queue()
652
self.p = multiprocessing.Process(target=self.process, args=( self.sq, ))
653
self.p.daemon = True
654
self.p.start()
655
656
def process(self, sq):
657
while True:
658
while not sq.empty():
659
plist, loss_history, iter = sq.get()
660
661
preview_lh_cache = {}
662
for preview, filepath in plist:
663
filepath = Path(filepath)
664
i = (preview.shape[1], preview.shape[2])
665
666
preview_lh = preview_lh_cache.get(i, None)
667
if preview_lh is None:
668
preview_lh = ModelBase.get_loss_history_preview(loss_history, iter, preview.shape[1], preview.shape[2])
669
preview_lh_cache[i] = preview_lh
670
671
img = (np.concatenate ( [preview_lh, preview], axis=0 ) * 255).astype(np.uint8)
672
673
filepath.parent.mkdir(parents=True, exist_ok=True)
674
cv2_imwrite (filepath, img )
675
676
time.sleep(0.01)
677
678
def post(self, plist, loss_history, iter):
679
self.sq.put ( (plist, loss_history, iter) )
680
681
# disable pickling
682
def __getstate__(self):
683
return dict()
684
def __setstate__(self, d):
685
self.__dict__.update(d)
686
687