Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
iperov
GitHub Repository: iperov/deepfacelab
Path: blob/master/facelib/FaceEnhancer.py
628 views
1
import operator
2
from pathlib import Path
3
4
import cv2
5
import numpy as np
6
7
from core.leras import nn
8
9
class FaceEnhancer(object):
10
"""
11
x4 face enhancer
12
"""
13
def __init__(self, place_model_on_cpu=False, run_on_cpu=False):
14
nn.initialize(data_format="NHWC")
15
tf = nn.tf
16
17
class FaceEnhancer (nn.ModelBase):
18
def __init__(self, name='FaceEnhancer'):
19
super().__init__(name=name)
20
21
def on_build(self):
22
self.conv1 = nn.Conv2D (3, 64, kernel_size=3, strides=1, padding='SAME')
23
24
self.dense1 = nn.Dense (1, 64, use_bias=False)
25
self.dense2 = nn.Dense (1, 64, use_bias=False)
26
27
self.e0_conv0 = nn.Conv2D (64, 64, kernel_size=3, strides=1, padding='SAME')
28
self.e0_conv1 = nn.Conv2D (64, 64, kernel_size=3, strides=1, padding='SAME')
29
30
self.e1_conv0 = nn.Conv2D (64, 112, kernel_size=3, strides=1, padding='SAME')
31
self.e1_conv1 = nn.Conv2D (112, 112, kernel_size=3, strides=1, padding='SAME')
32
33
self.e2_conv0 = nn.Conv2D (112, 192, kernel_size=3, strides=1, padding='SAME')
34
self.e2_conv1 = nn.Conv2D (192, 192, kernel_size=3, strides=1, padding='SAME')
35
36
self.e3_conv0 = nn.Conv2D (192, 336, kernel_size=3, strides=1, padding='SAME')
37
self.e3_conv1 = nn.Conv2D (336, 336, kernel_size=3, strides=1, padding='SAME')
38
39
self.e4_conv0 = nn.Conv2D (336, 512, kernel_size=3, strides=1, padding='SAME')
40
self.e4_conv1 = nn.Conv2D (512, 512, kernel_size=3, strides=1, padding='SAME')
41
42
self.center_conv0 = nn.Conv2D (512, 512, kernel_size=3, strides=1, padding='SAME')
43
self.center_conv1 = nn.Conv2D (512, 512, kernel_size=3, strides=1, padding='SAME')
44
self.center_conv2 = nn.Conv2D (512, 512, kernel_size=3, strides=1, padding='SAME')
45
self.center_conv3 = nn.Conv2D (512, 512, kernel_size=3, strides=1, padding='SAME')
46
47
self.d4_conv0 = nn.Conv2D (1024, 512, kernel_size=3, strides=1, padding='SAME')
48
self.d4_conv1 = nn.Conv2D (512, 512, kernel_size=3, strides=1, padding='SAME')
49
50
self.d3_conv0 = nn.Conv2D (848, 512, kernel_size=3, strides=1, padding='SAME')
51
self.d3_conv1 = nn.Conv2D (512, 512, kernel_size=3, strides=1, padding='SAME')
52
53
self.d2_conv0 = nn.Conv2D (704, 288, kernel_size=3, strides=1, padding='SAME')
54
self.d2_conv1 = nn.Conv2D (288, 288, kernel_size=3, strides=1, padding='SAME')
55
56
self.d1_conv0 = nn.Conv2D (400, 160, kernel_size=3, strides=1, padding='SAME')
57
self.d1_conv1 = nn.Conv2D (160, 160, kernel_size=3, strides=1, padding='SAME')
58
59
self.d0_conv0 = nn.Conv2D (224, 96, kernel_size=3, strides=1, padding='SAME')
60
self.d0_conv1 = nn.Conv2D (96, 96, kernel_size=3, strides=1, padding='SAME')
61
62
self.out1x_conv0 = nn.Conv2D (96, 48, kernel_size=3, strides=1, padding='SAME')
63
self.out1x_conv1 = nn.Conv2D (48, 3, kernel_size=3, strides=1, padding='SAME')
64
65
self.dec2x_conv0 = nn.Conv2D (96, 96, kernel_size=3, strides=1, padding='SAME')
66
self.dec2x_conv1 = nn.Conv2D (96, 96, kernel_size=3, strides=1, padding='SAME')
67
68
self.out2x_conv0 = nn.Conv2D (96, 48, kernel_size=3, strides=1, padding='SAME')
69
self.out2x_conv1 = nn.Conv2D (48, 3, kernel_size=3, strides=1, padding='SAME')
70
71
self.dec4x_conv0 = nn.Conv2D (96, 72, kernel_size=3, strides=1, padding='SAME')
72
self.dec4x_conv1 = nn.Conv2D (72, 72, kernel_size=3, strides=1, padding='SAME')
73
74
self.out4x_conv0 = nn.Conv2D (72, 36, kernel_size=3, strides=1, padding='SAME')
75
self.out4x_conv1 = nn.Conv2D (36, 3 , kernel_size=3, strides=1, padding='SAME')
76
77
def forward(self, inp):
78
bgr, param, param1 = inp
79
80
x = self.conv1(bgr)
81
a = self.dense1(param)
82
a = tf.reshape(a, (-1,1,1,64) )
83
84
b = self.dense2(param1)
85
b = tf.reshape(b, (-1,1,1,64) )
86
87
x = tf.nn.leaky_relu(x+a+b, 0.1)
88
89
x = tf.nn.leaky_relu(self.e0_conv0(x), 0.1)
90
x = e0 = tf.nn.leaky_relu(self.e0_conv1(x), 0.1)
91
92
x = tf.nn.avg_pool(x, [1,2,2,1], [1,2,2,1], "VALID")
93
x = tf.nn.leaky_relu(self.e1_conv0(x), 0.1)
94
x = e1 = tf.nn.leaky_relu(self.e1_conv1(x), 0.1)
95
96
x = tf.nn.avg_pool(x, [1,2,2,1], [1,2,2,1], "VALID")
97
x = tf.nn.leaky_relu(self.e2_conv0(x), 0.1)
98
x = e2 = tf.nn.leaky_relu(self.e2_conv1(x), 0.1)
99
100
x = tf.nn.avg_pool(x, [1,2,2,1], [1,2,2,1], "VALID")
101
x = tf.nn.leaky_relu(self.e3_conv0(x), 0.1)
102
x = e3 = tf.nn.leaky_relu(self.e3_conv1(x), 0.1)
103
104
x = tf.nn.avg_pool(x, [1,2,2,1], [1,2,2,1], "VALID")
105
x = tf.nn.leaky_relu(self.e4_conv0(x), 0.1)
106
x = e4 = tf.nn.leaky_relu(self.e4_conv1(x), 0.1)
107
108
x = tf.nn.avg_pool(x, [1,2,2,1], [1,2,2,1], "VALID")
109
x = tf.nn.leaky_relu(self.center_conv0(x), 0.1)
110
x = tf.nn.leaky_relu(self.center_conv1(x), 0.1)
111
x = tf.nn.leaky_relu(self.center_conv2(x), 0.1)
112
x = tf.nn.leaky_relu(self.center_conv3(x), 0.1)
113
114
x = tf.concat( [nn.resize2d_bilinear(x), e4], -1 )
115
x = tf.nn.leaky_relu(self.d4_conv0(x), 0.1)
116
x = tf.nn.leaky_relu(self.d4_conv1(x), 0.1)
117
118
x = tf.concat( [nn.resize2d_bilinear(x), e3], -1 )
119
x = tf.nn.leaky_relu(self.d3_conv0(x), 0.1)
120
x = tf.nn.leaky_relu(self.d3_conv1(x), 0.1)
121
122
x = tf.concat( [nn.resize2d_bilinear(x), e2], -1 )
123
x = tf.nn.leaky_relu(self.d2_conv0(x), 0.1)
124
x = tf.nn.leaky_relu(self.d2_conv1(x), 0.1)
125
126
x = tf.concat( [nn.resize2d_bilinear(x), e1], -1 )
127
x = tf.nn.leaky_relu(self.d1_conv0(x), 0.1)
128
x = tf.nn.leaky_relu(self.d1_conv1(x), 0.1)
129
130
x = tf.concat( [nn.resize2d_bilinear(x), e0], -1 )
131
x = tf.nn.leaky_relu(self.d0_conv0(x), 0.1)
132
x = d0 = tf.nn.leaky_relu(self.d0_conv1(x), 0.1)
133
134
x = tf.nn.leaky_relu(self.out1x_conv0(x), 0.1)
135
x = self.out1x_conv1(x)
136
out1x = bgr + tf.nn.tanh(x)
137
138
x = d0
139
x = tf.nn.leaky_relu(self.dec2x_conv0(x), 0.1)
140
x = tf.nn.leaky_relu(self.dec2x_conv1(x), 0.1)
141
x = d2x = nn.resize2d_bilinear(x)
142
143
x = tf.nn.leaky_relu(self.out2x_conv0(x), 0.1)
144
x = self.out2x_conv1(x)
145
146
out2x = nn.resize2d_bilinear(out1x) + tf.nn.tanh(x)
147
148
x = d2x
149
x = tf.nn.leaky_relu(self.dec4x_conv0(x), 0.1)
150
x = tf.nn.leaky_relu(self.dec4x_conv1(x), 0.1)
151
x = d4x = nn.resize2d_bilinear(x)
152
153
x = tf.nn.leaky_relu(self.out4x_conv0(x), 0.1)
154
x = self.out4x_conv1(x)
155
156
out4x = nn.resize2d_bilinear(out2x) + tf.nn.tanh(x)
157
158
return out4x
159
160
model_path = Path(__file__).parent / "FaceEnhancer.npy"
161
if not model_path.exists():
162
raise Exception("Unable to load FaceEnhancer.npy")
163
164
with tf.device ('/CPU:0' if place_model_on_cpu else nn.tf_default_device_name):
165
self.model = FaceEnhancer()
166
self.model.load_weights (model_path)
167
168
with tf.device ('/CPU:0' if run_on_cpu else nn.tf_default_device_name):
169
self.model.build_for_run ([ (tf.float32, nn.get4Dshape (192,192,3) ),
170
(tf.float32, (None,1,) ),
171
(tf.float32, (None,1,) ),
172
])
173
174
def enhance (self, inp_img, is_tanh=False, preserve_size=True):
175
if not is_tanh:
176
inp_img = np.clip( inp_img * 2 -1, -1, 1 )
177
178
param = np.array([0.2])
179
param1 = np.array([1.0])
180
up_res = 4
181
patch_size = 192
182
patch_size_half = patch_size // 2
183
184
ih,iw,ic = inp_img.shape
185
h,w,c = ih,iw,ic
186
187
th,tw = h*up_res, w*up_res
188
189
t_padding = 0
190
b_padding = 0
191
l_padding = 0
192
r_padding = 0
193
194
if h < patch_size:
195
t_padding = (patch_size-h)//2
196
b_padding = (patch_size-h) - t_padding
197
198
if w < patch_size:
199
l_padding = (patch_size-w)//2
200
r_padding = (patch_size-w) - l_padding
201
202
if t_padding != 0:
203
inp_img = np.concatenate ([ np.zeros ( (t_padding,w,c), dtype=np.float32 ), inp_img ], axis=0 )
204
h,w,c = inp_img.shape
205
206
if b_padding != 0:
207
inp_img = np.concatenate ([ inp_img, np.zeros ( (b_padding,w,c), dtype=np.float32 ) ], axis=0 )
208
h,w,c = inp_img.shape
209
210
if l_padding != 0:
211
inp_img = np.concatenate ([ np.zeros ( (h,l_padding,c), dtype=np.float32 ), inp_img ], axis=1 )
212
h,w,c = inp_img.shape
213
214
if r_padding != 0:
215
inp_img = np.concatenate ([ inp_img, np.zeros ( (h,r_padding,c), dtype=np.float32 ) ], axis=1 )
216
h,w,c = inp_img.shape
217
218
219
i_max = w-patch_size+1
220
j_max = h-patch_size+1
221
222
final_img = np.zeros ( (h*up_res,w*up_res,c), dtype=np.float32 )
223
final_img_div = np.zeros ( (h*up_res,w*up_res,1), dtype=np.float32 )
224
225
x = np.concatenate ( [ np.linspace (0,1,patch_size_half*up_res), np.linspace (1,0,patch_size_half*up_res) ] )
226
x,y = np.meshgrid(x,x)
227
patch_mask = (x*y)[...,None]
228
229
j=0
230
while j < j_max:
231
i = 0
232
while i < i_max:
233
patch_img = inp_img[j:j+patch_size, i:i+patch_size,:]
234
x = self.model.run( [ patch_img[None,...], [param], [param1] ] )[0]
235
final_img [j*up_res:(j+patch_size)*up_res, i*up_res:(i+patch_size)*up_res,:] += x*patch_mask
236
final_img_div[j*up_res:(j+patch_size)*up_res, i*up_res:(i+patch_size)*up_res,:] += patch_mask
237
if i == i_max-1:
238
break
239
i = min( i+patch_size_half, i_max-1)
240
if j == j_max-1:
241
break
242
j = min( j+patch_size_half, j_max-1)
243
244
final_img_div[final_img_div==0] = 1.0
245
final_img /= final_img_div
246
247
if t_padding+b_padding+l_padding+r_padding != 0:
248
final_img = final_img [t_padding*up_res:(h-b_padding)*up_res, l_padding*up_res:(w-r_padding)*up_res,:]
249
250
if preserve_size:
251
final_img = cv2.resize (final_img, (iw,ih), interpolation=cv2.INTER_LANCZOS4)
252
253
if not is_tanh:
254
final_img = np.clip( final_img/2+0.5, 0, 1 )
255
256
return final_img
257
258
259
"""
260
261
def enhance (self, inp_img, is_tanh=False, preserve_size=True):
262
if not is_tanh:
263
inp_img = np.clip( inp_img * 2 -1, -1, 1 )
264
265
param = np.array([0.2])
266
param1 = np.array([1.0])
267
up_res = 4
268
patch_size = 192
269
patch_size_half = patch_size // 2
270
271
h,w,c = inp_img.shape
272
273
th,tw = h*up_res, w*up_res
274
275
preupscale_rate = 1.0
276
277
if h < patch_size or w < patch_size:
278
preupscale_rate = 1.0 / ( max(h,w) / patch_size )
279
280
if preupscale_rate != 1.0:
281
inp_img = cv2.resize (inp_img, ( int(w*preupscale_rate), int(h*preupscale_rate) ), interpolation=cv2.INTER_LANCZOS4)
282
h,w,c = inp_img.shape
283
284
i_max = w-patch_size+1
285
j_max = h-patch_size+1
286
287
final_img = np.zeros ( (h*up_res,w*up_res,c), dtype=np.float32 )
288
final_img_div = np.zeros ( (h*up_res,w*up_res,1), dtype=np.float32 )
289
290
x = np.concatenate ( [ np.linspace (0,1,patch_size_half*up_res), np.linspace (1,0,patch_size_half*up_res) ] )
291
x,y = np.meshgrid(x,x)
292
patch_mask = (x*y)[...,None]
293
294
j=0
295
while j < j_max:
296
i = 0
297
while i < i_max:
298
patch_img = inp_img[j:j+patch_size, i:i+patch_size,:]
299
x = self.model.run( [ patch_img[None,...], [param], [param1] ] )[0]
300
final_img [j*up_res:(j+patch_size)*up_res, i*up_res:(i+patch_size)*up_res,:] += x*patch_mask
301
final_img_div[j*up_res:(j+patch_size)*up_res, i*up_res:(i+patch_size)*up_res,:] += patch_mask
302
if i == i_max-1:
303
break
304
i = min( i+patch_size_half, i_max-1)
305
if j == j_max-1:
306
break
307
j = min( j+patch_size_half, j_max-1)
308
309
final_img_div[final_img_div==0] = 1.0
310
final_img /= final_img_div
311
312
if preserve_size:
313
final_img = cv2.resize (final_img, (w,h), interpolation=cv2.INTER_LANCZOS4)
314
else:
315
if preupscale_rate != 1.0:
316
final_img = cv2.resize (final_img, (tw,th), interpolation=cv2.INTER_LANCZOS4)
317
318
if not is_tanh:
319
final_img = np.clip( final_img/2+0.5, 0, 1 )
320
321
return final_img
322
"""
323