Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
iperov
GitHub Repository: iperov/deepfacelab
Path: blob/master/facelib/XSegNet.py
628 views
1
import os
2
import pickle
3
from functools import partial
4
from pathlib import Path
5
6
import cv2
7
import numpy as np
8
9
from core.interact import interact as io
10
from core.leras import nn
11
12
13
class XSegNet(object):
14
VERSION = 1
15
16
def __init__ (self, name,
17
resolution=256,
18
load_weights=True,
19
weights_file_root=None,
20
training=False,
21
place_model_on_cpu=False,
22
run_on_cpu=False,
23
optimizer=None,
24
data_format="NHWC",
25
raise_on_no_model_files=False):
26
27
self.resolution = resolution
28
self.weights_file_root = Path(weights_file_root) if weights_file_root is not None else Path(__file__).parent
29
30
nn.initialize(data_format=data_format)
31
tf = nn.tf
32
33
model_name = f'{name}_{resolution}'
34
self.model_filename_list = []
35
36
with tf.device ('/CPU:0'):
37
#Place holders on CPU
38
self.input_t = tf.placeholder (nn.floatx, nn.get4Dshape(resolution,resolution,3) )
39
self.target_t = tf.placeholder (nn.floatx, nn.get4Dshape(resolution,resolution,1) )
40
41
# Initializing model classes
42
with tf.device ('/CPU:0' if place_model_on_cpu else nn.tf_default_device_name):
43
self.model = nn.XSeg(3, 32, 1, name=name)
44
self.model_weights = self.model.get_weights()
45
if training:
46
if optimizer is None:
47
raise ValueError("Optimizer should be provided for training mode.")
48
self.opt = optimizer
49
self.opt.initialize_variables (self.model_weights, vars_on_cpu=place_model_on_cpu)
50
self.model_filename_list += [ [self.opt, f'{model_name}_opt.npy' ] ]
51
52
53
self.model_filename_list += [ [self.model, f'{model_name}.npy'] ]
54
55
if not training:
56
with tf.device ('/CPU:0' if run_on_cpu else nn.tf_default_device_name):
57
_, pred = self.model(self.input_t)
58
59
def net_run(input_np):
60
return nn.tf_sess.run ( [pred], feed_dict={self.input_t :input_np})[0]
61
self.net_run = net_run
62
63
self.initialized = True
64
# Loading/initializing all models/optimizers weights
65
for model, filename in self.model_filename_list:
66
do_init = not load_weights
67
68
if not do_init:
69
model_file_path = self.weights_file_root / filename
70
do_init = not model.load_weights( model_file_path )
71
if do_init:
72
if raise_on_no_model_files:
73
raise Exception(f'{model_file_path} does not exists.')
74
if not training:
75
self.initialized = False
76
break
77
78
if do_init:
79
model.init_weights()
80
81
def get_resolution(self):
82
return self.resolution
83
84
def flow(self, x, pretrain=False):
85
return self.model(x, pretrain=pretrain)
86
87
def get_weights(self):
88
return self.model_weights
89
90
def save_weights(self):
91
for model, filename in io.progress_bar_generator(self.model_filename_list, "Saving", leave=False):
92
model.save_weights( self.weights_file_root / filename )
93
94
def extract (self, input_image):
95
if not self.initialized:
96
return 0.5*np.ones ( (self.resolution, self.resolution, 1), nn.floatx.as_numpy_dtype )
97
98
input_shape_len = len(input_image.shape)
99
if input_shape_len == 3:
100
input_image = input_image[None,...]
101
102
result = np.clip ( self.net_run(input_image), 0, 1.0 )
103
result[result < 0.1] = 0 #get rid of noise
104
105
if input_shape_len == 3:
106
result = result[0]
107
108
return result
109