Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
iperov
GitHub Repository: iperov/deepfacelab
Path: blob/master/models/Model_XSeg/Model.py
628 views
1
import multiprocessing
2
import operator
3
from functools import partial
4
5
import numpy as np
6
7
from core import mathlib
8
from core.interact import interact as io
9
from core.leras import nn
10
from facelib import FaceType, XSegNet
11
from models import ModelBase
12
from samplelib import *
13
14
class XSegModel(ModelBase):
15
16
def __init__(self, *args, **kwargs):
17
super().__init__(*args, force_model_class_name='XSeg', **kwargs)
18
19
#override
20
def on_initialize_options(self):
21
ask_override = self.ask_override()
22
23
if not self.is_first_run() and ask_override:
24
if io.input_bool(f"Restart training?", False, help_message="Reset model weights and start training from scratch."):
25
self.set_iter(0)
26
27
default_face_type = self.options['face_type'] = self.load_or_def_option('face_type', 'wf')
28
default_pretrain = self.options['pretrain'] = self.load_or_def_option('pretrain', False)
29
30
if self.is_first_run():
31
self.options['face_type'] = io.input_str ("Face type", default_face_type, ['h','mf','f','wf','head'], help_message="Half / mid face / full face / whole face / head. Choose the same as your deepfake model.").lower()
32
33
if self.is_first_run() or ask_override:
34
self.ask_batch_size(4, range=[2,16])
35
self.options['pretrain'] = io.input_bool ("Enable pretraining mode", default_pretrain)
36
37
if not self.is_exporting and (self.options['pretrain'] and self.get_pretraining_data_path() is None):
38
raise Exception("pretraining_data_path is not defined")
39
40
self.pretrain_just_disabled = (default_pretrain == True and self.options['pretrain'] == False)
41
42
#override
43
def on_initialize(self):
44
device_config = nn.getCurrentDeviceConfig()
45
self.model_data_format = "NCHW" if self.is_exporting or (len(device_config.devices) != 0 and not self.is_debug()) else "NHWC"
46
nn.initialize(data_format=self.model_data_format)
47
tf = nn.tf
48
49
device_config = nn.getCurrentDeviceConfig()
50
devices = device_config.devices
51
52
self.resolution = resolution = 256
53
54
55
self.face_type = {'h' : FaceType.HALF,
56
'mf' : FaceType.MID_FULL,
57
'f' : FaceType.FULL,
58
'wf' : FaceType.WHOLE_FACE,
59
'head' : FaceType.HEAD}[ self.options['face_type'] ]
60
61
62
place_model_on_cpu = len(devices) == 0
63
models_opt_device = '/CPU:0' if place_model_on_cpu else nn.tf_default_device_name
64
65
bgr_shape = nn.get4Dshape(resolution,resolution,3)
66
mask_shape = nn.get4Dshape(resolution,resolution,1)
67
68
# Initializing model classes
69
self.model = XSegNet(name='XSeg',
70
resolution=resolution,
71
load_weights=not self.is_first_run(),
72
weights_file_root=self.get_model_root_path(),
73
training=True,
74
place_model_on_cpu=place_model_on_cpu,
75
optimizer=nn.RMSprop(lr=0.0001, lr_dropout=0.3, name='opt'),
76
data_format=nn.data_format)
77
78
self.pretrain = self.options['pretrain']
79
if self.pretrain_just_disabled:
80
self.set_iter(0)
81
82
if self.is_training:
83
# Adjust batch size for multiple GPU
84
gpu_count = max(1, len(devices) )
85
bs_per_gpu = max(1, self.get_batch_size() // gpu_count)
86
self.set_batch_size( gpu_count*bs_per_gpu)
87
88
# Compute losses per GPU
89
gpu_pred_list = []
90
91
gpu_losses = []
92
gpu_loss_gvs = []
93
94
for gpu_id in range(gpu_count):
95
with tf.device(f'/{devices[gpu_id].tf_dev_type}:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ):
96
with tf.device(f'/CPU:0'):
97
# slice on CPU, otherwise all batch data will be transfered to GPU first
98
batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu )
99
gpu_input_t = self.model.input_t [batch_slice,:,:,:]
100
gpu_target_t = self.model.target_t [batch_slice,:,:,:]
101
102
# process model tensors
103
gpu_pred_logits_t, gpu_pred_t = self.model.flow(gpu_input_t, pretrain=self.pretrain)
104
gpu_pred_list.append(gpu_pred_t)
105
106
107
if self.pretrain:
108
# Structural loss
109
gpu_loss = tf.reduce_mean (5*nn.dssim(gpu_target_t, gpu_pred_t, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])
110
gpu_loss += tf.reduce_mean (5*nn.dssim(gpu_target_t, gpu_pred_t, max_val=1.0, filter_size=int(resolution/23.2)), axis=[1])
111
# Pixel loss
112
gpu_loss += tf.reduce_mean (10*tf.square(gpu_target_t-gpu_pred_t), axis=[1,2,3])
113
else:
114
gpu_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=gpu_target_t, logits=gpu_pred_logits_t), axis=[1,2,3])
115
116
gpu_losses += [gpu_loss]
117
118
gpu_loss_gvs += [ nn.gradients ( gpu_loss, self.model.get_weights() ) ]
119
120
121
# Average losses and gradients, and create optimizer update ops
122
#with tf.device(f'/CPU:0'): # Temporary fix. Unknown bug with training freeze starts from 2.4.0, but 2.3.1 was ok
123
with tf.device (models_opt_device):
124
pred = tf.concat(gpu_pred_list, 0)
125
loss = tf.concat(gpu_losses, 0)
126
loss_gv_op = self.model.opt.get_update_op (nn.average_gv_list (gpu_loss_gvs))
127
128
129
# Initializing training and view functions
130
if self.pretrain:
131
def train(input_np, target_np):
132
l, _ = nn.tf_sess.run ( [loss, loss_gv_op], feed_dict={self.model.input_t :input_np, self.model.target_t :target_np})
133
return l
134
else:
135
def train(input_np, target_np):
136
l, _ = nn.tf_sess.run ( [loss, loss_gv_op], feed_dict={self.model.input_t :input_np, self.model.target_t :target_np })
137
return l
138
self.train = train
139
140
def view(input_np):
141
return nn.tf_sess.run ( [pred], feed_dict={self.model.input_t :input_np})
142
self.view = view
143
144
# initializing sample generators
145
cpu_count = min(multiprocessing.cpu_count(), 8)
146
src_dst_generators_count = cpu_count // 2
147
src_generators_count = cpu_count // 2
148
dst_generators_count = cpu_count // 2
149
150
if self.pretrain:
151
pretrain_gen = SampleGeneratorFace(self.get_pretraining_data_path(), debug=self.is_debug(), batch_size=self.get_batch_size(),
152
sample_process_options=SampleProcessor.Options(random_flip=True),
153
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':True, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
154
{'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':True, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
155
],
156
uniform_yaw_distribution=False,
157
generators_count=cpu_count )
158
self.set_training_data_generators ([pretrain_gen])
159
else:
160
srcdst_generator = SampleGeneratorFaceXSeg([self.training_data_src_path, self.training_data_dst_path],
161
debug=self.is_debug(),
162
batch_size=self.get_batch_size(),
163
resolution=resolution,
164
face_type=self.face_type,
165
generators_count=src_dst_generators_count,
166
data_format=nn.data_format)
167
168
src_generator = SampleGeneratorFace(self.training_data_src_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
169
sample_process_options=SampleProcessor.Options(random_flip=False),
170
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp':False, 'transform':False, 'channel_type' : SampleProcessor.ChannelType.BGR, 'border_replicate':False, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
171
],
172
generators_count=src_generators_count,
173
raise_on_no_data=False )
174
dst_generator = SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
175
sample_process_options=SampleProcessor.Options(random_flip=False),
176
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp':False, 'transform':False, 'channel_type' : SampleProcessor.ChannelType.BGR, 'border_replicate':False, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
177
],
178
generators_count=dst_generators_count,
179
raise_on_no_data=False )
180
181
self.set_training_data_generators ([srcdst_generator, src_generator, dst_generator])
182
183
#override
184
def get_model_filename_list(self):
185
return self.model.model_filename_list
186
187
#override
188
def onSave(self):
189
self.model.save_weights()
190
191
#override
192
def onTrainOneIter(self):
193
image_np, target_np = self.generate_next_samples()[0]
194
loss = self.train (image_np, target_np)
195
196
return ( ('loss', np.mean(loss) ), )
197
198
#override
199
def onGetPreview(self, samples, for_history=False):
200
n_samples = min(4, self.get_batch_size(), 800 // self.resolution )
201
202
if self.pretrain:
203
srcdst_samples, = samples
204
image_np, mask_np = srcdst_samples
205
else:
206
srcdst_samples, src_samples, dst_samples = samples
207
image_np, mask_np = srcdst_samples
208
209
I, M, IM, = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([image_np,mask_np] + self.view (image_np) ) ]
210
M, IM, = [ np.repeat (x, (3,), -1) for x in [M, IM] ]
211
212
green_bg = np.tile( np.array([0,1,0], dtype=np.float32)[None,None,...], (self.resolution,self.resolution,1) )
213
214
result = []
215
st = []
216
for i in range(n_samples):
217
if self.pretrain:
218
ar = I[i], IM[i]
219
else:
220
ar = I[i]*M[i]+0.5*I[i]*(1-M[i])+0.5*green_bg*(1-M[i]), IM[i], I[i]*IM[i]+0.5*I[i]*(1-IM[i]) + 0.5*green_bg*(1-IM[i])
221
st.append ( np.concatenate ( ar, axis=1) )
222
result += [ ('XSeg training faces', np.concatenate (st, axis=0 )), ]
223
224
if not self.pretrain and len(src_samples) != 0:
225
src_np, = src_samples
226
227
228
D, DM, = [ np.clip(nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([src_np] + self.view (src_np) ) ]
229
DM, = [ np.repeat (x, (3,), -1) for x in [DM] ]
230
231
st = []
232
for i in range(n_samples):
233
ar = D[i], DM[i], D[i]*DM[i] + 0.5*D[i]*(1-DM[i]) + 0.5*green_bg*(1-DM[i])
234
st.append ( np.concatenate ( ar, axis=1) )
235
236
result += [ ('XSeg src faces', np.concatenate (st, axis=0 )), ]
237
238
if not self.pretrain and len(dst_samples) != 0:
239
dst_np, = dst_samples
240
241
242
D, DM, = [ np.clip(nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([dst_np] + self.view (dst_np) ) ]
243
DM, = [ np.repeat (x, (3,), -1) for x in [DM] ]
244
245
st = []
246
for i in range(n_samples):
247
ar = D[i], DM[i], D[i]*DM[i] + 0.5*D[i]*(1-DM[i]) + 0.5*green_bg*(1-DM[i])
248
st.append ( np.concatenate ( ar, axis=1) )
249
250
result += [ ('XSeg dst faces', np.concatenate (st, axis=0 )), ]
251
252
return result
253
254
def export_dfm (self):
255
output_path = self.get_strpath_storage_for_file(f'model.onnx')
256
io.log_info(f'Dumping .onnx to {output_path}')
257
tf = nn.tf
258
259
with tf.device (nn.tf_default_device_name):
260
input_t = tf.placeholder (nn.floatx, (None, self.resolution, self.resolution, 3), name='in_face')
261
input_t = tf.transpose(input_t, (0,3,1,2))
262
_, pred_t = self.model.flow(input_t)
263
pred_t = tf.transpose(pred_t, (0,2,3,1))
264
265
tf.identity(pred_t, name='out_mask')
266
267
output_graph_def = tf.graph_util.convert_variables_to_constants(
268
nn.tf_sess,
269
tf.get_default_graph().as_graph_def(),
270
['out_mask']
271
)
272
273
import tf2onnx
274
with tf.device("/CPU:0"):
275
model_proto, _ = tf2onnx.convert._convert_common(
276
output_graph_def,
277
name='XSeg',
278
input_names=['in_face:0'],
279
output_names=['out_mask:0'],
280
opset=13,
281
output_path=output_path)
282
283
Model = XSegModel
284