Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
iperov
GitHub Repository: iperov/deepfacelab
Path: blob/master/facelib/FANExtractor.py
628 views
1
import os
2
import traceback
3
from pathlib import Path
4
5
import cv2
6
import numpy as np
7
from numpy import linalg as npla
8
9
from facelib import FaceType, LandmarksProcessor
10
from core.leras import nn
11
12
"""
13
ported from https://github.com/1adrianb/face-alignment
14
"""
15
class FANExtractor(object):
16
def __init__ (self, landmarks_3D=False, place_model_on_cpu=False):
17
18
model_path = Path(__file__).parent / ( "2DFAN.npy" if not landmarks_3D else "3DFAN.npy")
19
if not model_path.exists():
20
raise Exception("Unable to load FANExtractor model")
21
22
nn.initialize(data_format="NHWC")
23
tf = nn.tf
24
25
class ConvBlock(nn.ModelBase):
26
def on_build(self, in_planes, out_planes):
27
self.in_planes = in_planes
28
self.out_planes = out_planes
29
30
self.bn1 = nn.BatchNorm2D(in_planes)
31
self.conv1 = nn.Conv2D (in_planes, out_planes//2, kernel_size=3, strides=1, padding='SAME', use_bias=False )
32
33
self.bn2 = nn.BatchNorm2D(out_planes//2)
34
self.conv2 = nn.Conv2D (out_planes//2, out_planes//4, kernel_size=3, strides=1, padding='SAME', use_bias=False )
35
36
self.bn3 = nn.BatchNorm2D(out_planes//4)
37
self.conv3 = nn.Conv2D (out_planes//4, out_planes//4, kernel_size=3, strides=1, padding='SAME', use_bias=False )
38
39
if self.in_planes != self.out_planes:
40
self.down_bn1 = nn.BatchNorm2D(in_planes)
41
self.down_conv1 = nn.Conv2D (in_planes, out_planes, kernel_size=1, strides=1, padding='VALID', use_bias=False )
42
else:
43
self.down_bn1 = None
44
self.down_conv1 = None
45
46
def forward(self, input):
47
x = input
48
x = self.bn1(x)
49
x = tf.nn.relu(x)
50
x = out1 = self.conv1(x)
51
52
x = self.bn2(x)
53
x = tf.nn.relu(x)
54
x = out2 = self.conv2(x)
55
56
x = self.bn3(x)
57
x = tf.nn.relu(x)
58
x = out3 = self.conv3(x)
59
60
x = tf.concat ([out1, out2, out3], axis=-1)
61
62
if self.in_planes != self.out_planes:
63
downsample = self.down_bn1(input)
64
downsample = tf.nn.relu (downsample)
65
downsample = self.down_conv1 (downsample)
66
x = x + downsample
67
else:
68
x = x + input
69
70
return x
71
72
class HourGlass (nn.ModelBase):
73
def on_build(self, in_planes, depth):
74
self.b1 = ConvBlock (in_planes, 256)
75
self.b2 = ConvBlock (in_planes, 256)
76
77
if depth > 1:
78
self.b2_plus = HourGlass(256, depth-1)
79
else:
80
self.b2_plus = ConvBlock(256, 256)
81
82
self.b3 = ConvBlock(256, 256)
83
84
def forward(self, input):
85
up1 = self.b1(input)
86
87
low1 = tf.nn.avg_pool(input, [1,2,2,1], [1,2,2,1], 'VALID')
88
low1 = self.b2 (low1)
89
90
low2 = self.b2_plus(low1)
91
low3 = self.b3(low2)
92
93
up2 = nn.upsample2d(low3)
94
95
return up1+up2
96
97
class FAN (nn.ModelBase):
98
def __init__(self):
99
super().__init__(name='FAN')
100
101
def on_build(self):
102
self.conv1 = nn.Conv2D (3, 64, kernel_size=7, strides=2, padding='SAME')
103
self.bn1 = nn.BatchNorm2D(64)
104
105
self.conv2 = ConvBlock(64, 128)
106
self.conv3 = ConvBlock(128, 128)
107
self.conv4 = ConvBlock(128, 256)
108
109
self.m = []
110
self.top_m = []
111
self.conv_last = []
112
self.bn_end = []
113
self.l = []
114
self.bl = []
115
self.al = []
116
for i in range(4):
117
self.m += [ HourGlass(256, 4) ]
118
self.top_m += [ ConvBlock(256, 256) ]
119
120
self.conv_last += [ nn.Conv2D (256, 256, kernel_size=1, strides=1, padding='VALID') ]
121
self.bn_end += [ nn.BatchNorm2D(256) ]
122
123
self.l += [ nn.Conv2D (256, 68, kernel_size=1, strides=1, padding='VALID') ]
124
125
if i < 4-1:
126
self.bl += [ nn.Conv2D (256, 256, kernel_size=1, strides=1, padding='VALID') ]
127
self.al += [ nn.Conv2D (68, 256, kernel_size=1, strides=1, padding='VALID') ]
128
129
def forward(self, inp) :
130
x, = inp
131
x = self.conv1(x)
132
x = self.bn1(x)
133
x = tf.nn.relu(x)
134
135
x = self.conv2(x)
136
x = tf.nn.avg_pool(x, [1,2,2,1], [1,2,2,1], 'VALID')
137
x = self.conv3(x)
138
x = self.conv4(x)
139
140
outputs = []
141
previous = x
142
for i in range(4):
143
ll = self.m[i] (previous)
144
ll = self.top_m[i] (ll)
145
ll = self.conv_last[i] (ll)
146
ll = self.bn_end[i] (ll)
147
ll = tf.nn.relu(ll)
148
tmp_out = self.l[i](ll)
149
outputs.append(tmp_out)
150
if i < 4 - 1:
151
ll = self.bl[i](ll)
152
previous = previous + ll + self.al[i](tmp_out)
153
x = outputs[-1]
154
x = tf.transpose(x, (0,3,1,2) )
155
return x
156
157
e = None
158
if place_model_on_cpu:
159
e = tf.device("/CPU:0")
160
161
if e is not None: e.__enter__()
162
self.model = FAN()
163
self.model.load_weights(str(model_path))
164
if e is not None: e.__exit__(None,None,None)
165
166
self.model.build_for_run ([ ( tf.float32, (None,256,256,3) ) ])
167
168
def extract (self, input_image, rects, second_pass_extractor=None, is_bgr=True, multi_sample=False):
169
if len(rects) == 0:
170
return []
171
172
if is_bgr:
173
input_image = input_image[:,:,::-1]
174
is_bgr = False
175
176
(h, w, ch) = input_image.shape
177
178
landmarks = []
179
for (left, top, right, bottom) in rects:
180
scale = (right - left + bottom - top) / 195.0
181
182
center = np.array( [ (left + right) / 2.0, (top + bottom) / 2.0] )
183
centers = [ center ]
184
185
if multi_sample:
186
centers += [ center + [-1,-1],
187
center + [1,-1],
188
center + [1,1],
189
center + [-1,1],
190
]
191
192
images = []
193
ptss = []
194
195
try:
196
for c in centers:
197
images += [ self.crop(input_image, c, scale) ]
198
199
images = np.stack (images)
200
images = images.astype(np.float32) / 255.0
201
202
predicted = []
203
for i in range( len(images) ):
204
predicted += [ self.model.run ( [ images[i][None,...] ] )[0] ]
205
206
predicted = np.stack(predicted)
207
208
for i, pred in enumerate(predicted):
209
ptss += [ self.get_pts_from_predict ( pred, centers[i], scale) ]
210
pts_img = np.mean ( np.array(ptss), 0 )
211
212
landmarks.append (pts_img)
213
except:
214
landmarks.append (None)
215
216
if second_pass_extractor is not None:
217
for i, lmrks in enumerate(landmarks):
218
try:
219
if lmrks is not None:
220
image_to_face_mat = LandmarksProcessor.get_transform_mat (lmrks, 256, FaceType.FULL)
221
face_image = cv2.warpAffine(input_image, image_to_face_mat, (256, 256), cv2.INTER_CUBIC )
222
223
rects2 = second_pass_extractor.extract(face_image, is_bgr=is_bgr)
224
if len(rects2) == 1: #dont do second pass if faces != 1 detected in cropped image
225
lmrks2 = self.extract (face_image, [ rects2[0] ], is_bgr=is_bgr, multi_sample=True)[0]
226
landmarks[i] = LandmarksProcessor.transform_points (lmrks2, image_to_face_mat, True)
227
except:
228
pass
229
230
return landmarks
231
232
def transform(self, point, center, scale, resolution):
233
pt = np.array ( [point[0], point[1], 1.0] )
234
h = 200.0 * scale
235
m = np.eye(3)
236
m[0,0] = resolution / h
237
m[1,1] = resolution / h
238
m[0,2] = resolution * ( -center[0] / h + 0.5 )
239
m[1,2] = resolution * ( -center[1] / h + 0.5 )
240
m = np.linalg.inv(m)
241
return np.matmul (m, pt)[0:2]
242
243
def crop(self, image, center, scale, resolution=256.0):
244
ul = self.transform([1, 1], center, scale, resolution).astype( np.int )
245
br = self.transform([resolution, resolution], center, scale, resolution).astype( np.int )
246
247
if image.ndim > 2:
248
newDim = np.array([br[1] - ul[1], br[0] - ul[0], image.shape[2]], dtype=np.int32)
249
newImg = np.zeros(newDim, dtype=np.uint8)
250
else:
251
newDim = np.array([br[1] - ul[1], br[0] - ul[0]], dtype=np.int)
252
newImg = np.zeros(newDim, dtype=np.uint8)
253
ht = image.shape[0]
254
wd = image.shape[1]
255
newX = np.array([max(1, -ul[0] + 1), min(br[0], wd) - ul[0]], dtype=np.int32)
256
newY = np.array([max(1, -ul[1] + 1), min(br[1], ht) - ul[1]], dtype=np.int32)
257
oldX = np.array([max(1, ul[0] + 1), min(br[0], wd)], dtype=np.int32)
258
oldY = np.array([max(1, ul[1] + 1), min(br[1], ht)], dtype=np.int32)
259
newImg[newY[0] - 1:newY[1], newX[0] - 1:newX[1] ] = image[oldY[0] - 1:oldY[1], oldX[0] - 1:oldX[1], :]
260
261
newImg = cv2.resize(newImg, dsize=(int(resolution), int(resolution)), interpolation=cv2.INTER_LINEAR)
262
return newImg
263
264
def get_pts_from_predict(self, a, center, scale):
265
a_ch, a_h, a_w = a.shape
266
267
b = a.reshape ( (a_ch, a_h*a_w) )
268
c = b.argmax(1).reshape ( (a_ch, 1) ).repeat(2, axis=1).astype(np.float)
269
c[:,0] %= a_w
270
c[:,1] = np.apply_along_axis ( lambda x: np.floor(x / a_w), 0, c[:,1] )
271
272
for i in range(a_ch):
273
pX, pY = int(c[i,0]), int(c[i,1])
274
if pX > 0 and pX < 63 and pY > 0 and pY < 63:
275
diff = np.array ( [a[i,pY,pX+1]-a[i,pY,pX-1], a[i,pY+1,pX]-a[i,pY-1,pX]] )
276
c[i] += np.sign(diff)*0.25
277
278
c += 0.5
279
280
return np.array( [ self.transform (c[i], center, scale, a_w) for i in range(a_ch) ] )
281
282