import multiprocessing
from functools import partial
import numpy as np
from core import mathlib
from core.interact import interact as io
from core.leras import nn
from facelib import FaceType
from models import ModelBase
from samplelib import *
class QModel(ModelBase):
def on_initialize(self):
device_config = nn.getCurrentDeviceConfig()
devices = device_config.devices
self.model_data_format = "NCHW" if len(devices) != 0 and not self.is_debug() else "NHWC"
nn.initialize(data_format=self.model_data_format)
tf = nn.tf
resolution = self.resolution = 96
self.face_type = FaceType.FULL
ae_dims = 128
e_dims = 64
d_dims = 64
d_mask_dims = 16
self.pretrain = False
self.pretrain_just_disabled = False
masked_training = True
models_opt_on_gpu = len(devices) >= 1 and all([dev.total_mem_gb >= 4 for dev in devices])
models_opt_device = nn.tf_default_device_name if models_opt_on_gpu and self.is_training else '/CPU:0'
optimizer_vars_on_cpu = models_opt_device=='/CPU:0'
input_ch = 3
bgr_shape = nn.get4Dshape(resolution,resolution,input_ch)
mask_shape = nn.get4Dshape(resolution,resolution,1)
self.model_filename_list = []
model_archi = nn.DeepFakeArchi(resolution, opts='ud')
with tf.device ('/CPU:0'):
self.warped_src = tf.placeholder (nn.floatx, bgr_shape)
self.warped_dst = tf.placeholder (nn.floatx, bgr_shape)
self.target_src = tf.placeholder (nn.floatx, bgr_shape)
self.target_dst = tf.placeholder (nn.floatx, bgr_shape)
self.target_srcm = tf.placeholder (nn.floatx, mask_shape)
self.target_dstm = tf.placeholder (nn.floatx, mask_shape)
with tf.device (models_opt_device):
self.encoder = model_archi.Encoder(in_ch=input_ch, e_ch=e_dims, name='encoder')
encoder_out_ch = self.encoder.get_out_ch()*self.encoder.get_out_res(resolution)**2
self.inter = model_archi.Inter (in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims, name='inter')
inter_out_ch = self.inter.get_out_ch()
self.decoder_src = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, name='decoder_src')
self.decoder_dst = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, name='decoder_dst')
self.model_filename_list += [ [self.encoder, 'encoder.npy' ],
[self.inter, 'inter.npy' ],
[self.decoder_src, 'decoder_src.npy'],
[self.decoder_dst, 'decoder_dst.npy'] ]
if self.is_training:
self.src_dst_trainable_weights = self.encoder.get_weights() + self.inter.get_weights() + self.decoder_src.get_weights() + self.decoder_dst.get_weights()
self.src_dst_opt = nn.RMSprop(lr=2e-4, lr_dropout=0.3, name='src_dst_opt')
self.src_dst_opt.initialize_variables(self.src_dst_trainable_weights, vars_on_cpu=optimizer_vars_on_cpu )
self.model_filename_list += [ (self.src_dst_opt, 'src_dst_opt.npy') ]
if self.is_training:
gpu_count = max(1, len(devices) )
bs_per_gpu = max(1, 4 // gpu_count)
self.set_batch_size( gpu_count*bs_per_gpu)
gpu_pred_src_src_list = []
gpu_pred_dst_dst_list = []
gpu_pred_src_dst_list = []
gpu_pred_src_srcm_list = []
gpu_pred_dst_dstm_list = []
gpu_pred_src_dstm_list = []
gpu_src_losses = []
gpu_dst_losses = []
gpu_src_dst_loss_gvs = []
for gpu_id in range(gpu_count):
with tf.device( f'/{devices[gpu_id].tf_dev_type}:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ):
batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu )
with tf.device(f'/CPU:0'):
gpu_warped_src = self.warped_src [batch_slice,:,:,:]
gpu_warped_dst = self.warped_dst [batch_slice,:,:,:]
gpu_target_src = self.target_src [batch_slice,:,:,:]
gpu_target_dst = self.target_dst [batch_slice,:,:,:]
gpu_target_srcm = self.target_srcm[batch_slice,:,:,:]
gpu_target_dstm = self.target_dstm[batch_slice,:,:,:]
gpu_src_code = self.inter(self.encoder(gpu_warped_src))
gpu_dst_code = self.inter(self.encoder(gpu_warped_dst))
gpu_pred_src_src, gpu_pred_src_srcm = self.decoder_src(gpu_src_code)
gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code)
gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code)
gpu_pred_src_src_list.append(gpu_pred_src_src)
gpu_pred_dst_dst_list.append(gpu_pred_dst_dst)
gpu_pred_src_dst_list.append(gpu_pred_src_dst)
gpu_pred_src_srcm_list.append(gpu_pred_src_srcm)
gpu_pred_dst_dstm_list.append(gpu_pred_dst_dstm)
gpu_pred_src_dstm_list.append(gpu_pred_src_dstm)
gpu_target_srcm_blur = nn.gaussian_blur(gpu_target_srcm, max(1, resolution // 32) )
gpu_target_dstm_blur = nn.gaussian_blur(gpu_target_dstm, max(1, resolution // 32) )
gpu_target_dst_masked = gpu_target_dst*gpu_target_dstm_blur
gpu_target_dst_anti_masked = gpu_target_dst*(1.0 - gpu_target_dstm_blur)
gpu_target_src_masked_opt = gpu_target_src*gpu_target_srcm_blur if masked_training else gpu_target_src
gpu_target_dst_masked_opt = gpu_target_dst_masked if masked_training else gpu_target_dst
gpu_pred_src_src_masked_opt = gpu_pred_src_src*gpu_target_srcm_blur if masked_training else gpu_pred_src_src
gpu_pred_dst_dst_masked_opt = gpu_pred_dst_dst*gpu_target_dstm_blur if masked_training else gpu_pred_dst_dst
gpu_psd_target_dst_masked = gpu_pred_src_dst*gpu_target_dstm_blur
gpu_psd_target_dst_anti_masked = gpu_pred_src_dst*(1.0 - gpu_target_dstm_blur)
gpu_src_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])
gpu_src_loss += tf.reduce_mean ( 10*tf.square ( gpu_target_src_masked_opt - gpu_pred_src_src_masked_opt ), axis=[1,2,3])
gpu_src_loss += tf.reduce_mean ( 10*tf.square( gpu_target_srcm - gpu_pred_src_srcm ),axis=[1,2,3] )
gpu_dst_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1])
gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dst_masked_opt- gpu_pred_dst_dst_masked_opt ), axis=[1,2,3])
gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dstm - gpu_pred_dst_dstm ),axis=[1,2,3] )
gpu_src_losses += [gpu_src_loss]
gpu_dst_losses += [gpu_dst_loss]
gpu_G_loss = gpu_src_loss + gpu_dst_loss
gpu_src_dst_loss_gvs += [ nn.gradients ( gpu_G_loss, self.src_dst_trainable_weights ) ]
with tf.device (models_opt_device):
pred_src_src = nn.concat(gpu_pred_src_src_list, 0)
pred_dst_dst = nn.concat(gpu_pred_dst_dst_list, 0)
pred_src_dst = nn.concat(gpu_pred_src_dst_list, 0)
pred_src_srcm = nn.concat(gpu_pred_src_srcm_list, 0)
pred_dst_dstm = nn.concat(gpu_pred_dst_dstm_list, 0)
pred_src_dstm = nn.concat(gpu_pred_src_dstm_list, 0)
src_loss = nn.average_tensor_list(gpu_src_losses)
dst_loss = nn.average_tensor_list(gpu_dst_losses)
src_dst_loss_gv = nn.average_gv_list (gpu_src_dst_loss_gvs)
src_dst_loss_gv_op = self.src_dst_opt.get_update_op (src_dst_loss_gv)
def src_dst_train(warped_src, target_src, target_srcm, \
warped_dst, target_dst, target_dstm):
s, d, _ = nn.tf_sess.run ( [ src_loss, dst_loss, src_dst_loss_gv_op],
feed_dict={self.warped_src :warped_src,
self.target_src :target_src,
self.target_srcm:target_srcm,
self.warped_dst :warped_dst,
self.target_dst :target_dst,
self.target_dstm:target_dstm,
})
s = np.mean(s)
d = np.mean(d)
return s, d
self.src_dst_train = src_dst_train
def AE_view(warped_src, warped_dst):
return nn.tf_sess.run ( [pred_src_src, pred_dst_dst, pred_dst_dstm, pred_src_dst, pred_src_dstm],
feed_dict={self.warped_src:warped_src,
self.warped_dst:warped_dst})
self.AE_view = AE_view
else:
with tf.device( nn.tf_default_device_name if len(devices) != 0 else f'/CPU:0'):
gpu_dst_code = self.inter(self.encoder(self.warped_dst))
gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code)
_, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code)
def AE_merge( warped_dst):
return nn.tf_sess.run ( [gpu_pred_src_dst, gpu_pred_dst_dstm, gpu_pred_src_dstm], feed_dict={self.warped_dst:warped_dst})
self.AE_merge = AE_merge
for model, filename in io.progress_bar_generator(self.model_filename_list, "Initializing models"):
if self.pretrain_just_disabled:
do_init = False
if model == self.inter:
do_init = True
else:
do_init = self.is_first_run()
if not do_init:
do_init = not model.load_weights( self.get_strpath_storage_for_file(filename) )
if do_init and self.pretrained_model_path is not None:
pretrained_filepath = self.pretrained_model_path / filename
if pretrained_filepath.exists():
do_init = not model.load_weights(pretrained_filepath)
if do_init:
model.init_weights()
if self.is_training:
training_data_src_path = self.training_data_src_path if not self.pretrain else self.get_pretraining_data_path()
training_data_dst_path = self.training_data_dst_path if not self.pretrain else self.get_pretraining_data_path()
cpu_count = min(multiprocessing.cpu_count(), 8)
src_generators_count = cpu_count // 2
dst_generators_count = cpu_count // 2
self.set_training_data_generators ([
SampleGeneratorFace(training_data_src_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
sample_process_options=SampleProcessor.Options(random_flip=True if self.pretrain else False),
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':True, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
{'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}
],
generators_count=src_generators_count ),
SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
sample_process_options=SampleProcessor.Options(random_flip=True if self.pretrain else False),
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':True, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
{'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}
],
generators_count=dst_generators_count )
])
self.last_samples = None
def get_model_filename_list(self):
return self.model_filename_list
def onSave(self):
for model, filename in io.progress_bar_generator(self.get_model_filename_list(), "Saving", leave=False):
model.save_weights ( self.get_strpath_storage_for_file(filename) )
def onTrainOneIter(self):
if self.get_iter() % 3 == 0 and self.last_samples is not None:
( (warped_src, target_src, target_srcm), \
(warped_dst, target_dst, target_dstm) ) = self.last_samples
warped_src = target_src
warped_dst = target_dst
else:
samples = self.last_samples = self.generate_next_samples()
( (warped_src, target_src, target_srcm), \
(warped_dst, target_dst, target_dstm) ) = samples
src_loss, dst_loss = self.src_dst_train (warped_src, target_src, target_srcm,
warped_dst, target_dst, target_dstm)
return ( ('src_loss', src_loss), ('dst_loss', dst_loss), )
def onGetPreview(self, samples, for_history=False):
( (warped_src, target_src, target_srcm),
(warped_dst, target_dst, target_dstm) ) = samples
S, D, SS, DD, DDM, SD, SDM = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([target_src,target_dst] + self.AE_view (target_src, target_dst) ) ]
DDM, SDM, = [ np.repeat (x, (3,), -1) for x in [DDM, SDM] ]
target_srcm, target_dstm = [ nn.to_data_format(x,"NHWC", self.model_data_format) for x in ([target_srcm, target_dstm] )]
n_samples = min(4, self.get_batch_size() )
result = []
st = []
for i in range(n_samples):
ar = S[i], SS[i], D[i], DD[i], SD[i]
st.append ( np.concatenate ( ar, axis=1) )
result += [ ('Quick96', np.concatenate (st, axis=0 )), ]
st_m = []
for i in range(n_samples):
ar = S[i]*target_srcm[i], SS[i], D[i]*target_dstm[i], DD[i]*DDM[i], SD[i]*(DDM[i]*SDM[i])
st_m.append ( np.concatenate ( ar, axis=1) )
result += [ ('Quick96 masked', np.concatenate (st_m, axis=0 )), ]
return result
def predictor_func (self, face=None):
face = nn.to_data_format(face[None,...], self.model_data_format, "NHWC")
bgr, mask_dst_dstm, mask_src_dstm = [ nn.to_data_format(x, "NHWC", self.model_data_format).astype(np.float32) for x in self.AE_merge (face) ]
return bgr[0], mask_src_dstm[0][...,0], mask_dst_dstm[0][...,0]
def get_MergerConfig(self):
import merger
return self.predictor_func, (self.resolution, self.resolution, 3), merger.MergerConfigMasked(face_type=self.face_type,
default_mode = 'overlay',
)
Model = QModel