Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
iperov
GitHub Repository: iperov/deepfacelab
Path: blob/master/models/Model_AMP/Model.py
628 views
1
import multiprocessing
2
import operator
3
from functools import partial
4
5
import numpy as np
6
7
from core import mathlib
8
from core.interact import interact as io
9
from core.leras import nn
10
from facelib import FaceType
11
from models import ModelBase
12
from samplelib import *
13
from core.cv2ex import *
14
15
class AMPModel(ModelBase):
16
17
#override
18
def on_initialize_options(self):
19
default_resolution = self.options['resolution'] = self.load_or_def_option('resolution', 224)
20
default_face_type = self.options['face_type'] = self.load_or_def_option('face_type', 'wf')
21
default_models_opt_on_gpu = self.options['models_opt_on_gpu'] = self.load_or_def_option('models_opt_on_gpu', True)
22
23
default_ae_dims = self.options['ae_dims'] = self.load_or_def_option('ae_dims', 256)
24
default_inter_dims = self.options['inter_dims'] = self.load_or_def_option('inter_dims', 1024)
25
26
default_e_dims = self.options['e_dims'] = self.load_or_def_option('e_dims', 64)
27
default_d_dims = self.options['d_dims'] = self.options.get('d_dims', None)
28
default_d_mask_dims = self.options['d_mask_dims'] = self.options.get('d_mask_dims', None)
29
default_morph_factor = self.options['morph_factor'] = self.options.get('morph_factor', 0.5)
30
default_uniform_yaw = self.options['uniform_yaw'] = self.load_or_def_option('uniform_yaw', False)
31
default_blur_out_mask = self.options['blur_out_mask'] = self.load_or_def_option('blur_out_mask', False)
32
default_lr_dropout = self.options['lr_dropout'] = self.load_or_def_option('lr_dropout', 'n')
33
default_random_warp = self.options['random_warp'] = self.load_or_def_option('random_warp', True)
34
default_ct_mode = self.options['ct_mode'] = self.load_or_def_option('ct_mode', 'none')
35
default_clipgrad = self.options['clipgrad'] = self.load_or_def_option('clipgrad', False)
36
37
ask_override = self.ask_override()
38
if self.is_first_run() or ask_override:
39
self.ask_autobackup_hour()
40
self.ask_write_preview_history()
41
self.ask_target_iter()
42
self.ask_random_src_flip()
43
self.ask_random_dst_flip()
44
self.ask_batch_size(8)
45
46
if self.is_first_run():
47
resolution = io.input_int("Resolution", default_resolution, add_info="64-640", help_message="More resolution requires more VRAM and time to train. Value will be adjusted to multiple of 32 .")
48
resolution = np.clip ( (resolution // 32) * 32, 64, 640)
49
self.options['resolution'] = resolution
50
self.options['face_type'] = io.input_str ("Face type", default_face_type, ['f','wf','head'], help_message="whole face / head").lower()
51
52
53
default_d_dims = self.options['d_dims'] = self.load_or_def_option('d_dims', 64)
54
55
default_d_mask_dims = default_d_dims // 3
56
default_d_mask_dims += default_d_mask_dims % 2
57
default_d_mask_dims = self.options['d_mask_dims'] = self.load_or_def_option('d_mask_dims', default_d_mask_dims)
58
59
if self.is_first_run():
60
self.options['ae_dims'] = np.clip ( io.input_int("AutoEncoder dimensions", default_ae_dims, add_info="32-1024", help_message="All face information will packed to AE dims. If amount of AE dims are not enough, then for example closed eyes will not be recognized. More dims are better, but require more VRAM. You can fine-tune model size to fit your GPU." ), 32, 1024 )
61
self.options['inter_dims'] = np.clip ( io.input_int("Inter dimensions", default_inter_dims, add_info="32-2048", help_message="Should be equal or more than AutoEncoder dimensions. More dims are better, but require more VRAM. You can fine-tune model size to fit your GPU." ), 32, 2048 )
62
63
e_dims = np.clip ( io.input_int("Encoder dimensions", default_e_dims, add_info="16-256", help_message="More dims help to recognize more facial features and achieve sharper result, but require more VRAM. You can fine-tune model size to fit your GPU." ), 16, 256 )
64
self.options['e_dims'] = e_dims + e_dims % 2
65
66
d_dims = np.clip ( io.input_int("Decoder dimensions", default_d_dims, add_info="16-256", help_message="More dims help to recognize more facial features and achieve sharper result, but require more VRAM. You can fine-tune model size to fit your GPU." ), 16, 256 )
67
self.options['d_dims'] = d_dims + d_dims % 2
68
69
d_mask_dims = np.clip ( io.input_int("Decoder mask dimensions", default_d_mask_dims, add_info="16-256", help_message="Typical mask dimensions = decoder dimensions / 3. If you manually cut out obstacles from the dst mask, you can increase this parameter to achieve better quality." ), 16, 256 )
70
self.options['d_mask_dims'] = d_mask_dims + d_mask_dims % 2
71
72
morph_factor = np.clip ( io.input_number ("Morph factor.", default_morph_factor, add_info="0.1 .. 0.5", help_message="Typical fine value is 0.5"), 0.1, 0.5 )
73
self.options['morph_factor'] = morph_factor
74
75
if self.is_first_run() or ask_override:
76
self.options['uniform_yaw'] = io.input_bool ("Uniform yaw distribution of samples", default_uniform_yaw, help_message='Helps to fix blurry side faces due to small amount of them in the faceset.')
77
self.options['blur_out_mask'] = io.input_bool ("Blur out mask", default_blur_out_mask, help_message='Blurs nearby area outside of applied face mask of training samples. The result is the background near the face is smoothed and less noticeable on swapped face. The exact xseg mask in src and dst faceset is required.')
78
self.options['lr_dropout'] = io.input_str (f"Use learning rate dropout", default_lr_dropout, ['n','y','cpu'], help_message="When the face is trained enough, you can enable this option to get extra sharpness and reduce subpixel shake for less amount of iterations. Enabled it before `disable random warp` and before GAN. \nn - disabled.\ny - enabled\ncpu - enabled on CPU. This allows not to use extra VRAM, sacrificing 20% time of iteration.")
79
80
default_gan_power = self.options['gan_power'] = self.load_or_def_option('gan_power', 0.0)
81
default_gan_patch_size = self.options['gan_patch_size'] = self.load_or_def_option('gan_patch_size', self.options['resolution'] // 8)
82
default_gan_dims = self.options['gan_dims'] = self.load_or_def_option('gan_dims', 16)
83
84
if self.is_first_run() or ask_override:
85
self.options['models_opt_on_gpu'] = io.input_bool ("Place models and optimizer on GPU", default_models_opt_on_gpu, help_message="When you train on one GPU, by default model and optimizer weights are placed on GPU to accelerate the process. You can place they on CPU to free up extra VRAM, thus set bigger dimensions.")
86
87
self.options['random_warp'] = io.input_bool ("Enable random warp of samples", default_random_warp, help_message="Random warp is required to generalize facial expressions of both faces. When the face is trained enough, you can disable it to get extra sharpness and reduce subpixel shake for less amount of iterations.")
88
89
self.options['gan_power'] = np.clip ( io.input_number ("GAN power", default_gan_power, add_info="0.0 .. 5.0", help_message="Forces the neural network to learn small details of the face. Enable it only when the face is trained enough with random_warp(off), and don't disable. The higher the value, the higher the chances of artifacts. Typical fine value is 0.1"), 0.0, 5.0 )
90
91
if self.options['gan_power'] != 0.0:
92
gan_patch_size = np.clip ( io.input_int("GAN patch size", default_gan_patch_size, add_info="3-640", help_message="The higher patch size, the higher the quality, the more VRAM is required. You can get sharper edges even at the lowest setting. Typical fine value is resolution / 8." ), 3, 640 )
93
self.options['gan_patch_size'] = gan_patch_size
94
95
gan_dims = np.clip ( io.input_int("GAN dimensions", default_gan_dims, add_info="4-512", help_message="The dimensions of the GAN network. The higher dimensions, the more VRAM is required. You can get sharper edges even at the lowest setting. Typical fine value is 16." ), 4, 512 )
96
self.options['gan_dims'] = gan_dims
97
98
self.options['ct_mode'] = io.input_str (f"Color transfer for src faceset", default_ct_mode, ['none','rct','lct','mkl','idt','sot'], help_message="Change color distribution of src samples close to dst samples. If src faceset is deverse enough, then lct mode is fine in most cases.")
99
self.options['clipgrad'] = io.input_bool ("Enable gradient clipping", default_clipgrad, help_message="Gradient clipping reduces chance of model collapse, sacrificing speed of training.")
100
101
self.gan_model_changed = (default_gan_patch_size != self.options['gan_patch_size']) or (default_gan_dims != self.options['gan_dims'])
102
103
#override
104
def on_initialize(self):
105
device_config = nn.getCurrentDeviceConfig()
106
devices = device_config.devices
107
self.model_data_format = "NCHW"
108
nn.initialize(data_format=self.model_data_format)
109
tf = nn.tf
110
111
input_ch=3
112
resolution = self.resolution = self.options['resolution']
113
e_dims = self.options['e_dims']
114
ae_dims = self.options['ae_dims']
115
inter_dims = self.inter_dims = self.options['inter_dims']
116
inter_res = self.inter_res = resolution // 32
117
d_dims = self.options['d_dims']
118
d_mask_dims = self.options['d_mask_dims']
119
face_type = self.face_type = {'f' : FaceType.FULL,
120
'wf' : FaceType.WHOLE_FACE,
121
'head' : FaceType.HEAD}[ self.options['face_type'] ]
122
morph_factor = self.options['morph_factor']
123
gan_power = self.gan_power = self.options['gan_power']
124
random_warp = self.options['random_warp']
125
126
blur_out_mask = self.options['blur_out_mask']
127
128
ct_mode = self.options['ct_mode']
129
if ct_mode == 'none':
130
ct_mode = None
131
132
use_fp16 = False
133
if self.is_exporting:
134
use_fp16 = io.input_bool ("Export quantized?", False, help_message='Makes the exported model faster. If you have problems, disable this option.')
135
136
conv_dtype = tf.float16 if use_fp16 else tf.float32
137
138
class Downscale(nn.ModelBase):
139
def on_build(self, in_ch, out_ch, kernel_size=5 ):
140
self.conv1 = nn.Conv2D( in_ch, out_ch, kernel_size=kernel_size, strides=2, padding='SAME', dtype=conv_dtype)
141
142
def forward(self, x):
143
return tf.nn.leaky_relu(self.conv1(x), 0.1)
144
145
class Upscale(nn.ModelBase):
146
def on_build(self, in_ch, out_ch, kernel_size=3 ):
147
self.conv1 = nn.Conv2D(in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype)
148
149
def forward(self, x):
150
x = nn.depth_to_space(tf.nn.leaky_relu(self.conv1(x), 0.1), 2)
151
return x
152
153
class ResidualBlock(nn.ModelBase):
154
def on_build(self, ch, kernel_size=3 ):
155
self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype)
156
self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype)
157
158
def forward(self, inp):
159
x = self.conv1(inp)
160
x = tf.nn.leaky_relu(x, 0.2)
161
x = self.conv2(x)
162
x = tf.nn.leaky_relu(inp+x, 0.2)
163
return x
164
165
class Encoder(nn.ModelBase):
166
def on_build(self):
167
self.down1 = Downscale(input_ch, e_dims, kernel_size=5)
168
self.res1 = ResidualBlock(e_dims)
169
self.down2 = Downscale(e_dims, e_dims*2, kernel_size=5)
170
self.down3 = Downscale(e_dims*2, e_dims*4, kernel_size=5)
171
self.down4 = Downscale(e_dims*4, e_dims*8, kernel_size=5)
172
self.down5 = Downscale(e_dims*8, e_dims*8, kernel_size=5)
173
self.res5 = ResidualBlock(e_dims*8)
174
self.dense1 = nn.Dense( (( resolution//(2**5) )**2) * e_dims*8, ae_dims )
175
176
def forward(self, x):
177
if use_fp16:
178
x = tf.cast(x, tf.float16)
179
x = self.down1(x)
180
x = self.res1(x)
181
x = self.down2(x)
182
x = self.down3(x)
183
x = self.down4(x)
184
x = self.down5(x)
185
x = self.res5(x)
186
if use_fp16:
187
x = tf.cast(x, tf.float32)
188
x = nn.pixel_norm(nn.flatten(x), axes=-1)
189
x = self.dense1(x)
190
return x
191
192
193
class Inter(nn.ModelBase):
194
def on_build(self):
195
self.dense2 = nn.Dense(ae_dims, inter_res * inter_res * inter_dims)
196
197
def forward(self, inp):
198
x = inp
199
x = self.dense2(x)
200
x = nn.reshape_4D (x, inter_res, inter_res, inter_dims)
201
return x
202
203
204
class Decoder(nn.ModelBase):
205
def on_build(self ):
206
self.upscale0 = Upscale(inter_dims, d_dims*8, kernel_size=3)
207
self.upscale1 = Upscale(d_dims*8, d_dims*8, kernel_size=3)
208
self.upscale2 = Upscale(d_dims*8, d_dims*4, kernel_size=3)
209
self.upscale3 = Upscale(d_dims*4, d_dims*2, kernel_size=3)
210
211
self.res0 = ResidualBlock(d_dims*8, kernel_size=3)
212
self.res1 = ResidualBlock(d_dims*8, kernel_size=3)
213
self.res2 = ResidualBlock(d_dims*4, kernel_size=3)
214
self.res3 = ResidualBlock(d_dims*2, kernel_size=3)
215
216
self.upscalem0 = Upscale(inter_dims, d_mask_dims*8, kernel_size=3)
217
self.upscalem1 = Upscale(d_mask_dims*8, d_mask_dims*8, kernel_size=3)
218
self.upscalem2 = Upscale(d_mask_dims*8, d_mask_dims*4, kernel_size=3)
219
self.upscalem3 = Upscale(d_mask_dims*4, d_mask_dims*2, kernel_size=3)
220
self.upscalem4 = Upscale(d_mask_dims*2, d_mask_dims*1, kernel_size=3)
221
self.out_convm = nn.Conv2D( d_mask_dims*1, 1, kernel_size=1, padding='SAME', dtype=conv_dtype)
222
223
self.out_conv = nn.Conv2D( d_dims*2, 3, kernel_size=1, padding='SAME', dtype=conv_dtype)
224
self.out_conv1 = nn.Conv2D( d_dims*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype)
225
self.out_conv2 = nn.Conv2D( d_dims*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype)
226
self.out_conv3 = nn.Conv2D( d_dims*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype)
227
228
def forward(self, z):
229
if use_fp16:
230
z = tf.cast(z, tf.float16)
231
232
x = self.upscale0(z)
233
x = self.res0(x)
234
x = self.upscale1(x)
235
x = self.res1(x)
236
x = self.upscale2(x)
237
x = self.res2(x)
238
x = self.upscale3(x)
239
x = self.res3(x)
240
241
x = tf.nn.sigmoid( nn.depth_to_space(tf.concat( (self.out_conv(x),
242
self.out_conv1(x),
243
self.out_conv2(x),
244
self.out_conv3(x)), nn.conv2d_ch_axis), 2) )
245
m = self.upscalem0(z)
246
m = self.upscalem1(m)
247
m = self.upscalem2(m)
248
m = self.upscalem3(m)
249
m = self.upscalem4(m)
250
m = tf.nn.sigmoid(self.out_convm(m))
251
252
if use_fp16:
253
x = tf.cast(x, tf.float32)
254
m = tf.cast(m, tf.float32)
255
return x, m
256
257
models_opt_on_gpu = False if len(devices) == 0 else self.options['models_opt_on_gpu']
258
models_opt_device = nn.tf_default_device_name if models_opt_on_gpu and self.is_training else '/CPU:0'
259
optimizer_vars_on_cpu = models_opt_device=='/CPU:0'
260
261
bgr_shape = self.bgr_shape = nn.get4Dshape(resolution,resolution,input_ch)
262
mask_shape = nn.get4Dshape(resolution,resolution,1)
263
self.model_filename_list = []
264
265
with tf.device ('/CPU:0'):
266
#Place holders on CPU
267
self.warped_src = tf.placeholder (nn.floatx, bgr_shape, name='warped_src')
268
self.warped_dst = tf.placeholder (nn.floatx, bgr_shape, name='warped_dst')
269
270
self.target_src = tf.placeholder (nn.floatx, bgr_shape, name='target_src')
271
self.target_dst = tf.placeholder (nn.floatx, bgr_shape, name='target_dst')
272
273
self.target_srcm = tf.placeholder (nn.floatx, mask_shape, name='target_srcm')
274
self.target_srcm_em = tf.placeholder (nn.floatx, mask_shape, name='target_srcm_em')
275
self.target_dstm = tf.placeholder (nn.floatx, mask_shape, name='target_dstm')
276
self.target_dstm_em = tf.placeholder (nn.floatx, mask_shape, name='target_dstm_em')
277
278
self.morph_value_t = tf.placeholder (nn.floatx, (1,), name='morph_value_t')
279
280
# Initializing model classes
281
with tf.device (models_opt_device):
282
self.encoder = Encoder(name='encoder')
283
self.inter_src = Inter(name='inter_src')
284
self.inter_dst = Inter(name='inter_dst')
285
self.decoder = Decoder(name='decoder')
286
287
self.model_filename_list += [ [self.encoder, 'encoder.npy'],
288
[self.inter_src, 'inter_src.npy'],
289
[self.inter_dst , 'inter_dst.npy'],
290
[self.decoder , 'decoder.npy'] ]
291
292
if self.is_training:
293
# Initialize optimizers
294
clipnorm = 1.0 if self.options['clipgrad'] else 0.0
295
if self.options['lr_dropout'] in ['y','cpu']:
296
lr_cos = 500
297
lr_dropout = 0.3
298
else:
299
lr_cos = 0
300
lr_dropout = 1.0
301
self.G_weights = self.encoder.get_weights() + self.decoder.get_weights()
302
303
self.src_dst_opt = nn.AdaBelief(lr=5e-5, lr_dropout=lr_dropout, lr_cos=lr_cos, clipnorm=clipnorm, name='src_dst_opt')
304
self.src_dst_opt.initialize_variables (self.G_weights, vars_on_cpu=optimizer_vars_on_cpu)
305
self.model_filename_list += [ (self.src_dst_opt, 'src_dst_opt.npy') ]
306
307
if gan_power != 0:
308
self.GAN = nn.UNetPatchDiscriminator(patch_size=self.options['gan_patch_size'], in_ch=input_ch, base_ch=self.options['gan_dims'], name="GAN")
309
self.GAN_opt = nn.AdaBelief(lr=5e-5, lr_dropout=lr_dropout, lr_cos=lr_cos, clipnorm=clipnorm, name='GAN_opt')
310
self.GAN_opt.initialize_variables ( self.GAN.get_weights(), vars_on_cpu=optimizer_vars_on_cpu)
311
self.model_filename_list += [ [self.GAN, 'GAN.npy'],
312
[self.GAN_opt, 'GAN_opt.npy'] ]
313
314
if self.is_training:
315
# Adjust batch size for multiple GPU
316
gpu_count = max(1, len(devices) )
317
bs_per_gpu = max(1, self.get_batch_size() // gpu_count)
318
self.set_batch_size( gpu_count*bs_per_gpu)
319
320
# Compute losses per GPU
321
gpu_pred_src_src_list = []
322
gpu_pred_dst_dst_list = []
323
gpu_pred_src_dst_list = []
324
gpu_pred_src_srcm_list = []
325
gpu_pred_dst_dstm_list = []
326
gpu_pred_src_dstm_list = []
327
328
gpu_src_losses = []
329
gpu_dst_losses = []
330
gpu_G_loss_gradients = []
331
gpu_GAN_loss_gradients = []
332
333
def DLossOnes(logits):
334
return tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(logits), logits=logits), axis=[1,2,3])
335
336
def DLossZeros(logits):
337
return tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(logits), logits=logits), axis=[1,2,3])
338
339
for gpu_id in range(gpu_count):
340
with tf.device( f'/{devices[gpu_id].tf_dev_type}:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ):
341
with tf.device(f'/CPU:0'):
342
# slice on CPU, otherwise all batch data will be transfered to GPU first
343
batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu )
344
gpu_warped_src = self.warped_src [batch_slice,:,:,:]
345
gpu_warped_dst = self.warped_dst [batch_slice,:,:,:]
346
gpu_target_src = self.target_src [batch_slice,:,:,:]
347
gpu_target_dst = self.target_dst [batch_slice,:,:,:]
348
gpu_target_srcm = self.target_srcm[batch_slice,:,:,:]
349
gpu_target_srcm_em = self.target_srcm_em[batch_slice,:,:,:]
350
gpu_target_dstm = self.target_dstm[batch_slice,:,:,:]
351
gpu_target_dstm_em = self.target_dstm_em[batch_slice,:,:,:]
352
353
# process model tensors
354
gpu_src_code = self.encoder (gpu_warped_src)
355
gpu_dst_code = self.encoder (gpu_warped_dst)
356
357
gpu_src_inter_src_code, gpu_src_inter_dst_code = self.inter_src (gpu_src_code), self.inter_dst (gpu_src_code)
358
gpu_dst_inter_src_code, gpu_dst_inter_dst_code = self.inter_src (gpu_dst_code), self.inter_dst (gpu_dst_code)
359
360
inter_dims_bin = int(inter_dims*morph_factor)
361
with tf.device(f'/CPU:0'):
362
inter_rnd_binomial = tf.stack([tf.random.shuffle(tf.concat([tf.tile(tf.constant([1], tf.float32), ( inter_dims_bin, )),
363
tf.tile(tf.constant([0], tf.float32), ( inter_dims-inter_dims_bin, ))], 0 )) for _ in range(bs_per_gpu)], 0)
364
365
inter_rnd_binomial = tf.stop_gradient(inter_rnd_binomial[...,None,None])
366
367
gpu_src_code = gpu_src_inter_src_code * inter_rnd_binomial + gpu_src_inter_dst_code * (1-inter_rnd_binomial)
368
gpu_dst_code = gpu_dst_inter_dst_code
369
370
inter_dims_slice = tf.cast(inter_dims*self.morph_value_t[0], tf.int32)
371
gpu_src_dst_code = tf.concat( (tf.slice(gpu_dst_inter_src_code, [0,0,0,0], [-1, inter_dims_slice , inter_res, inter_res]),
372
tf.slice(gpu_dst_inter_dst_code, [0,inter_dims_slice,0,0], [-1,inter_dims-inter_dims_slice, inter_res,inter_res]) ), 1 )
373
374
gpu_pred_src_src, gpu_pred_src_srcm = self.decoder(gpu_src_code)
375
gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder(gpu_dst_code)
376
gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code)
377
378
gpu_pred_src_src_list.append(gpu_pred_src_src), gpu_pred_src_srcm_list.append(gpu_pred_src_srcm)
379
gpu_pred_dst_dst_list.append(gpu_pred_dst_dst), gpu_pred_dst_dstm_list.append(gpu_pred_dst_dstm)
380
gpu_pred_src_dst_list.append(gpu_pred_src_dst), gpu_pred_src_dstm_list.append(gpu_pred_src_dstm)
381
382
gpu_target_srcm_anti = 1-gpu_target_srcm
383
gpu_target_dstm_anti = 1-gpu_target_dstm
384
385
gpu_target_srcm_gblur = nn.gaussian_blur(gpu_target_srcm, resolution // 32)
386
gpu_target_dstm_gblur = nn.gaussian_blur(gpu_target_dstm, resolution // 32)
387
388
gpu_target_srcm_blur = tf.clip_by_value(gpu_target_srcm_gblur, 0, 0.5) * 2
389
gpu_target_dstm_blur = tf.clip_by_value(gpu_target_dstm_gblur, 0, 0.5) * 2
390
gpu_target_srcm_anti_blur = 1.0-gpu_target_srcm_blur
391
gpu_target_dstm_anti_blur = 1.0-gpu_target_dstm_blur
392
393
if blur_out_mask:
394
sigma = resolution / 128
395
396
x = nn.gaussian_blur(gpu_target_src*gpu_target_srcm_anti, sigma)
397
y = 1-nn.gaussian_blur(gpu_target_srcm, sigma)
398
y = tf.where(tf.equal(y, 0), tf.ones_like(y), y)
399
gpu_target_src = gpu_target_src*gpu_target_srcm + (x/y)*gpu_target_srcm_anti
400
401
x = nn.gaussian_blur(gpu_target_dst*gpu_target_dstm_anti, sigma)
402
y = 1-nn.gaussian_blur(gpu_target_dstm, sigma)
403
y = tf.where(tf.equal(y, 0), tf.ones_like(y), y)
404
gpu_target_dst = gpu_target_dst*gpu_target_dstm + (x/y)*gpu_target_dstm_anti
405
406
gpu_target_src_masked = gpu_target_src*gpu_target_srcm_blur
407
gpu_target_dst_masked = gpu_target_dst*gpu_target_dstm_blur
408
gpu_target_src_anti_masked = gpu_target_src*gpu_target_srcm_anti_blur
409
gpu_target_dst_anti_masked = gpu_target_dst*gpu_target_dstm_anti_blur
410
411
gpu_pred_src_src_masked = gpu_pred_src_src*gpu_target_srcm_blur
412
gpu_pred_dst_dst_masked = gpu_pred_dst_dst*gpu_target_dstm_blur
413
gpu_pred_src_src_anti_masked = gpu_pred_src_src*gpu_target_srcm_anti_blur
414
gpu_pred_dst_dst_anti_masked = gpu_pred_dst_dst*gpu_target_dstm_anti_blur
415
416
# Structural loss
417
gpu_src_loss = tf.reduce_mean (5*nn.dssim(gpu_target_src_masked, gpu_pred_src_src_masked, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])
418
gpu_src_loss += tf.reduce_mean (5*nn.dssim(gpu_target_src_masked, gpu_pred_src_src_masked, max_val=1.0, filter_size=int(resolution/23.2)), axis=[1])
419
gpu_dst_loss = tf.reduce_mean (5*nn.dssim(gpu_target_dst_masked, gpu_pred_dst_dst_masked, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1])
420
gpu_dst_loss += tf.reduce_mean (5*nn.dssim(gpu_target_dst_masked, gpu_pred_dst_dst_masked, max_val=1.0, filter_size=int(resolution/23.2) ), axis=[1])
421
422
# Pixel loss
423
gpu_src_loss += tf.reduce_mean (10*tf.square(gpu_target_src_masked-gpu_pred_src_src_masked), axis=[1,2,3])
424
gpu_dst_loss += tf.reduce_mean (10*tf.square(gpu_target_dst_masked-gpu_pred_dst_dst_masked), axis=[1,2,3])
425
426
# Eyes+mouth prio loss
427
gpu_src_loss += tf.reduce_mean (300*tf.abs (gpu_target_src*gpu_target_srcm_em-gpu_pred_src_src*gpu_target_srcm_em), axis=[1,2,3])
428
gpu_dst_loss += tf.reduce_mean (300*tf.abs (gpu_target_dst*gpu_target_dstm_em-gpu_pred_dst_dst*gpu_target_dstm_em), axis=[1,2,3])
429
430
# Mask loss
431
gpu_src_loss += tf.reduce_mean ( 10*tf.square( gpu_target_srcm - gpu_pred_src_srcm ),axis=[1,2,3] )
432
gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dstm - gpu_pred_dst_dstm ),axis=[1,2,3] )
433
434
gpu_src_losses += [gpu_src_loss]
435
gpu_dst_losses += [gpu_dst_loss]
436
gpu_G_loss = gpu_src_loss + gpu_dst_loss
437
# dst-dst background weak loss
438
gpu_G_loss += tf.reduce_mean(0.1*tf.square(gpu_pred_dst_dst_anti_masked-gpu_target_dst_anti_masked),axis=[1,2,3] )
439
gpu_G_loss += 0.000001*nn.total_variation_mse(gpu_pred_dst_dst_anti_masked)
440
441
442
if gan_power != 0:
443
gpu_pred_src_src_d, gpu_pred_src_src_d2 = self.GAN(gpu_pred_src_src_masked)
444
gpu_pred_dst_dst_d, gpu_pred_dst_dst_d2 = self.GAN(gpu_pred_dst_dst_masked)
445
gpu_target_src_d, gpu_target_src_d2 = self.GAN(gpu_target_src_masked)
446
gpu_target_dst_d, gpu_target_dst_d2 = self.GAN(gpu_target_dst_masked)
447
448
gpu_GAN_loss = (DLossOnes (gpu_target_src_d) + DLossOnes (gpu_target_src_d2) + \
449
DLossZeros(gpu_pred_src_src_d) + DLossZeros(gpu_pred_src_src_d2) + \
450
DLossOnes (gpu_target_dst_d) + DLossOnes (gpu_target_dst_d2) + \
451
DLossZeros(gpu_pred_dst_dst_d) + DLossZeros(gpu_pred_dst_dst_d2)
452
) * (1.0 / 8)
453
454
gpu_GAN_loss_gradients += [ nn.gradients (gpu_GAN_loss, self.GAN.get_weights() ) ]
455
456
gpu_G_loss += (DLossOnes(gpu_pred_src_src_d) + DLossOnes(gpu_pred_src_src_d2) + \
457
DLossOnes(gpu_pred_dst_dst_d) + DLossOnes(gpu_pred_dst_dst_d2)
458
) * gan_power
459
460
# Minimal src-src-bg rec with total_variation_mse to suppress random bright dots from gan
461
gpu_G_loss += 0.000001*nn.total_variation_mse(gpu_pred_src_src)
462
gpu_G_loss += 0.02*tf.reduce_mean(tf.square(gpu_pred_src_src_anti_masked-gpu_target_src_anti_masked),axis=[1,2,3] )
463
464
gpu_G_loss_gradients += [ nn.gradients ( gpu_G_loss, self.G_weights ) ]
465
466
# Average losses and gradients, and create optimizer update ops
467
with tf.device(f'/CPU:0'):
468
pred_src_src = nn.concat(gpu_pred_src_src_list, 0)
469
pred_dst_dst = nn.concat(gpu_pred_dst_dst_list, 0)
470
pred_src_dst = nn.concat(gpu_pred_src_dst_list, 0)
471
pred_src_srcm = nn.concat(gpu_pred_src_srcm_list, 0)
472
pred_dst_dstm = nn.concat(gpu_pred_dst_dstm_list, 0)
473
pred_src_dstm = nn.concat(gpu_pred_src_dstm_list, 0)
474
475
with tf.device (models_opt_device):
476
src_loss = tf.concat(gpu_src_losses, 0)
477
dst_loss = tf.concat(gpu_dst_losses, 0)
478
train_op = self.src_dst_opt.get_update_op (nn.average_gv_list (gpu_G_loss_gradients))
479
480
if gan_power != 0:
481
GAN_train_op = self.GAN_opt.get_update_op (nn.average_gv_list(gpu_GAN_loss_gradients) )
482
483
# Initializing training and view functions
484
def train(warped_src, target_src, target_srcm, target_srcm_em, \
485
warped_dst, target_dst, target_dstm, target_dstm_em, ):
486
s, d, _ = nn.tf_sess.run ([src_loss, dst_loss, train_op],
487
feed_dict={self.warped_src :warped_src,
488
self.target_src :target_src,
489
self.target_srcm:target_srcm,
490
self.target_srcm_em:target_srcm_em,
491
self.warped_dst :warped_dst,
492
self.target_dst :target_dst,
493
self.target_dstm:target_dstm,
494
self.target_dstm_em:target_dstm_em,
495
})
496
return s, d
497
self.train = train
498
499
if gan_power != 0:
500
def GAN_train(warped_src, target_src, target_srcm, target_srcm_em, \
501
warped_dst, target_dst, target_dstm, target_dstm_em, ):
502
nn.tf_sess.run ([GAN_train_op], feed_dict={self.warped_src :warped_src,
503
self.target_src :target_src,
504
self.target_srcm:target_srcm,
505
self.target_srcm_em:target_srcm_em,
506
self.warped_dst :warped_dst,
507
self.target_dst :target_dst,
508
self.target_dstm:target_dstm,
509
self.target_dstm_em:target_dstm_em})
510
self.GAN_train = GAN_train
511
512
def AE_view(warped_src, warped_dst, morph_value):
513
return nn.tf_sess.run ( [pred_src_src, pred_dst_dst, pred_dst_dstm, pred_src_dst, pred_src_dstm],
514
feed_dict={self.warped_src:warped_src, self.warped_dst:warped_dst, self.morph_value_t:[morph_value] })
515
516
self.AE_view = AE_view
517
else:
518
#Initializing merge function
519
with tf.device( nn.tf_default_device_name if len(devices) != 0 else f'/CPU:0'):
520
gpu_dst_code = self.encoder (self.warped_dst)
521
gpu_dst_inter_src_code = self.inter_src (gpu_dst_code)
522
gpu_dst_inter_dst_code = self.inter_dst (gpu_dst_code)
523
524
inter_dims_slice = tf.cast(inter_dims*self.morph_value_t[0], tf.int32)
525
gpu_src_dst_code = tf.concat( ( tf.slice(gpu_dst_inter_src_code, [0,0,0,0], [-1, inter_dims_slice , inter_res, inter_res]),
526
tf.slice(gpu_dst_inter_dst_code, [0,inter_dims_slice,0,0], [-1,inter_dims-inter_dims_slice, inter_res,inter_res]) ), 1 )
527
528
gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code)
529
_, gpu_pred_dst_dstm = self.decoder(gpu_dst_inter_dst_code)
530
531
def AE_merge(warped_dst, morph_value):
532
return nn.tf_sess.run ( [gpu_pred_src_dst, gpu_pred_dst_dstm, gpu_pred_src_dstm], feed_dict={self.warped_dst:warped_dst, self.morph_value_t:[morph_value] })
533
534
self.AE_merge = AE_merge
535
536
# Loading/initializing all models/optimizers weights
537
for model, filename in io.progress_bar_generator(self.model_filename_list, "Initializing models"):
538
do_init = self.is_first_run()
539
if self.is_training and gan_power != 0 and model == self.GAN:
540
if self.gan_model_changed:
541
do_init = True
542
if not do_init:
543
do_init = not model.load_weights( self.get_strpath_storage_for_file(filename) )
544
if do_init:
545
model.init_weights()
546
###############
547
548
# initializing sample generators
549
if self.is_training:
550
training_data_src_path = self.training_data_src_path #if not self.pretrain else self.get_pretraining_data_path()
551
training_data_dst_path = self.training_data_dst_path #if not self.pretrain else self.get_pretraining_data_path()
552
553
random_ct_samples_path=training_data_dst_path if ct_mode is not None else None #and not self.pretrain
554
555
cpu_count = multiprocessing.cpu_count()
556
src_generators_count = cpu_count // 2
557
dst_generators_count = cpu_count // 2
558
if ct_mode is not None:
559
src_generators_count = int(src_generators_count * 1.5)
560
561
562
563
self.set_training_data_generators ([
564
SampleGeneratorFace(training_data_src_path, random_ct_samples_path=random_ct_samples_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
565
sample_process_options=SampleProcessor.Options(scale_range=[-0.15, 0.15], random_flip=self.random_src_flip),
566
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution},
567
{'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution},
568
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution},
569
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.EYES_MOUTH, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution},
570
],
571
uniform_yaw_distribution=self.options['uniform_yaw'],# or self.pretrain,
572
generators_count=src_generators_count ),
573
574
SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
575
sample_process_options=SampleProcessor.Options(scale_range=[-0.15, 0.15], random_flip=self.random_dst_flip),
576
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution},
577
{'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution},
578
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution},
579
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.EYES_MOUTH, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution},
580
],
581
uniform_yaw_distribution=self.options['uniform_yaw'],# or self.pretrain,
582
generators_count=dst_generators_count )
583
])
584
585
def export_dfm (self):
586
output_path=self.get_strpath_storage_for_file('model.dfm')
587
588
io.log_info(f'Dumping .dfm to {output_path}')
589
590
tf = nn.tf
591
with tf.device (nn.tf_default_device_name):
592
warped_dst = tf.placeholder (nn.floatx, (None, self.resolution, self.resolution, 3), name='in_face')
593
warped_dst = tf.transpose(warped_dst, (0,3,1,2))
594
morph_value = tf.placeholder (nn.floatx, (1,), name='morph_value')
595
596
gpu_dst_code = self.encoder (warped_dst)
597
gpu_dst_inter_src_code = self.inter_src ( gpu_dst_code)
598
gpu_dst_inter_dst_code = self.inter_dst ( gpu_dst_code)
599
600
inter_dims_slice = tf.cast(self.inter_dims*morph_value[0], tf.int32)
601
gpu_src_dst_code = tf.concat( (tf.slice(gpu_dst_inter_src_code, [0,0,0,0], [-1, inter_dims_slice , self.inter_res, self.inter_res]),
602
tf.slice(gpu_dst_inter_dst_code, [0,inter_dims_slice,0,0], [-1,self.inter_dims-inter_dims_slice, self.inter_res,self.inter_res]) ), 1 )
603
604
gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code)
605
_, gpu_pred_dst_dstm = self.decoder(gpu_dst_inter_dst_code)
606
607
gpu_pred_src_dst = tf.transpose(gpu_pred_src_dst, (0,2,3,1))
608
gpu_pred_dst_dstm = tf.transpose(gpu_pred_dst_dstm, (0,2,3,1))
609
gpu_pred_src_dstm = tf.transpose(gpu_pred_src_dstm, (0,2,3,1))
610
611
tf.identity(gpu_pred_dst_dstm, name='out_face_mask')
612
tf.identity(gpu_pred_src_dst, name='out_celeb_face')
613
tf.identity(gpu_pred_src_dstm, name='out_celeb_face_mask')
614
615
output_graph_def = tf.graph_util.convert_variables_to_constants(
616
nn.tf_sess,
617
tf.get_default_graph().as_graph_def(),
618
['out_face_mask','out_celeb_face','out_celeb_face_mask']
619
)
620
621
import tf2onnx
622
with tf.device("/CPU:0"):
623
model_proto, _ = tf2onnx.convert._convert_common(
624
output_graph_def,
625
name='AMP',
626
input_names=['in_face:0','morph_value:0'],
627
output_names=['out_face_mask:0','out_celeb_face:0','out_celeb_face_mask:0'],
628
opset=12,
629
output_path=output_path)
630
631
#override
632
def get_model_filename_list(self):
633
return self.model_filename_list
634
635
#override
636
def onSave(self):
637
for model, filename in io.progress_bar_generator(self.get_model_filename_list(), "Saving", leave=False):
638
model.save_weights ( self.get_strpath_storage_for_file(filename) )
639
640
#override
641
def should_save_preview_history(self):
642
return (not io.is_colab() and self.iter % ( 10*(max(1,self.resolution // 64)) ) == 0) or \
643
(io.is_colab() and self.iter % 100 == 0)
644
645
#override
646
def onTrainOneIter(self):
647
bs = self.get_batch_size()
648
649
( (warped_src, target_src, target_srcm, target_srcm_em), \
650
(warped_dst, target_dst, target_dstm, target_dstm_em) ) = self.generate_next_samples()
651
652
src_loss, dst_loss = self.train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em)
653
654
if self.gan_power != 0:
655
self.GAN_train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em)
656
657
return ( ('src_loss', np.mean(src_loss) ), ('dst_loss', np.mean(dst_loss) ), )
658
659
#override
660
def onGetPreview(self, samples, for_history=False):
661
( (warped_src, target_src, target_srcm, target_srcm_em),
662
(warped_dst, target_dst, target_dstm, target_dstm_em) ) = samples
663
664
S, D, SS, DD, DDM_000, _, _ = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([target_src,target_dst] + self.AE_view (target_src, target_dst, 0.0) ) ]
665
666
_, _, DDM_025, SD_025, SDM_025 = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in self.AE_view (target_src, target_dst, 0.25) ]
667
_, _, DDM_050, SD_050, SDM_050 = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in self.AE_view (target_src, target_dst, 0.50) ]
668
_, _, DDM_065, SD_065, SDM_065 = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in self.AE_view (target_src, target_dst, 0.65) ]
669
_, _, DDM_075, SD_075, SDM_075 = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in self.AE_view (target_src, target_dst, 0.75) ]
670
_, _, DDM_100, SD_100, SDM_100 = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in self.AE_view (target_src, target_dst, 1.00) ]
671
672
(DDM_000,
673
DDM_025, SDM_025,
674
DDM_050, SDM_050,
675
DDM_065, SDM_065,
676
DDM_075, SDM_075,
677
DDM_100, SDM_100) = [ np.repeat (x, (3,), -1) for x in (DDM_000,
678
DDM_025, SDM_025,
679
DDM_050, SDM_050,
680
DDM_065, SDM_065,
681
DDM_075, SDM_075,
682
DDM_100, SDM_100) ]
683
684
target_srcm, target_dstm = [ nn.to_data_format(x,"NHWC", self.model_data_format) for x in ([target_srcm, target_dstm] )]
685
686
n_samples = min(4, self.get_batch_size(), 800 // self.resolution )
687
688
result = []
689
690
i = np.random.randint(n_samples) if not for_history else 0
691
692
st = [ np.concatenate ((S[i], D[i], DD[i]*DDM_000[i]), axis=1) ]
693
st += [ np.concatenate ((SS[i], DD[i], SD_100[i] ), axis=1) ]
694
695
result += [ ('AMP morph 1.0', np.concatenate (st, axis=0 )), ]
696
697
st = [ np.concatenate ((DD[i], SD_025[i], SD_050[i]), axis=1) ]
698
st += [ np.concatenate ((SD_065[i], SD_075[i], SD_100[i]), axis=1) ]
699
result += [ ('AMP morph list', np.concatenate (st, axis=0 )), ]
700
701
st = [ np.concatenate ((DD[i], SD_025[i]*DDM_025[i]*SDM_025[i], SD_050[i]*DDM_050[i]*SDM_050[i]), axis=1) ]
702
st += [ np.concatenate ((SD_065[i]*DDM_065[i]*SDM_065[i], SD_075[i]*DDM_075[i]*SDM_075[i], SD_100[i]*DDM_100[i]*SDM_100[i]), axis=1) ]
703
result += [ ('AMP morph list masked', np.concatenate (st, axis=0 )), ]
704
705
return result
706
707
def predictor_func (self, face, morph_value):
708
face = nn.to_data_format(face[None,...], self.model_data_format, "NHWC")
709
710
bgr, mask_dst_dstm, mask_src_dstm = [ nn.to_data_format(x,"NHWC", self.model_data_format).astype(np.float32) for x in self.AE_merge (face, morph_value) ]
711
712
return bgr[0], mask_src_dstm[0][...,0], mask_dst_dstm[0][...,0]
713
714
#override
715
def get_MergerConfig(self):
716
morph_factor = np.clip ( io.input_number ("Morph factor", 1.0, add_info="0.0 .. 1.0"), 0.0, 1.0 )
717
718
def predictor_morph(face):
719
return self.predictor_func(face, morph_factor)
720
721
722
import merger
723
return predictor_morph, (self.options['resolution'], self.options['resolution'], 3), merger.MergerConfigMasked(face_type=self.face_type, default_mode = 'overlay')
724
725
Model = AMPModel
726
727