Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
iperov
GitHub Repository: iperov/deepfacelab
Path: blob/master/models/Model_Quick96/Model.py
628 views
1
import multiprocessing
2
from functools import partial
3
4
import numpy as np
5
6
from core import mathlib
7
from core.interact import interact as io
8
from core.leras import nn
9
from facelib import FaceType
10
from models import ModelBase
11
from samplelib import *
12
13
class QModel(ModelBase):
14
#override
15
def on_initialize(self):
16
device_config = nn.getCurrentDeviceConfig()
17
devices = device_config.devices
18
self.model_data_format = "NCHW" if len(devices) != 0 and not self.is_debug() else "NHWC"
19
nn.initialize(data_format=self.model_data_format)
20
tf = nn.tf
21
22
resolution = self.resolution = 96
23
self.face_type = FaceType.FULL
24
ae_dims = 128
25
e_dims = 64
26
d_dims = 64
27
d_mask_dims = 16
28
self.pretrain = False
29
self.pretrain_just_disabled = False
30
31
masked_training = True
32
33
models_opt_on_gpu = len(devices) >= 1 and all([dev.total_mem_gb >= 4 for dev in devices])
34
models_opt_device = nn.tf_default_device_name if models_opt_on_gpu and self.is_training else '/CPU:0'
35
optimizer_vars_on_cpu = models_opt_device=='/CPU:0'
36
37
input_ch = 3
38
bgr_shape = nn.get4Dshape(resolution,resolution,input_ch)
39
mask_shape = nn.get4Dshape(resolution,resolution,1)
40
41
self.model_filename_list = []
42
43
model_archi = nn.DeepFakeArchi(resolution, opts='ud')
44
45
with tf.device ('/CPU:0'):
46
#Place holders on CPU
47
self.warped_src = tf.placeholder (nn.floatx, bgr_shape)
48
self.warped_dst = tf.placeholder (nn.floatx, bgr_shape)
49
50
self.target_src = tf.placeholder (nn.floatx, bgr_shape)
51
self.target_dst = tf.placeholder (nn.floatx, bgr_shape)
52
53
self.target_srcm = tf.placeholder (nn.floatx, mask_shape)
54
self.target_dstm = tf.placeholder (nn.floatx, mask_shape)
55
56
# Initializing model classes
57
with tf.device (models_opt_device):
58
self.encoder = model_archi.Encoder(in_ch=input_ch, e_ch=e_dims, name='encoder')
59
encoder_out_ch = self.encoder.get_out_ch()*self.encoder.get_out_res(resolution)**2
60
61
self.inter = model_archi.Inter (in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims, name='inter')
62
inter_out_ch = self.inter.get_out_ch()
63
64
self.decoder_src = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, name='decoder_src')
65
self.decoder_dst = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, name='decoder_dst')
66
67
self.model_filename_list += [ [self.encoder, 'encoder.npy' ],
68
[self.inter, 'inter.npy' ],
69
[self.decoder_src, 'decoder_src.npy'],
70
[self.decoder_dst, 'decoder_dst.npy'] ]
71
72
if self.is_training:
73
self.src_dst_trainable_weights = self.encoder.get_weights() + self.inter.get_weights() + self.decoder_src.get_weights() + self.decoder_dst.get_weights()
74
75
# Initialize optimizers
76
self.src_dst_opt = nn.RMSprop(lr=2e-4, lr_dropout=0.3, name='src_dst_opt')
77
self.src_dst_opt.initialize_variables(self.src_dst_trainable_weights, vars_on_cpu=optimizer_vars_on_cpu )
78
self.model_filename_list += [ (self.src_dst_opt, 'src_dst_opt.npy') ]
79
80
if self.is_training:
81
# Adjust batch size for multiple GPU
82
gpu_count = max(1, len(devices) )
83
bs_per_gpu = max(1, 4 // gpu_count)
84
self.set_batch_size( gpu_count*bs_per_gpu)
85
86
# Compute losses per GPU
87
gpu_pred_src_src_list = []
88
gpu_pred_dst_dst_list = []
89
gpu_pred_src_dst_list = []
90
gpu_pred_src_srcm_list = []
91
gpu_pred_dst_dstm_list = []
92
gpu_pred_src_dstm_list = []
93
94
gpu_src_losses = []
95
gpu_dst_losses = []
96
gpu_src_dst_loss_gvs = []
97
98
for gpu_id in range(gpu_count):
99
with tf.device( f'/{devices[gpu_id].tf_dev_type}:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ):
100
batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu )
101
with tf.device(f'/CPU:0'):
102
# slice on CPU, otherwise all batch data will be transfered to GPU first
103
gpu_warped_src = self.warped_src [batch_slice,:,:,:]
104
gpu_warped_dst = self.warped_dst [batch_slice,:,:,:]
105
gpu_target_src = self.target_src [batch_slice,:,:,:]
106
gpu_target_dst = self.target_dst [batch_slice,:,:,:]
107
gpu_target_srcm = self.target_srcm[batch_slice,:,:,:]
108
gpu_target_dstm = self.target_dstm[batch_slice,:,:,:]
109
110
# process model tensors
111
gpu_src_code = self.inter(self.encoder(gpu_warped_src))
112
gpu_dst_code = self.inter(self.encoder(gpu_warped_dst))
113
gpu_pred_src_src, gpu_pred_src_srcm = self.decoder_src(gpu_src_code)
114
gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code)
115
gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code)
116
117
gpu_pred_src_src_list.append(gpu_pred_src_src)
118
gpu_pred_dst_dst_list.append(gpu_pred_dst_dst)
119
gpu_pred_src_dst_list.append(gpu_pred_src_dst)
120
121
gpu_pred_src_srcm_list.append(gpu_pred_src_srcm)
122
gpu_pred_dst_dstm_list.append(gpu_pred_dst_dstm)
123
gpu_pred_src_dstm_list.append(gpu_pred_src_dstm)
124
125
gpu_target_srcm_blur = nn.gaussian_blur(gpu_target_srcm, max(1, resolution // 32) )
126
gpu_target_dstm_blur = nn.gaussian_blur(gpu_target_dstm, max(1, resolution // 32) )
127
128
gpu_target_dst_masked = gpu_target_dst*gpu_target_dstm_blur
129
gpu_target_dst_anti_masked = gpu_target_dst*(1.0 - gpu_target_dstm_blur)
130
131
gpu_target_src_masked_opt = gpu_target_src*gpu_target_srcm_blur if masked_training else gpu_target_src
132
gpu_target_dst_masked_opt = gpu_target_dst_masked if masked_training else gpu_target_dst
133
134
gpu_pred_src_src_masked_opt = gpu_pred_src_src*gpu_target_srcm_blur if masked_training else gpu_pred_src_src
135
gpu_pred_dst_dst_masked_opt = gpu_pred_dst_dst*gpu_target_dstm_blur if masked_training else gpu_pred_dst_dst
136
137
gpu_psd_target_dst_masked = gpu_pred_src_dst*gpu_target_dstm_blur
138
gpu_psd_target_dst_anti_masked = gpu_pred_src_dst*(1.0 - gpu_target_dstm_blur)
139
140
gpu_src_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])
141
gpu_src_loss += tf.reduce_mean ( 10*tf.square ( gpu_target_src_masked_opt - gpu_pred_src_src_masked_opt ), axis=[1,2,3])
142
gpu_src_loss += tf.reduce_mean ( 10*tf.square( gpu_target_srcm - gpu_pred_src_srcm ),axis=[1,2,3] )
143
144
gpu_dst_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1])
145
gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dst_masked_opt- gpu_pred_dst_dst_masked_opt ), axis=[1,2,3])
146
gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dstm - gpu_pred_dst_dstm ),axis=[1,2,3] )
147
148
gpu_src_losses += [gpu_src_loss]
149
gpu_dst_losses += [gpu_dst_loss]
150
151
gpu_G_loss = gpu_src_loss + gpu_dst_loss
152
gpu_src_dst_loss_gvs += [ nn.gradients ( gpu_G_loss, self.src_dst_trainable_weights ) ]
153
154
155
# Average losses and gradients, and create optimizer update ops
156
with tf.device (models_opt_device):
157
pred_src_src = nn.concat(gpu_pred_src_src_list, 0)
158
pred_dst_dst = nn.concat(gpu_pred_dst_dst_list, 0)
159
pred_src_dst = nn.concat(gpu_pred_src_dst_list, 0)
160
pred_src_srcm = nn.concat(gpu_pred_src_srcm_list, 0)
161
pred_dst_dstm = nn.concat(gpu_pred_dst_dstm_list, 0)
162
pred_src_dstm = nn.concat(gpu_pred_src_dstm_list, 0)
163
164
src_loss = nn.average_tensor_list(gpu_src_losses)
165
dst_loss = nn.average_tensor_list(gpu_dst_losses)
166
src_dst_loss_gv = nn.average_gv_list (gpu_src_dst_loss_gvs)
167
src_dst_loss_gv_op = self.src_dst_opt.get_update_op (src_dst_loss_gv)
168
169
# Initializing training and view functions
170
def src_dst_train(warped_src, target_src, target_srcm, \
171
warped_dst, target_dst, target_dstm):
172
s, d, _ = nn.tf_sess.run ( [ src_loss, dst_loss, src_dst_loss_gv_op],
173
feed_dict={self.warped_src :warped_src,
174
self.target_src :target_src,
175
self.target_srcm:target_srcm,
176
self.warped_dst :warped_dst,
177
self.target_dst :target_dst,
178
self.target_dstm:target_dstm,
179
})
180
s = np.mean(s)
181
d = np.mean(d)
182
return s, d
183
self.src_dst_train = src_dst_train
184
185
def AE_view(warped_src, warped_dst):
186
return nn.tf_sess.run ( [pred_src_src, pred_dst_dst, pred_dst_dstm, pred_src_dst, pred_src_dstm],
187
feed_dict={self.warped_src:warped_src,
188
self.warped_dst:warped_dst})
189
190
self.AE_view = AE_view
191
else:
192
# Initializing merge function
193
with tf.device( nn.tf_default_device_name if len(devices) != 0 else f'/CPU:0'):
194
gpu_dst_code = self.inter(self.encoder(self.warped_dst))
195
gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code)
196
_, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code)
197
198
def AE_merge( warped_dst):
199
200
return nn.tf_sess.run ( [gpu_pred_src_dst, gpu_pred_dst_dstm, gpu_pred_src_dstm], feed_dict={self.warped_dst:warped_dst})
201
202
self.AE_merge = AE_merge
203
204
# Loading/initializing all models/optimizers weights
205
for model, filename in io.progress_bar_generator(self.model_filename_list, "Initializing models"):
206
if self.pretrain_just_disabled:
207
do_init = False
208
if model == self.inter:
209
do_init = True
210
else:
211
do_init = self.is_first_run()
212
213
if not do_init:
214
do_init = not model.load_weights( self.get_strpath_storage_for_file(filename) )
215
216
if do_init and self.pretrained_model_path is not None:
217
pretrained_filepath = self.pretrained_model_path / filename
218
if pretrained_filepath.exists():
219
do_init = not model.load_weights(pretrained_filepath)
220
221
if do_init:
222
model.init_weights()
223
224
# initializing sample generators
225
if self.is_training:
226
training_data_src_path = self.training_data_src_path if not self.pretrain else self.get_pretraining_data_path()
227
training_data_dst_path = self.training_data_dst_path if not self.pretrain else self.get_pretraining_data_path()
228
229
cpu_count = min(multiprocessing.cpu_count(), 8)
230
src_generators_count = cpu_count // 2
231
dst_generators_count = cpu_count // 2
232
233
self.set_training_data_generators ([
234
SampleGeneratorFace(training_data_src_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
235
sample_process_options=SampleProcessor.Options(random_flip=True if self.pretrain else False),
236
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':True, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
237
{'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
238
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}
239
],
240
generators_count=src_generators_count ),
241
242
SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
243
sample_process_options=SampleProcessor.Options(random_flip=True if self.pretrain else False),
244
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':True, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
245
{'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
246
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}
247
],
248
generators_count=dst_generators_count )
249
])
250
251
self.last_samples = None
252
253
#override
254
def get_model_filename_list(self):
255
return self.model_filename_list
256
257
#override
258
def onSave(self):
259
for model, filename in io.progress_bar_generator(self.get_model_filename_list(), "Saving", leave=False):
260
model.save_weights ( self.get_strpath_storage_for_file(filename) )
261
262
#override
263
def onTrainOneIter(self):
264
265
if self.get_iter() % 3 == 0 and self.last_samples is not None:
266
( (warped_src, target_src, target_srcm), \
267
(warped_dst, target_dst, target_dstm) ) = self.last_samples
268
warped_src = target_src
269
warped_dst = target_dst
270
else:
271
samples = self.last_samples = self.generate_next_samples()
272
( (warped_src, target_src, target_srcm), \
273
(warped_dst, target_dst, target_dstm) ) = samples
274
275
src_loss, dst_loss = self.src_dst_train (warped_src, target_src, target_srcm,
276
warped_dst, target_dst, target_dstm)
277
278
return ( ('src_loss', src_loss), ('dst_loss', dst_loss), )
279
280
#override
281
def onGetPreview(self, samples, for_history=False):
282
( (warped_src, target_src, target_srcm),
283
(warped_dst, target_dst, target_dstm) ) = samples
284
285
S, D, SS, DD, DDM, SD, SDM = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([target_src,target_dst] + self.AE_view (target_src, target_dst) ) ]
286
DDM, SDM, = [ np.repeat (x, (3,), -1) for x in [DDM, SDM] ]
287
288
target_srcm, target_dstm = [ nn.to_data_format(x,"NHWC", self.model_data_format) for x in ([target_srcm, target_dstm] )]
289
290
n_samples = min(4, self.get_batch_size() )
291
result = []
292
st = []
293
for i in range(n_samples):
294
ar = S[i], SS[i], D[i], DD[i], SD[i]
295
st.append ( np.concatenate ( ar, axis=1) )
296
297
result += [ ('Quick96', np.concatenate (st, axis=0 )), ]
298
299
st_m = []
300
for i in range(n_samples):
301
ar = S[i]*target_srcm[i], SS[i], D[i]*target_dstm[i], DD[i]*DDM[i], SD[i]*(DDM[i]*SDM[i])
302
st_m.append ( np.concatenate ( ar, axis=1) )
303
304
result += [ ('Quick96 masked', np.concatenate (st_m, axis=0 )), ]
305
306
return result
307
308
def predictor_func (self, face=None):
309
face = nn.to_data_format(face[None,...], self.model_data_format, "NHWC")
310
311
bgr, mask_dst_dstm, mask_src_dstm = [ nn.to_data_format(x, "NHWC", self.model_data_format).astype(np.float32) for x in self.AE_merge (face) ]
312
return bgr[0], mask_src_dstm[0][...,0], mask_dst_dstm[0][...,0]
313
314
#override
315
def get_MergerConfig(self):
316
import merger
317
return self.predictor_func, (self.resolution, self.resolution, 3), merger.MergerConfigMasked(face_type=self.face_type,
318
default_mode = 'overlay',
319
)
320
321
Model = QModel
322
323