Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
iperov
GitHub Repository: iperov/deepfacelab
Path: blob/master/core/imagelib/warp.py
628 views
1
import numpy as np
2
import numpy.linalg as npla
3
import cv2
4
from core import randomex
5
6
def mls_rigid_deformation(vy, vx, src_pts, dst_pts, alpha=1.0, eps=1e-8):
7
dst_pts = dst_pts[..., ::-1].astype(np.int16)
8
src_pts = src_pts[..., ::-1].astype(np.int16)
9
10
src_pts, dst_pts = dst_pts, src_pts
11
12
grow = vx.shape[0]
13
gcol = vx.shape[1]
14
ctrls = src_pts.shape[0]
15
16
reshaped_p = src_pts.reshape(ctrls, 2, 1, 1)
17
reshaped_v = np.vstack((vx.reshape(1, grow, gcol), vy.reshape(1, grow, gcol)))
18
19
w = 1.0 / (np.sum((reshaped_p - reshaped_v).astype(np.float32) ** 2, axis=1) + eps) ** alpha
20
w /= np.sum(w, axis=0, keepdims=True)
21
22
pstar = np.zeros((2, grow, gcol), np.float32)
23
for i in range(ctrls):
24
pstar += w[i] * reshaped_p[i]
25
26
vpstar = reshaped_v - pstar
27
28
reshaped_mul_right = np.concatenate((vpstar[:,None,...],
29
np.concatenate((vpstar[1:2,None,...],-vpstar[0:1,None,...]), 0)
30
), axis=1).transpose(2, 3, 0, 1)
31
32
reshaped_q = dst_pts.reshape((ctrls, 2, 1, 1))
33
34
qstar = np.zeros((2, grow, gcol), np.float32)
35
for i in range(ctrls):
36
qstar += w[i] * reshaped_q[i]
37
38
temp = np.zeros((grow, gcol, 2), np.float32)
39
for i in range(ctrls):
40
phat = reshaped_p[i] - pstar
41
qhat = reshaped_q[i] - qstar
42
43
temp += np.matmul(qhat.reshape(1, 2, grow, gcol).transpose(2, 3, 0, 1),
44
45
np.matmul( ( w[None, i:i+1,...] *
46
np.concatenate((phat.reshape(1, 2, grow, gcol),
47
np.concatenate( (phat[None,1:2], -phat[None,0:1]), 1 )), 0)
48
).transpose(2, 3, 0, 1), reshaped_mul_right
49
)
50
).reshape(grow, gcol, 2)
51
52
temp = temp.transpose(2, 0, 1)
53
54
normed_temp = np.linalg.norm(temp, axis=0, keepdims=True)
55
normed_vpstar = np.linalg.norm(vpstar, axis=0, keepdims=True)
56
nan_mask = normed_temp[0]==0
57
58
transformers = np.true_divide(temp, normed_temp, out=np.zeros_like(temp), where= ~nan_mask) * normed_vpstar + qstar
59
nan_mask_flat = np.flatnonzero(nan_mask)
60
nan_mask_anti_flat = np.flatnonzero(~nan_mask)
61
62
transformers[0][nan_mask] = np.interp(nan_mask_flat, nan_mask_anti_flat, transformers[0][~nan_mask])
63
transformers[1][nan_mask] = np.interp(nan_mask_flat, nan_mask_anti_flat, transformers[1][~nan_mask])
64
65
return transformers
66
67
def gen_pts(W, H, rnd_state=None):
68
69
if rnd_state is None:
70
rnd_state = np.random
71
72
min_pts, max_pts = 4, 8
73
n_pts = rnd_state.randint(min_pts, max_pts)
74
75
min_radius_per = 0.00
76
max_radius_per = 0.10
77
pts = []
78
79
for i in range(n_pts):
80
while True:
81
x, y = rnd_state.randint(W), rnd_state.randint(H)
82
rad = min_radius_per + rnd_state.rand()*(max_radius_per-min_radius_per)
83
84
intersect = False
85
for px,py,prad,_,_ in pts:
86
87
dist = npla.norm([x-px, y-py])
88
if dist <= (rad+prad)*2:
89
intersect = True
90
break
91
if intersect:
92
continue
93
94
angle = rnd_state.rand()*(2*np.pi)
95
x2 = int(x+np.cos(angle)*W*rad)
96
y2 = int(y+np.sin(angle)*H*rad)
97
98
break
99
pts.append( (x,y,rad, x2,y2) )
100
101
pts1 = np.array( [ [pt[0],pt[1]] for pt in pts ] )
102
pts2 = np.array( [ [pt[-2],pt[-1]] for pt in pts ] )
103
104
return pts1, pts2
105
106
107
def gen_warp_params (w, flip=False, rotation_range=[-10,10], scale_range=[-0.5, 0.5], tx_range=[-0.05, 0.05], ty_range=[-0.05, 0.05], rnd_state=None, warp_rnd_state=None ):
108
if rnd_state is None:
109
rnd_state = np.random
110
if warp_rnd_state is None:
111
warp_rnd_state = np.random
112
rw = None
113
if w < 64:
114
rw = w
115
w = 64
116
117
rotation = rnd_state.uniform( rotation_range[0], rotation_range[1] )
118
scale = rnd_state.uniform( 1/(1-scale_range[0]) , 1+scale_range[1] )
119
tx = rnd_state.uniform( tx_range[0], tx_range[1] )
120
ty = rnd_state.uniform( ty_range[0], ty_range[1] )
121
p_flip = flip and rnd_state.randint(10) < 4
122
123
#random warp V1
124
cell_size = [ w // (2**i) for i in range(1,4) ] [ warp_rnd_state.randint(3) ]
125
cell_count = w // cell_size + 1
126
grid_points = np.linspace( 0, w, cell_count)
127
mapx = np.broadcast_to(grid_points, (cell_count, cell_count)).copy()
128
mapy = mapx.T
129
mapx[1:-1,1:-1] = mapx[1:-1,1:-1] + randomex.random_normal( size=(cell_count-2, cell_count-2), rnd_state=warp_rnd_state )*(cell_size*0.24)
130
mapy[1:-1,1:-1] = mapy[1:-1,1:-1] + randomex.random_normal( size=(cell_count-2, cell_count-2), rnd_state=warp_rnd_state )*(cell_size*0.24)
131
half_cell_size = cell_size // 2
132
mapx = cv2.resize(mapx, (w+cell_size,)*2 )[half_cell_size:-half_cell_size,half_cell_size:-half_cell_size].astype(np.float32)
133
mapy = cv2.resize(mapy, (w+cell_size,)*2 )[half_cell_size:-half_cell_size,half_cell_size:-half_cell_size].astype(np.float32)
134
##############
135
136
# random warp V2
137
# pts1, pts2 = gen_pts(w, w, rnd_state)
138
# gridX = np.arange(w, dtype=np.int16)
139
# gridY = np.arange(w, dtype=np.int16)
140
# vy, vx = np.meshgrid(gridX, gridY)
141
# drigid = mls_rigid_deformation(vy, vx, pts1, pts2)
142
# mapy, mapx = drigid.astype(np.float32)
143
################
144
145
#random transform
146
random_transform_mat = cv2.getRotationMatrix2D((w // 2, w // 2), rotation, scale)
147
random_transform_mat[:, 2] += (tx*w, ty*w)
148
149
params = dict()
150
params['mapx'] = mapx
151
params['mapy'] = mapy
152
params['rmat'] = random_transform_mat
153
u_mat = random_transform_mat.copy()
154
u_mat[:,2] /= w
155
params['umat'] = u_mat
156
params['w'] = w
157
params['rw'] = rw
158
params['flip'] = p_flip
159
160
return params
161
162
def warp_by_params (params, img, can_warp, can_transform, can_flip, border_replicate, cv2_inter=cv2.INTER_CUBIC):
163
rw = params['rw']
164
165
if (can_warp or can_transform) and rw is not None:
166
img = cv2.resize(img, (64,64), interpolation=cv2_inter)
167
168
if can_warp:
169
img = cv2.remap(img, params['mapx'], params['mapy'], cv2_inter )
170
if can_transform:
171
img = cv2.warpAffine( img, params['rmat'], (params['w'], params['w']), borderMode=(cv2.BORDER_REPLICATE if border_replicate else cv2.BORDER_CONSTANT), flags=cv2_inter )
172
173
174
if (can_warp or can_transform) and rw is not None:
175
img = cv2.resize(img, (rw,rw), interpolation=cv2_inter)
176
177
if len(img.shape) == 2:
178
img = img[...,None]
179
if can_flip and params['flip']:
180
img = img[:,::-1,...]
181
return img
182