Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
iperov
GitHub Repository: iperov/deepfacelab
Path: blob/master/models/Model_SAEHD/Model.py
628 views
1
import multiprocessing
2
import operator
3
from functools import partial
4
5
import numpy as np
6
7
from core import mathlib
8
from core.interact import interact as io
9
from core.leras import nn
10
from facelib import FaceType
11
from models import ModelBase
12
from samplelib import *
13
14
class SAEHDModel(ModelBase):
15
16
#override
17
def on_initialize_options(self):
18
device_config = nn.getCurrentDeviceConfig()
19
20
lowest_vram = 2
21
if len(device_config.devices) != 0:
22
lowest_vram = device_config.devices.get_worst_device().total_mem_gb
23
24
if lowest_vram >= 4:
25
suggest_batch_size = 8
26
else:
27
suggest_batch_size = 4
28
29
yn_str = {True:'y',False:'n'}
30
min_res = 64
31
max_res = 640
32
33
#default_usefp16 = self.options['use_fp16'] = self.load_or_def_option('use_fp16', False)
34
default_resolution = self.options['resolution'] = self.load_or_def_option('resolution', 128)
35
default_face_type = self.options['face_type'] = self.load_or_def_option('face_type', 'f')
36
default_models_opt_on_gpu = self.options['models_opt_on_gpu'] = self.load_or_def_option('models_opt_on_gpu', True)
37
38
default_archi = self.options['archi'] = self.load_or_def_option('archi', 'liae-ud')
39
40
default_ae_dims = self.options['ae_dims'] = self.load_or_def_option('ae_dims', 256)
41
default_e_dims = self.options['e_dims'] = self.load_or_def_option('e_dims', 64)
42
default_d_dims = self.options['d_dims'] = self.options.get('d_dims', None)
43
default_d_mask_dims = self.options['d_mask_dims'] = self.options.get('d_mask_dims', None)
44
default_masked_training = self.options['masked_training'] = self.load_or_def_option('masked_training', True)
45
default_eyes_mouth_prio = self.options['eyes_mouth_prio'] = self.load_or_def_option('eyes_mouth_prio', False)
46
default_uniform_yaw = self.options['uniform_yaw'] = self.load_or_def_option('uniform_yaw', False)
47
default_blur_out_mask = self.options['blur_out_mask'] = self.load_or_def_option('blur_out_mask', False)
48
49
default_adabelief = self.options['adabelief'] = self.load_or_def_option('adabelief', True)
50
51
lr_dropout = self.load_or_def_option('lr_dropout', 'n')
52
lr_dropout = {True:'y', False:'n'}.get(lr_dropout, lr_dropout) #backward comp
53
default_lr_dropout = self.options['lr_dropout'] = lr_dropout
54
55
default_random_warp = self.options['random_warp'] = self.load_or_def_option('random_warp', True)
56
default_random_hsv_power = self.options['random_hsv_power'] = self.load_or_def_option('random_hsv_power', 0.0)
57
default_true_face_power = self.options['true_face_power'] = self.load_or_def_option('true_face_power', 0.0)
58
default_face_style_power = self.options['face_style_power'] = self.load_or_def_option('face_style_power', 0.0)
59
default_bg_style_power = self.options['bg_style_power'] = self.load_or_def_option('bg_style_power', 0.0)
60
default_ct_mode = self.options['ct_mode'] = self.load_or_def_option('ct_mode', 'none')
61
default_clipgrad = self.options['clipgrad'] = self.load_or_def_option('clipgrad', False)
62
default_pretrain = self.options['pretrain'] = self.load_or_def_option('pretrain', False)
63
64
ask_override = self.ask_override()
65
if self.is_first_run() or ask_override:
66
self.ask_autobackup_hour()
67
self.ask_write_preview_history()
68
self.ask_target_iter()
69
self.ask_random_src_flip()
70
self.ask_random_dst_flip()
71
self.ask_batch_size(suggest_batch_size)
72
#self.options['use_fp16'] = io.input_bool ("Use fp16", default_usefp16, help_message='Increases training/inference speed, reduces model size. Model may crash. Enable it after 1-5k iters.')
73
74
if self.is_first_run():
75
resolution = io.input_int("Resolution", default_resolution, add_info="64-640", help_message="More resolution requires more VRAM and time to train. Value will be adjusted to multiple of 16 and 32 for -d archi.")
76
resolution = np.clip ( (resolution // 16) * 16, min_res, max_res)
77
self.options['resolution'] = resolution
78
79
80
81
self.options['face_type'] = io.input_str ("Face type", default_face_type, ['h','mf','f','wf','head'], help_message="Half / mid face / full face / whole face / head. Half face has better resolution, but covers less area of cheeks. Mid face is 30% wider than half face. 'Whole face' covers full area of face include forehead. 'head' covers full head, but requires XSeg for src and dst faceset.").lower()
82
83
while True:
84
archi = io.input_str ("AE architecture", default_archi, help_message=\
85
"""
86
'df' keeps more identity-preserved face.
87
'liae' can fix overly different face shapes.
88
'-u' increased likeness of the face.
89
'-d' (experimental) doubling the resolution using the same computation cost.
90
Examples: df, liae, df-d, df-ud, liae-ud, ...
91
""").lower()
92
93
archi_split = archi.split('-')
94
95
if len(archi_split) == 2:
96
archi_type, archi_opts = archi_split
97
elif len(archi_split) == 1:
98
archi_type, archi_opts = archi_split[0], None
99
else:
100
continue
101
102
if archi_type not in ['df', 'liae']:
103
continue
104
105
if archi_opts is not None:
106
if len(archi_opts) == 0:
107
continue
108
if len([ 1 for opt in archi_opts if opt not in ['u','d','t','c'] ]) != 0:
109
continue
110
111
if 'd' in archi_opts:
112
self.options['resolution'] = np.clip ( (self.options['resolution'] // 32) * 32, min_res, max_res)
113
114
break
115
self.options['archi'] = archi
116
117
default_d_dims = self.options['d_dims'] = self.load_or_def_option('d_dims', 64)
118
119
default_d_mask_dims = default_d_dims // 3
120
default_d_mask_dims += default_d_mask_dims % 2
121
default_d_mask_dims = self.options['d_mask_dims'] = self.load_or_def_option('d_mask_dims', default_d_mask_dims)
122
123
if self.is_first_run():
124
self.options['ae_dims'] = np.clip ( io.input_int("AutoEncoder dimensions", default_ae_dims, add_info="32-1024", help_message="All face information will packed to AE dims. If amount of AE dims are not enough, then for example closed eyes will not be recognized. More dims are better, but require more VRAM. You can fine-tune model size to fit your GPU." ), 32, 1024 )
125
126
e_dims = np.clip ( io.input_int("Encoder dimensions", default_e_dims, add_info="16-256", help_message="More dims help to recognize more facial features and achieve sharper result, but require more VRAM. You can fine-tune model size to fit your GPU." ), 16, 256 )
127
self.options['e_dims'] = e_dims + e_dims % 2
128
129
d_dims = np.clip ( io.input_int("Decoder dimensions", default_d_dims, add_info="16-256", help_message="More dims help to recognize more facial features and achieve sharper result, but require more VRAM. You can fine-tune model size to fit your GPU." ), 16, 256 )
130
self.options['d_dims'] = d_dims + d_dims % 2
131
132
d_mask_dims = np.clip ( io.input_int("Decoder mask dimensions", default_d_mask_dims, add_info="16-256", help_message="Typical mask dimensions = decoder dimensions / 3. If you manually cut out obstacles from the dst mask, you can increase this parameter to achieve better quality." ), 16, 256 )
133
self.options['d_mask_dims'] = d_mask_dims + d_mask_dims % 2
134
135
if self.is_first_run() or ask_override:
136
if self.options['face_type'] == 'wf' or self.options['face_type'] == 'head':
137
self.options['masked_training'] = io.input_bool ("Masked training", default_masked_training, help_message="This option is available only for 'whole_face' or 'head' type. Masked training clips training area to full_face mask or XSeg mask, thus network will train the faces properly.")
138
139
self.options['eyes_mouth_prio'] = io.input_bool ("Eyes and mouth priority", default_eyes_mouth_prio, help_message='Helps to fix eye problems during training like "alien eyes" and wrong eyes direction. Also makes the detail of the teeth higher.')
140
self.options['uniform_yaw'] = io.input_bool ("Uniform yaw distribution of samples", default_uniform_yaw, help_message='Helps to fix blurry side faces due to small amount of them in the faceset.')
141
self.options['blur_out_mask'] = io.input_bool ("Blur out mask", default_blur_out_mask, help_message='Blurs nearby area outside of applied face mask of training samples. The result is the background near the face is smoothed and less noticeable on swapped face. The exact xseg mask in src and dst faceset is required.')
142
143
default_gan_power = self.options['gan_power'] = self.load_or_def_option('gan_power', 0.0)
144
default_gan_patch_size = self.options['gan_patch_size'] = self.load_or_def_option('gan_patch_size', self.options['resolution'] // 8)
145
default_gan_dims = self.options['gan_dims'] = self.load_or_def_option('gan_dims', 16)
146
147
if self.is_first_run() or ask_override:
148
self.options['models_opt_on_gpu'] = io.input_bool ("Place models and optimizer on GPU", default_models_opt_on_gpu, help_message="When you train on one GPU, by default model and optimizer weights are placed on GPU to accelerate the process. You can place they on CPU to free up extra VRAM, thus set bigger dimensions.")
149
150
self.options['adabelief'] = io.input_bool ("Use AdaBelief optimizer?", default_adabelief, help_message="Use AdaBelief optimizer. It requires more VRAM, but the accuracy and the generalization of the model is higher.")
151
152
self.options['lr_dropout'] = io.input_str (f"Use learning rate dropout", default_lr_dropout, ['n','y','cpu'], help_message="When the face is trained enough, you can enable this option to get extra sharpness and reduce subpixel shake for less amount of iterations. Enabled it before `disable random warp` and before GAN. \nn - disabled.\ny - enabled\ncpu - enabled on CPU. This allows not to use extra VRAM, sacrificing 20% time of iteration.")
153
154
self.options['random_warp'] = io.input_bool ("Enable random warp of samples", default_random_warp, help_message="Random warp is required to generalize facial expressions of both faces. When the face is trained enough, you can disable it to get extra sharpness and reduce subpixel shake for less amount of iterations.")
155
156
self.options['random_hsv_power'] = np.clip ( io.input_number ("Random hue/saturation/light intensity", default_random_hsv_power, add_info="0.0 .. 0.3", help_message="Random hue/saturation/light intensity applied to the src face set only at the input of the neural network. Stabilizes color perturbations during face swapping. Reduces the quality of the color transfer by selecting the closest one in the src faceset. Thus the src faceset must be diverse enough. Typical fine value is 0.05"), 0.0, 0.3 )
157
158
self.options['gan_power'] = np.clip ( io.input_number ("GAN power", default_gan_power, add_info="0.0 .. 5.0", help_message="Forces the neural network to learn small details of the face. Enable it only when the face is trained enough with lr_dropout(on) and random_warp(off), and don't disable. The higher the value, the higher the chances of artifacts. Typical fine value is 0.1"), 0.0, 5.0 )
159
160
if self.options['gan_power'] != 0.0:
161
gan_patch_size = np.clip ( io.input_int("GAN patch size", default_gan_patch_size, add_info="3-640", help_message="The higher patch size, the higher the quality, the more VRAM is required. You can get sharper edges even at the lowest setting. Typical fine value is resolution / 8." ), 3, 640 )
162
self.options['gan_patch_size'] = gan_patch_size
163
164
gan_dims = np.clip ( io.input_int("GAN dimensions", default_gan_dims, add_info="4-512", help_message="The dimensions of the GAN network. The higher dimensions, the more VRAM is required. You can get sharper edges even at the lowest setting. Typical fine value is 16." ), 4, 512 )
165
self.options['gan_dims'] = gan_dims
166
167
if 'df' in self.options['archi']:
168
self.options['true_face_power'] = np.clip ( io.input_number ("'True face' power.", default_true_face_power, add_info="0.0000 .. 1.0", help_message="Experimental option. Discriminates result face to be more like src face. Higher value - stronger discrimination. Typical value is 0.01 . Comparison - https://i.imgur.com/czScS9q.png"), 0.0, 1.0 )
169
else:
170
self.options['true_face_power'] = 0.0
171
172
self.options['face_style_power'] = np.clip ( io.input_number("Face style power", default_face_style_power, add_info="0.0..100.0", help_message="Learn the color of the predicted face to be the same as dst inside mask. If you want to use this option with 'whole_face' you have to use XSeg trained mask. Warning: Enable it only after 10k iters, when predicted face is clear enough to start learn style. Start from 0.001 value and check history changes. Enabling this option increases the chance of model collapse."), 0.0, 100.0 )
173
self.options['bg_style_power'] = np.clip ( io.input_number("Background style power", default_bg_style_power, add_info="0.0..100.0", help_message="Learn the area outside mask of the predicted face to be the same as dst. If you want to use this option with 'whole_face' you have to use XSeg trained mask. For whole_face you have to use XSeg trained mask. This can make face more like dst. Enabling this option increases the chance of model collapse. Typical value is 2.0"), 0.0, 100.0 )
174
175
self.options['ct_mode'] = io.input_str (f"Color transfer for src faceset", default_ct_mode, ['none','rct','lct','mkl','idt','sot'], help_message="Change color distribution of src samples close to dst samples. Try all modes to find the best.")
176
self.options['clipgrad'] = io.input_bool ("Enable gradient clipping", default_clipgrad, help_message="Gradient clipping reduces chance of model collapse, sacrificing speed of training.")
177
178
self.options['pretrain'] = io.input_bool ("Enable pretraining mode", default_pretrain, help_message="Pretrain the model with large amount of various faces. After that, model can be used to train the fakes more quickly. Forces random_warp=N, random_flips=Y, gan_power=0.0, lr_dropout=N, styles=0.0, uniform_yaw=Y")
179
180
if self.options['pretrain'] and self.get_pretraining_data_path() is None:
181
raise Exception("pretraining_data_path is not defined")
182
183
self.gan_model_changed = (default_gan_patch_size != self.options['gan_patch_size']) or (default_gan_dims != self.options['gan_dims'])
184
185
self.pretrain_just_disabled = (default_pretrain == True and self.options['pretrain'] == False)
186
187
#override
188
def on_initialize(self):
189
device_config = nn.getCurrentDeviceConfig()
190
devices = device_config.devices
191
self.model_data_format = "NCHW" if len(devices) != 0 and not self.is_debug() else "NHWC"
192
nn.initialize(data_format=self.model_data_format)
193
tf = nn.tf
194
195
self.resolution = resolution = self.options['resolution']
196
self.face_type = {'h' : FaceType.HALF,
197
'mf' : FaceType.MID_FULL,
198
'f' : FaceType.FULL,
199
'wf' : FaceType.WHOLE_FACE,
200
'head' : FaceType.HEAD}[ self.options['face_type'] ]
201
202
if 'eyes_prio' in self.options:
203
self.options.pop('eyes_prio')
204
205
eyes_mouth_prio = self.options['eyes_mouth_prio']
206
207
archi_split = self.options['archi'].split('-')
208
209
if len(archi_split) == 2:
210
archi_type, archi_opts = archi_split
211
elif len(archi_split) == 1:
212
archi_type, archi_opts = archi_split[0], None
213
214
self.archi_type = archi_type
215
216
ae_dims = self.options['ae_dims']
217
e_dims = self.options['e_dims']
218
d_dims = self.options['d_dims']
219
d_mask_dims = self.options['d_mask_dims']
220
self.pretrain = self.options['pretrain']
221
if self.pretrain_just_disabled:
222
self.set_iter(0)
223
224
adabelief = self.options['adabelief']
225
226
use_fp16 = False
227
if self.is_exporting:
228
use_fp16 = io.input_bool ("Export quantized?", False, help_message='Makes the exported model faster. If you have problems, disable this option.')
229
230
self.gan_power = gan_power = 0.0 if self.pretrain else self.options['gan_power']
231
random_warp = False if self.pretrain else self.options['random_warp']
232
random_src_flip = self.random_src_flip if not self.pretrain else True
233
random_dst_flip = self.random_dst_flip if not self.pretrain else True
234
random_hsv_power = self.options['random_hsv_power'] if not self.pretrain else 0.0
235
blur_out_mask = self.options['blur_out_mask']
236
237
if self.pretrain:
238
self.options_show_override['lr_dropout'] = 'n'
239
self.options_show_override['random_warp'] = False
240
self.options_show_override['gan_power'] = 0.0
241
self.options_show_override['random_hsv_power'] = 0.0
242
self.options_show_override['face_style_power'] = 0.0
243
self.options_show_override['bg_style_power'] = 0.0
244
self.options_show_override['uniform_yaw'] = True
245
246
masked_training = self.options['masked_training']
247
ct_mode = self.options['ct_mode']
248
if ct_mode == 'none':
249
ct_mode = None
250
251
252
models_opt_on_gpu = False if len(devices) == 0 else self.options['models_opt_on_gpu']
253
models_opt_device = nn.tf_default_device_name if models_opt_on_gpu and self.is_training else '/CPU:0'
254
optimizer_vars_on_cpu = models_opt_device=='/CPU:0'
255
256
input_ch=3
257
bgr_shape = self.bgr_shape = nn.get4Dshape(resolution,resolution,input_ch)
258
mask_shape = nn.get4Dshape(resolution,resolution,1)
259
self.model_filename_list = []
260
261
with tf.device ('/CPU:0'):
262
#Place holders on CPU
263
self.warped_src = tf.placeholder (nn.floatx, bgr_shape, name='warped_src')
264
self.warped_dst = tf.placeholder (nn.floatx, bgr_shape, name='warped_dst')
265
266
self.target_src = tf.placeholder (nn.floatx, bgr_shape, name='target_src')
267
self.target_dst = tf.placeholder (nn.floatx, bgr_shape, name='target_dst')
268
269
self.target_srcm = tf.placeholder (nn.floatx, mask_shape, name='target_srcm')
270
self.target_srcm_em = tf.placeholder (nn.floatx, mask_shape, name='target_srcm_em')
271
self.target_dstm = tf.placeholder (nn.floatx, mask_shape, name='target_dstm')
272
self.target_dstm_em = tf.placeholder (nn.floatx, mask_shape, name='target_dstm_em')
273
274
# Initializing model classes
275
model_archi = nn.DeepFakeArchi(resolution, use_fp16=use_fp16, opts=archi_opts)
276
277
with tf.device (models_opt_device):
278
if 'df' in archi_type:
279
self.encoder = model_archi.Encoder(in_ch=input_ch, e_ch=e_dims, name='encoder')
280
encoder_out_ch = self.encoder.get_out_ch()*self.encoder.get_out_res(resolution)**2
281
282
self.inter = model_archi.Inter (in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims, name='inter')
283
inter_out_ch = self.inter.get_out_ch()
284
285
self.decoder_src = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, name='decoder_src')
286
self.decoder_dst = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, name='decoder_dst')
287
288
self.model_filename_list += [ [self.encoder, 'encoder.npy' ],
289
[self.inter, 'inter.npy' ],
290
[self.decoder_src, 'decoder_src.npy'],
291
[self.decoder_dst, 'decoder_dst.npy'] ]
292
293
if self.is_training:
294
if self.options['true_face_power'] != 0:
295
self.code_discriminator = nn.CodeDiscriminator(ae_dims, code_res=self.inter.get_out_res(), name='dis' )
296
self.model_filename_list += [ [self.code_discriminator, 'code_discriminator.npy'] ]
297
298
elif 'liae' in archi_type:
299
self.encoder = model_archi.Encoder(in_ch=input_ch, e_ch=e_dims, name='encoder')
300
encoder_out_ch = self.encoder.get_out_ch()*self.encoder.get_out_res(resolution)**2
301
302
self.inter_AB = model_archi.Inter(in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims*2, name='inter_AB')
303
self.inter_B = model_archi.Inter(in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims*2, name='inter_B')
304
305
inter_out_ch = self.inter_AB.get_out_ch()
306
inters_out_ch = inter_out_ch*2
307
self.decoder = model_archi.Decoder(in_ch=inters_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, name='decoder')
308
309
self.model_filename_list += [ [self.encoder, 'encoder.npy'],
310
[self.inter_AB, 'inter_AB.npy'],
311
[self.inter_B , 'inter_B.npy'],
312
[self.decoder , 'decoder.npy'] ]
313
314
if self.is_training:
315
if gan_power != 0:
316
self.D_src = nn.UNetPatchDiscriminator(patch_size=self.options['gan_patch_size'], in_ch=input_ch, base_ch=self.options['gan_dims'], name="D_src")
317
self.model_filename_list += [ [self.D_src, 'GAN.npy'] ]
318
319
# Initialize optimizers
320
lr=5e-5
321
if self.options['lr_dropout'] in ['y','cpu'] and not self.pretrain:
322
lr_cos = 500
323
lr_dropout = 0.3
324
else:
325
lr_cos = 0
326
lr_dropout = 1.0
327
OptimizerClass = nn.AdaBelief if adabelief else nn.RMSprop
328
clipnorm = 1.0 if self.options['clipgrad'] else 0.0
329
330
if 'df' in archi_type:
331
self.src_dst_saveable_weights = self.encoder.get_weights() + self.inter.get_weights() + self.decoder_src.get_weights() + self.decoder_dst.get_weights()
332
self.src_dst_trainable_weights = self.src_dst_saveable_weights
333
elif 'liae' in archi_type:
334
self.src_dst_saveable_weights = self.encoder.get_weights() + self.inter_AB.get_weights() + self.inter_B.get_weights() + self.decoder.get_weights()
335
if random_warp:
336
self.src_dst_trainable_weights = self.src_dst_saveable_weights
337
else:
338
self.src_dst_trainable_weights = self.encoder.get_weights() + self.inter_B.get_weights() + self.decoder.get_weights()
339
340
self.src_dst_opt = OptimizerClass(lr=lr, lr_dropout=lr_dropout, lr_cos=lr_cos, clipnorm=clipnorm, name='src_dst_opt')
341
self.src_dst_opt.initialize_variables (self.src_dst_saveable_weights, vars_on_cpu=optimizer_vars_on_cpu, lr_dropout_on_cpu=self.options['lr_dropout']=='cpu')
342
self.model_filename_list += [ (self.src_dst_opt, 'src_dst_opt.npy') ]
343
344
if self.options['true_face_power'] != 0:
345
self.D_code_opt = OptimizerClass(lr=lr, lr_dropout=lr_dropout, lr_cos=lr_cos, clipnorm=clipnorm, name='D_code_opt')
346
self.D_code_opt.initialize_variables ( self.code_discriminator.get_weights(), vars_on_cpu=optimizer_vars_on_cpu, lr_dropout_on_cpu=self.options['lr_dropout']=='cpu')
347
self.model_filename_list += [ (self.D_code_opt, 'D_code_opt.npy') ]
348
349
if gan_power != 0:
350
self.D_src_dst_opt = OptimizerClass(lr=lr, lr_dropout=lr_dropout, lr_cos=lr_cos, clipnorm=clipnorm, name='GAN_opt')
351
self.D_src_dst_opt.initialize_variables ( self.D_src.get_weights(), vars_on_cpu=optimizer_vars_on_cpu, lr_dropout_on_cpu=self.options['lr_dropout']=='cpu')#+self.D_src_x2.get_weights()
352
self.model_filename_list += [ (self.D_src_dst_opt, 'GAN_opt.npy') ]
353
354
if self.is_training:
355
# Adjust batch size for multiple GPU
356
gpu_count = max(1, len(devices) )
357
bs_per_gpu = max(1, self.get_batch_size() // gpu_count)
358
self.set_batch_size( gpu_count*bs_per_gpu)
359
360
# Compute losses per GPU
361
gpu_pred_src_src_list = []
362
gpu_pred_dst_dst_list = []
363
gpu_pred_src_dst_list = []
364
gpu_pred_src_srcm_list = []
365
gpu_pred_dst_dstm_list = []
366
gpu_pred_src_dstm_list = []
367
368
gpu_src_losses = []
369
gpu_dst_losses = []
370
gpu_G_loss_gvs = []
371
gpu_D_code_loss_gvs = []
372
gpu_D_src_dst_loss_gvs = []
373
374
for gpu_id in range(gpu_count):
375
with tf.device( f'/{devices[gpu_id].tf_dev_type}:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ):
376
with tf.device(f'/CPU:0'):
377
# slice on CPU, otherwise all batch data will be transfered to GPU first
378
batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu )
379
gpu_warped_src = self.warped_src [batch_slice,:,:,:]
380
gpu_warped_dst = self.warped_dst [batch_slice,:,:,:]
381
gpu_target_src = self.target_src [batch_slice,:,:,:]
382
gpu_target_dst = self.target_dst [batch_slice,:,:,:]
383
gpu_target_srcm = self.target_srcm[batch_slice,:,:,:]
384
gpu_target_srcm_em = self.target_srcm_em[batch_slice,:,:,:]
385
gpu_target_dstm = self.target_dstm[batch_slice,:,:,:]
386
gpu_target_dstm_em = self.target_dstm_em[batch_slice,:,:,:]
387
388
gpu_target_srcm_anti = 1-gpu_target_srcm
389
gpu_target_dstm_anti = 1-gpu_target_dstm
390
391
if blur_out_mask:
392
sigma = resolution / 128
393
394
x = nn.gaussian_blur(gpu_target_src*gpu_target_srcm_anti, sigma)
395
y = 1-nn.gaussian_blur(gpu_target_srcm, sigma)
396
y = tf.where(tf.equal(y, 0), tf.ones_like(y), y)
397
gpu_target_src = gpu_target_src*gpu_target_srcm + (x/y)*gpu_target_srcm_anti
398
399
x = nn.gaussian_blur(gpu_target_dst*gpu_target_dstm_anti, sigma)
400
y = 1-nn.gaussian_blur(gpu_target_dstm, sigma)
401
y = tf.where(tf.equal(y, 0), tf.ones_like(y), y)
402
gpu_target_dst = gpu_target_dst*gpu_target_dstm + (x/y)*gpu_target_dstm_anti
403
404
405
# process model tensors
406
if 'df' in archi_type:
407
gpu_src_code = self.inter(self.encoder(gpu_warped_src))
408
gpu_dst_code = self.inter(self.encoder(gpu_warped_dst))
409
gpu_pred_src_src, gpu_pred_src_srcm = self.decoder_src(gpu_src_code)
410
gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code)
411
gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code)
412
gpu_pred_src_dst_no_code_grad, _ = self.decoder_src(tf.stop_gradient(gpu_dst_code))
413
414
elif 'liae' in archi_type:
415
gpu_src_code = self.encoder (gpu_warped_src)
416
gpu_src_inter_AB_code = self.inter_AB (gpu_src_code)
417
gpu_src_code = tf.concat([gpu_src_inter_AB_code,gpu_src_inter_AB_code], nn.conv2d_ch_axis )
418
gpu_dst_code = self.encoder (gpu_warped_dst)
419
gpu_dst_inter_B_code = self.inter_B (gpu_dst_code)
420
gpu_dst_inter_AB_code = self.inter_AB (gpu_dst_code)
421
gpu_dst_code = tf.concat([gpu_dst_inter_B_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis )
422
gpu_src_dst_code = tf.concat([gpu_dst_inter_AB_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis )
423
424
gpu_pred_src_src, gpu_pred_src_srcm = self.decoder(gpu_src_code)
425
gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder(gpu_dst_code)
426
gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code)
427
gpu_pred_src_dst_no_code_grad, _ = self.decoder(tf.stop_gradient(gpu_src_dst_code))
428
429
gpu_pred_src_src_list.append(gpu_pred_src_src)
430
gpu_pred_dst_dst_list.append(gpu_pred_dst_dst)
431
gpu_pred_src_dst_list.append(gpu_pred_src_dst)
432
433
gpu_pred_src_srcm_list.append(gpu_pred_src_srcm)
434
gpu_pred_dst_dstm_list.append(gpu_pred_dst_dstm)
435
gpu_pred_src_dstm_list.append(gpu_pred_src_dstm)
436
437
gpu_target_srcm_blur = nn.gaussian_blur(gpu_target_srcm, max(1, resolution // 32) )
438
gpu_target_srcm_blur = tf.clip_by_value(gpu_target_srcm_blur, 0, 0.5) * 2
439
gpu_target_srcm_anti_blur = 1.0-gpu_target_srcm_blur
440
441
gpu_target_dstm_blur = nn.gaussian_blur(gpu_target_dstm, max(1, resolution // 32) )
442
gpu_target_dstm_blur = tf.clip_by_value(gpu_target_dstm_blur, 0, 0.5) * 2
443
444
gpu_style_mask_blur = nn.gaussian_blur(gpu_pred_src_dstm*gpu_pred_dst_dstm, max(1, resolution // 32) )
445
gpu_style_mask_blur = tf.stop_gradient(tf.clip_by_value(gpu_target_srcm_blur, 0, 1.0))
446
gpu_style_mask_anti_blur = 1.0 - gpu_style_mask_blur
447
448
gpu_target_dst_masked = gpu_target_dst*gpu_target_dstm_blur
449
450
gpu_target_src_anti_masked = gpu_target_src*gpu_target_srcm_anti_blur
451
gpu_pred_src_src_anti_masked = gpu_pred_src_src*gpu_target_srcm_anti_blur
452
453
gpu_target_src_masked_opt = gpu_target_src*gpu_target_srcm_blur if masked_training else gpu_target_src
454
gpu_target_dst_masked_opt = gpu_target_dst_masked if masked_training else gpu_target_dst
455
gpu_pred_src_src_masked_opt = gpu_pred_src_src*gpu_target_srcm_blur if masked_training else gpu_pred_src_src
456
gpu_pred_dst_dst_masked_opt = gpu_pred_dst_dst*gpu_target_dstm_blur if masked_training else gpu_pred_dst_dst
457
458
if resolution < 256:
459
gpu_src_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])
460
else:
461
gpu_src_loss = tf.reduce_mean ( 5*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])
462
gpu_src_loss += tf.reduce_mean ( 5*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/23.2)), axis=[1])
463
gpu_src_loss += tf.reduce_mean ( 10*tf.square ( gpu_target_src_masked_opt - gpu_pred_src_src_masked_opt ), axis=[1,2,3])
464
465
if eyes_mouth_prio:
466
gpu_src_loss += tf.reduce_mean ( 300*tf.abs ( gpu_target_src*gpu_target_srcm_em - gpu_pred_src_src*gpu_target_srcm_em ), axis=[1,2,3])
467
468
gpu_src_loss += tf.reduce_mean ( 10*tf.square( gpu_target_srcm - gpu_pred_src_srcm ),axis=[1,2,3] )
469
470
face_style_power = self.options['face_style_power'] / 100.0
471
if face_style_power != 0 and not self.pretrain:
472
gpu_src_loss += nn.style_loss(gpu_pred_src_dst_no_code_grad*tf.stop_gradient(gpu_pred_src_dstm), tf.stop_gradient(gpu_pred_dst_dst*gpu_pred_dst_dstm), gaussian_blur_radius=resolution//8, loss_weight=10000*face_style_power)
473
474
bg_style_power = self.options['bg_style_power'] / 100.0
475
if bg_style_power != 0 and not self.pretrain:
476
gpu_target_dst_style_anti_masked = gpu_target_dst*gpu_style_mask_anti_blur
477
gpu_psd_style_anti_masked = gpu_pred_src_dst*gpu_style_mask_anti_blur
478
479
gpu_src_loss += tf.reduce_mean( (10*bg_style_power)*nn.dssim( gpu_psd_style_anti_masked, gpu_target_dst_style_anti_masked, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])
480
gpu_src_loss += tf.reduce_mean( (10*bg_style_power)*tf.square(gpu_psd_style_anti_masked - gpu_target_dst_style_anti_masked), axis=[1,2,3] )
481
482
if resolution < 256:
483
gpu_dst_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1])
484
else:
485
gpu_dst_loss = tf.reduce_mean ( 5*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1])
486
gpu_dst_loss += tf.reduce_mean ( 5*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/23.2) ), axis=[1])
487
gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dst_masked_opt- gpu_pred_dst_dst_masked_opt ), axis=[1,2,3])
488
489
if eyes_mouth_prio:
490
gpu_dst_loss += tf.reduce_mean ( 300*tf.abs ( gpu_target_dst*gpu_target_dstm_em - gpu_pred_dst_dst*gpu_target_dstm_em ), axis=[1,2,3])
491
492
gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dstm - gpu_pred_dst_dstm ),axis=[1,2,3] )
493
494
gpu_src_losses += [gpu_src_loss]
495
gpu_dst_losses += [gpu_dst_loss]
496
497
gpu_G_loss = gpu_src_loss + gpu_dst_loss
498
499
def DLoss(labels,logits):
500
return tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits), axis=[1,2,3])
501
502
if self.options['true_face_power'] != 0:
503
gpu_src_code_d = self.code_discriminator( gpu_src_code )
504
gpu_src_code_d_ones = tf.ones_like (gpu_src_code_d)
505
gpu_src_code_d_zeros = tf.zeros_like(gpu_src_code_d)
506
gpu_dst_code_d = self.code_discriminator( gpu_dst_code )
507
gpu_dst_code_d_ones = tf.ones_like(gpu_dst_code_d)
508
509
gpu_G_loss += self.options['true_face_power']*DLoss(gpu_src_code_d_ones, gpu_src_code_d)
510
511
gpu_D_code_loss = (DLoss(gpu_dst_code_d_ones , gpu_dst_code_d) + \
512
DLoss(gpu_src_code_d_zeros, gpu_src_code_d) ) * 0.5
513
514
gpu_D_code_loss_gvs += [ nn.gradients (gpu_D_code_loss, self.code_discriminator.get_weights() ) ]
515
516
if gan_power != 0:
517
gpu_pred_src_src_d, \
518
gpu_pred_src_src_d2 = self.D_src(gpu_pred_src_src_masked_opt)
519
520
gpu_pred_src_src_d_ones = tf.ones_like (gpu_pred_src_src_d)
521
gpu_pred_src_src_d_zeros = tf.zeros_like(gpu_pred_src_src_d)
522
523
gpu_pred_src_src_d2_ones = tf.ones_like (gpu_pred_src_src_d2)
524
gpu_pred_src_src_d2_zeros = tf.zeros_like(gpu_pred_src_src_d2)
525
526
gpu_target_src_d, \
527
gpu_target_src_d2 = self.D_src(gpu_target_src_masked_opt)
528
529
gpu_target_src_d_ones = tf.ones_like(gpu_target_src_d)
530
gpu_target_src_d2_ones = tf.ones_like(gpu_target_src_d2)
531
532
gpu_D_src_dst_loss = (DLoss(gpu_target_src_d_ones , gpu_target_src_d) + \
533
DLoss(gpu_pred_src_src_d_zeros , gpu_pred_src_src_d) ) * 0.5 + \
534
(DLoss(gpu_target_src_d2_ones , gpu_target_src_d2) + \
535
DLoss(gpu_pred_src_src_d2_zeros , gpu_pred_src_src_d2) ) * 0.5
536
537
gpu_D_src_dst_loss_gvs += [ nn.gradients (gpu_D_src_dst_loss, self.D_src.get_weights() ) ]#+self.D_src_x2.get_weights()
538
539
gpu_G_loss += gan_power*(DLoss(gpu_pred_src_src_d_ones, gpu_pred_src_src_d) + \
540
DLoss(gpu_pred_src_src_d2_ones, gpu_pred_src_src_d2))
541
542
if masked_training:
543
# Minimal src-src-bg rec with total_variation_mse to suppress random bright dots from gan
544
gpu_G_loss += 0.000001*nn.total_variation_mse(gpu_pred_src_src)
545
gpu_G_loss += 0.02*tf.reduce_mean(tf.square(gpu_pred_src_src_anti_masked-gpu_target_src_anti_masked),axis=[1,2,3] )
546
547
gpu_G_loss_gvs += [ nn.gradients ( gpu_G_loss, self.src_dst_trainable_weights )]
548
549
550
551
552
# Average losses and gradients, and create optimizer update ops
553
with tf.device(f'/CPU:0'):
554
pred_src_src = nn.concat(gpu_pred_src_src_list, 0)
555
pred_dst_dst = nn.concat(gpu_pred_dst_dst_list, 0)
556
pred_src_dst = nn.concat(gpu_pred_src_dst_list, 0)
557
pred_src_srcm = nn.concat(gpu_pred_src_srcm_list, 0)
558
pred_dst_dstm = nn.concat(gpu_pred_dst_dstm_list, 0)
559
pred_src_dstm = nn.concat(gpu_pred_src_dstm_list, 0)
560
561
with tf.device (models_opt_device):
562
src_loss = tf.concat(gpu_src_losses, 0)
563
dst_loss = tf.concat(gpu_dst_losses, 0)
564
src_dst_loss_gv_op = self.src_dst_opt.get_update_op (nn.average_gv_list (gpu_G_loss_gvs))
565
566
if self.options['true_face_power'] != 0:
567
D_loss_gv_op = self.D_code_opt.get_update_op (nn.average_gv_list(gpu_D_code_loss_gvs))
568
569
if gan_power != 0:
570
src_D_src_dst_loss_gv_op = self.D_src_dst_opt.get_update_op (nn.average_gv_list(gpu_D_src_dst_loss_gvs) )
571
572
573
# Initializing training and view functions
574
def src_dst_train(warped_src, target_src, target_srcm, target_srcm_em, \
575
warped_dst, target_dst, target_dstm, target_dstm_em, ):
576
s, d = nn.tf_sess.run ( [ src_loss, dst_loss, src_dst_loss_gv_op],
577
feed_dict={self.warped_src :warped_src,
578
self.target_src :target_src,
579
self.target_srcm:target_srcm,
580
self.target_srcm_em:target_srcm_em,
581
self.warped_dst :warped_dst,
582
self.target_dst :target_dst,
583
self.target_dstm:target_dstm,
584
self.target_dstm_em:target_dstm_em,
585
})[:2]
586
return s, d
587
self.src_dst_train = src_dst_train
588
589
if self.options['true_face_power'] != 0:
590
def D_train(warped_src, warped_dst):
591
nn.tf_sess.run ([D_loss_gv_op], feed_dict={self.warped_src: warped_src, self.warped_dst: warped_dst})
592
self.D_train = D_train
593
594
if gan_power != 0:
595
def D_src_dst_train(warped_src, target_src, target_srcm, target_srcm_em, \
596
warped_dst, target_dst, target_dstm, target_dstm_em, ):
597
nn.tf_sess.run ([src_D_src_dst_loss_gv_op], feed_dict={self.warped_src :warped_src,
598
self.target_src :target_src,
599
self.target_srcm:target_srcm,
600
self.target_srcm_em:target_srcm_em,
601
self.warped_dst :warped_dst,
602
self.target_dst :target_dst,
603
self.target_dstm:target_dstm,
604
self.target_dstm_em:target_dstm_em})
605
self.D_src_dst_train = D_src_dst_train
606
607
608
def AE_view(warped_src, warped_dst):
609
return nn.tf_sess.run ( [pred_src_src, pred_dst_dst, pred_dst_dstm, pred_src_dst, pred_src_dstm],
610
feed_dict={self.warped_src:warped_src,
611
self.warped_dst:warped_dst})
612
self.AE_view = AE_view
613
else:
614
# Initializing merge function
615
with tf.device( nn.tf_default_device_name if len(devices) != 0 else f'/CPU:0'):
616
if 'df' in archi_type:
617
gpu_dst_code = self.inter(self.encoder(self.warped_dst))
618
gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code)
619
_, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code)
620
621
elif 'liae' in archi_type:
622
gpu_dst_code = self.encoder (self.warped_dst)
623
gpu_dst_inter_B_code = self.inter_B (gpu_dst_code)
624
gpu_dst_inter_AB_code = self.inter_AB (gpu_dst_code)
625
gpu_dst_code = tf.concat([gpu_dst_inter_B_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis)
626
gpu_src_dst_code = tf.concat([gpu_dst_inter_AB_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis)
627
628
gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code)
629
_, gpu_pred_dst_dstm = self.decoder(gpu_dst_code)
630
631
632
def AE_merge( warped_dst):
633
return nn.tf_sess.run ( [gpu_pred_src_dst, gpu_pred_dst_dstm, gpu_pred_src_dstm], feed_dict={self.warped_dst:warped_dst})
634
635
self.AE_merge = AE_merge
636
637
# Loading/initializing all models/optimizers weights
638
for model, filename in io.progress_bar_generator(self.model_filename_list, "Initializing models"):
639
if self.pretrain_just_disabled:
640
do_init = False
641
if 'df' in archi_type:
642
if model == self.inter:
643
do_init = True
644
elif 'liae' in archi_type:
645
if model == self.inter_AB or model == self.inter_B:
646
do_init = True
647
else:
648
do_init = self.is_first_run()
649
if self.is_training and gan_power != 0 and model == self.D_src:
650
if self.gan_model_changed:
651
do_init = True
652
653
if not do_init:
654
do_init = not model.load_weights( self.get_strpath_storage_for_file(filename) )
655
656
if do_init:
657
model.init_weights()
658
659
660
###############
661
662
# initializing sample generators
663
if self.is_training:
664
training_data_src_path = self.training_data_src_path if not self.pretrain else self.get_pretraining_data_path()
665
training_data_dst_path = self.training_data_dst_path if not self.pretrain else self.get_pretraining_data_path()
666
667
random_ct_samples_path=training_data_dst_path if ct_mode is not None and not self.pretrain else None
668
669
cpu_count = multiprocessing.cpu_count()
670
src_generators_count = cpu_count // 2
671
dst_generators_count = cpu_count // 2
672
if ct_mode is not None:
673
src_generators_count = int(src_generators_count * 1.5)
674
675
self.set_training_data_generators ([
676
SampleGeneratorFace(training_data_src_path, random_ct_samples_path=random_ct_samples_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
677
sample_process_options=SampleProcessor.Options(scale_range=[-0.15, 0.15], random_flip=random_src_flip),
678
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'random_hsv_shift_amount' : random_hsv_power, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
679
{'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
680
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
681
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.EYES_MOUTH, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
682
],
683
uniform_yaw_distribution=self.options['uniform_yaw'] or self.pretrain,
684
generators_count=src_generators_count ),
685
686
SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
687
sample_process_options=SampleProcessor.Options(scale_range=[-0.15, 0.15], random_flip=random_dst_flip),
688
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
689
{'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
690
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
691
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.EYES_MOUTH, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
692
],
693
uniform_yaw_distribution=self.options['uniform_yaw'] or self.pretrain,
694
generators_count=dst_generators_count )
695
])
696
697
if self.pretrain_just_disabled:
698
self.update_sample_for_preview(force_new=True)
699
700
def export_dfm (self):
701
output_path=self.get_strpath_storage_for_file('model.dfm')
702
703
io.log_info(f'Dumping .dfm to {output_path}')
704
705
tf = nn.tf
706
nn.set_data_format('NCHW')
707
708
with tf.device (nn.tf_default_device_name):
709
warped_dst = tf.placeholder (nn.floatx, (None, self.resolution, self.resolution, 3), name='in_face')
710
warped_dst = tf.transpose(warped_dst, (0,3,1,2))
711
712
713
if 'df' in self.archi_type:
714
gpu_dst_code = self.inter(self.encoder(warped_dst))
715
gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code)
716
_, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code)
717
718
elif 'liae' in self.archi_type:
719
gpu_dst_code = self.encoder (warped_dst)
720
gpu_dst_inter_B_code = self.inter_B (gpu_dst_code)
721
gpu_dst_inter_AB_code = self.inter_AB (gpu_dst_code)
722
gpu_dst_code = tf.concat([gpu_dst_inter_B_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis)
723
gpu_src_dst_code = tf.concat([gpu_dst_inter_AB_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis)
724
725
gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code)
726
_, gpu_pred_dst_dstm = self.decoder(gpu_dst_code)
727
728
gpu_pred_src_dst = tf.transpose(gpu_pred_src_dst, (0,2,3,1))
729
gpu_pred_dst_dstm = tf.transpose(gpu_pred_dst_dstm, (0,2,3,1))
730
gpu_pred_src_dstm = tf.transpose(gpu_pred_src_dstm, (0,2,3,1))
731
732
tf.identity(gpu_pred_dst_dstm, name='out_face_mask')
733
tf.identity(gpu_pred_src_dst, name='out_celeb_face')
734
tf.identity(gpu_pred_src_dstm, name='out_celeb_face_mask')
735
736
output_graph_def = tf.graph_util.convert_variables_to_constants(
737
nn.tf_sess,
738
tf.get_default_graph().as_graph_def(),
739
['out_face_mask','out_celeb_face','out_celeb_face_mask']
740
)
741
742
import tf2onnx
743
with tf.device("/CPU:0"):
744
model_proto, _ = tf2onnx.convert._convert_common(
745
output_graph_def,
746
name='SAEHD',
747
input_names=['in_face:0'],
748
output_names=['out_face_mask:0','out_celeb_face:0','out_celeb_face_mask:0'],
749
opset=12,
750
output_path=output_path)
751
752
#override
753
def get_model_filename_list(self):
754
return self.model_filename_list
755
756
#override
757
def onSave(self):
758
for model, filename in io.progress_bar_generator(self.get_model_filename_list(), "Saving", leave=False):
759
model.save_weights ( self.get_strpath_storage_for_file(filename) )
760
761
#override
762
def should_save_preview_history(self):
763
return (not io.is_colab() and self.iter % ( 10*(max(1,self.resolution // 64)) ) == 0) or \
764
(io.is_colab() and self.iter % 100 == 0)
765
766
#override
767
def onTrainOneIter(self):
768
if self.get_iter() == 0 and not self.pretrain and not self.pretrain_just_disabled:
769
io.log_info('You are training the model from scratch. It is strongly recommended to use a pretrained model to speed up the training and improve the quality.\n')
770
771
( (warped_src, target_src, target_srcm, target_srcm_em), \
772
(warped_dst, target_dst, target_dstm, target_dstm_em) ) = self.generate_next_samples()
773
774
src_loss, dst_loss = self.src_dst_train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em)
775
776
if self.options['true_face_power'] != 0 and not self.pretrain:
777
self.D_train (warped_src, warped_dst)
778
779
if self.gan_power != 0:
780
self.D_src_dst_train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em)
781
782
return ( ('src_loss', np.mean(src_loss) ), ('dst_loss', np.mean(dst_loss) ), )
783
784
#override
785
def onGetPreview(self, samples, for_history=False):
786
( (warped_src, target_src, target_srcm, target_srcm_em),
787
(warped_dst, target_dst, target_dstm, target_dstm_em) ) = samples
788
789
S, D, SS, DD, DDM, SD, SDM = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([target_src,target_dst] + self.AE_view (target_src, target_dst) ) ]
790
DDM, SDM, = [ np.repeat (x, (3,), -1) for x in [DDM, SDM] ]
791
792
target_srcm, target_dstm = [ nn.to_data_format(x,"NHWC", self.model_data_format) for x in ([target_srcm, target_dstm] )]
793
794
n_samples = min(4, self.get_batch_size(), 800 // self.resolution )
795
796
if self.resolution <= 256:
797
result = []
798
799
st = []
800
for i in range(n_samples):
801
ar = S[i], SS[i], D[i], DD[i], SD[i]
802
st.append ( np.concatenate ( ar, axis=1) )
803
result += [ ('SAEHD', np.concatenate (st, axis=0 )), ]
804
805
806
st_m = []
807
for i in range(n_samples):
808
SD_mask = DDM[i]*SDM[i] if self.face_type < FaceType.HEAD else SDM[i]
809
810
ar = S[i]*target_srcm[i], SS[i], D[i]*target_dstm[i], DD[i]*DDM[i], SD[i]*SD_mask
811
st_m.append ( np.concatenate ( ar, axis=1) )
812
813
result += [ ('SAEHD masked', np.concatenate (st_m, axis=0 )), ]
814
else:
815
result = []
816
817
st = []
818
for i in range(n_samples):
819
ar = S[i], SS[i]
820
st.append ( np.concatenate ( ar, axis=1) )
821
result += [ ('SAEHD src-src', np.concatenate (st, axis=0 )), ]
822
823
st = []
824
for i in range(n_samples):
825
ar = D[i], DD[i]
826
st.append ( np.concatenate ( ar, axis=1) )
827
result += [ ('SAEHD dst-dst', np.concatenate (st, axis=0 )), ]
828
829
st = []
830
for i in range(n_samples):
831
ar = D[i], SD[i]
832
st.append ( np.concatenate ( ar, axis=1) )
833
result += [ ('SAEHD pred', np.concatenate (st, axis=0 )), ]
834
835
836
st_m = []
837
for i in range(n_samples):
838
ar = S[i]*target_srcm[i], SS[i]
839
st_m.append ( np.concatenate ( ar, axis=1) )
840
result += [ ('SAEHD masked src-src', np.concatenate (st_m, axis=0 )), ]
841
842
st_m = []
843
for i in range(n_samples):
844
ar = D[i]*target_dstm[i], DD[i]*DDM[i]
845
st_m.append ( np.concatenate ( ar, axis=1) )
846
result += [ ('SAEHD masked dst-dst', np.concatenate (st_m, axis=0 )), ]
847
848
st_m = []
849
for i in range(n_samples):
850
SD_mask = DDM[i]*SDM[i] if self.face_type < FaceType.HEAD else SDM[i]
851
ar = D[i]*target_dstm[i], SD[i]*SD_mask
852
st_m.append ( np.concatenate ( ar, axis=1) )
853
result += [ ('SAEHD masked pred', np.concatenate (st_m, axis=0 )), ]
854
855
return result
856
857
def predictor_func (self, face=None):
858
face = nn.to_data_format(face[None,...], self.model_data_format, "NHWC")
859
860
bgr, mask_dst_dstm, mask_src_dstm = [ nn.to_data_format(x,"NHWC", self.model_data_format).astype(np.float32) for x in self.AE_merge (face) ]
861
862
return bgr[0], mask_src_dstm[0][...,0], mask_dst_dstm[0][...,0]
863
864
#override
865
def get_MergerConfig(self):
866
import merger
867
return self.predictor_func, (self.options['resolution'], self.options['resolution'], 3), merger.MergerConfigMasked(face_type=self.face_type, default_mode = 'overlay')
868
869
Model = SAEHDModel
870
871