Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
iperov
GitHub Repository: iperov/deepfacelab
Path: blob/master/core/leras/archis/DeepFakeArchi.py
628 views
1
from core.leras import nn
2
tf = nn.tf
3
4
class DeepFakeArchi(nn.ArchiBase):
5
"""
6
resolution
7
8
mod None - default
9
'quick'
10
11
opts ''
12
''
13
't'
14
"""
15
def __init__(self, resolution, use_fp16=False, mod=None, opts=None):
16
super().__init__()
17
18
if opts is None:
19
opts = ''
20
21
22
conv_dtype = tf.float16 if use_fp16 else tf.float32
23
24
if 'c' in opts:
25
def act(x, alpha=0.1):
26
return x*tf.cos(x)
27
else:
28
def act(x, alpha=0.1):
29
return tf.nn.leaky_relu(x, alpha)
30
31
if mod is None:
32
class Downscale(nn.ModelBase):
33
def __init__(self, in_ch, out_ch, kernel_size=5, *kwargs ):
34
self.in_ch = in_ch
35
self.out_ch = out_ch
36
self.kernel_size = kernel_size
37
super().__init__(*kwargs)
38
39
def on_build(self, *args, **kwargs ):
40
self.conv1 = nn.Conv2D( self.in_ch, self.out_ch, kernel_size=self.kernel_size, strides=2, padding='SAME', dtype=conv_dtype)
41
42
def forward(self, x):
43
x = self.conv1(x)
44
x = act(x, 0.1)
45
return x
46
47
def get_out_ch(self):
48
return self.out_ch
49
50
class DownscaleBlock(nn.ModelBase):
51
def on_build(self, in_ch, ch, n_downscales, kernel_size):
52
self.downs = []
53
54
last_ch = in_ch
55
for i in range(n_downscales):
56
cur_ch = ch*( min(2**i, 8) )
57
self.downs.append ( Downscale(last_ch, cur_ch, kernel_size=kernel_size))
58
last_ch = self.downs[-1].get_out_ch()
59
60
def forward(self, inp):
61
x = inp
62
for down in self.downs:
63
x = down(x)
64
return x
65
66
class Upscale(nn.ModelBase):
67
def on_build(self, in_ch, out_ch, kernel_size=3):
68
self.conv1 = nn.Conv2D( in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype)
69
70
def forward(self, x):
71
x = self.conv1(x)
72
x = act(x, 0.1)
73
x = nn.depth_to_space(x, 2)
74
return x
75
76
class ResidualBlock(nn.ModelBase):
77
def on_build(self, ch, kernel_size=3):
78
self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype)
79
self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype)
80
81
def forward(self, inp):
82
x = self.conv1(inp)
83
x = act(x, 0.2)
84
x = self.conv2(x)
85
x = act(inp + x, 0.2)
86
return x
87
88
class Encoder(nn.ModelBase):
89
def __init__(self, in_ch, e_ch, **kwargs ):
90
self.in_ch = in_ch
91
self.e_ch = e_ch
92
super().__init__(**kwargs)
93
94
def on_build(self):
95
if 't' in opts:
96
self.down1 = Downscale(self.in_ch, self.e_ch, kernel_size=5)
97
self.res1 = ResidualBlock(self.e_ch)
98
self.down2 = Downscale(self.e_ch, self.e_ch*2, kernel_size=5)
99
self.down3 = Downscale(self.e_ch*2, self.e_ch*4, kernel_size=5)
100
self.down4 = Downscale(self.e_ch*4, self.e_ch*8, kernel_size=5)
101
self.down5 = Downscale(self.e_ch*8, self.e_ch*8, kernel_size=5)
102
self.res5 = ResidualBlock(self.e_ch*8)
103
else:
104
self.down1 = DownscaleBlock(self.in_ch, self.e_ch, n_downscales=4 if 't' not in opts else 5, kernel_size=5)
105
106
def forward(self, x):
107
if use_fp16:
108
x = tf.cast(x, tf.float16)
109
110
if 't' in opts:
111
x = self.down1(x)
112
x = self.res1(x)
113
x = self.down2(x)
114
x = self.down3(x)
115
x = self.down4(x)
116
x = self.down5(x)
117
x = self.res5(x)
118
else:
119
x = self.down1(x)
120
x = nn.flatten(x)
121
if 'u' in opts:
122
x = nn.pixel_norm(x, axes=-1)
123
124
if use_fp16:
125
x = tf.cast(x, tf.float32)
126
return x
127
128
def get_out_res(self, res):
129
return res // ( (2**4) if 't' not in opts else (2**5) )
130
131
def get_out_ch(self):
132
return self.e_ch * 8
133
134
lowest_dense_res = resolution // (32 if 'd' in opts else 16)
135
136
class Inter(nn.ModelBase):
137
def __init__(self, in_ch, ae_ch, ae_out_ch, **kwargs):
138
self.in_ch, self.ae_ch, self.ae_out_ch = in_ch, ae_ch, ae_out_ch
139
super().__init__(**kwargs)
140
141
def on_build(self):
142
in_ch, ae_ch, ae_out_ch = self.in_ch, self.ae_ch, self.ae_out_ch
143
144
self.dense1 = nn.Dense( in_ch, ae_ch )
145
self.dense2 = nn.Dense( ae_ch, lowest_dense_res * lowest_dense_res * ae_out_ch )
146
if 't' not in opts:
147
self.upscale1 = Upscale(ae_out_ch, ae_out_ch)
148
149
def forward(self, inp):
150
x = inp
151
x = self.dense1(x)
152
x = self.dense2(x)
153
x = nn.reshape_4D (x, lowest_dense_res, lowest_dense_res, self.ae_out_ch)
154
155
if use_fp16:
156
x = tf.cast(x, tf.float16)
157
158
if 't' not in opts:
159
x = self.upscale1(x)
160
161
return x
162
163
def get_out_res(self):
164
return lowest_dense_res * 2 if 't' not in opts else lowest_dense_res
165
166
def get_out_ch(self):
167
return self.ae_out_ch
168
169
class Decoder(nn.ModelBase):
170
def on_build(self, in_ch, d_ch, d_mask_ch):
171
if 't' not in opts:
172
self.upscale0 = Upscale(in_ch, d_ch*8, kernel_size=3)
173
self.upscale1 = Upscale(d_ch*8, d_ch*4, kernel_size=3)
174
self.upscale2 = Upscale(d_ch*4, d_ch*2, kernel_size=3)
175
self.res0 = ResidualBlock(d_ch*8, kernel_size=3)
176
self.res1 = ResidualBlock(d_ch*4, kernel_size=3)
177
self.res2 = ResidualBlock(d_ch*2, kernel_size=3)
178
179
self.upscalem0 = Upscale(in_ch, d_mask_ch*8, kernel_size=3)
180
self.upscalem1 = Upscale(d_mask_ch*8, d_mask_ch*4, kernel_size=3)
181
self.upscalem2 = Upscale(d_mask_ch*4, d_mask_ch*2, kernel_size=3)
182
183
self.out_conv = nn.Conv2D( d_ch*2, 3, kernel_size=1, padding='SAME', dtype=conv_dtype)
184
185
if 'd' in opts:
186
self.out_conv1 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype)
187
self.out_conv2 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype)
188
self.out_conv3 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype)
189
self.upscalem3 = Upscale(d_mask_ch*2, d_mask_ch*1, kernel_size=3)
190
self.out_convm = nn.Conv2D( d_mask_ch*1, 1, kernel_size=1, padding='SAME', dtype=conv_dtype)
191
else:
192
self.out_convm = nn.Conv2D( d_mask_ch*2, 1, kernel_size=1, padding='SAME', dtype=conv_dtype)
193
else:
194
self.upscale0 = Upscale(in_ch, d_ch*8, kernel_size=3)
195
self.upscale1 = Upscale(d_ch*8, d_ch*8, kernel_size=3)
196
self.upscale2 = Upscale(d_ch*8, d_ch*4, kernel_size=3)
197
self.upscale3 = Upscale(d_ch*4, d_ch*2, kernel_size=3)
198
self.res0 = ResidualBlock(d_ch*8, kernel_size=3)
199
self.res1 = ResidualBlock(d_ch*8, kernel_size=3)
200
self.res2 = ResidualBlock(d_ch*4, kernel_size=3)
201
self.res3 = ResidualBlock(d_ch*2, kernel_size=3)
202
203
self.upscalem0 = Upscale(in_ch, d_mask_ch*8, kernel_size=3)
204
self.upscalem1 = Upscale(d_mask_ch*8, d_mask_ch*8, kernel_size=3)
205
self.upscalem2 = Upscale(d_mask_ch*8, d_mask_ch*4, kernel_size=3)
206
self.upscalem3 = Upscale(d_mask_ch*4, d_mask_ch*2, kernel_size=3)
207
self.out_conv = nn.Conv2D( d_ch*2, 3, kernel_size=1, padding='SAME', dtype=conv_dtype)
208
209
if 'd' in opts:
210
self.out_conv1 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype)
211
self.out_conv2 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype)
212
self.out_conv3 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype)
213
self.upscalem4 = Upscale(d_mask_ch*2, d_mask_ch*1, kernel_size=3)
214
self.out_convm = nn.Conv2D( d_mask_ch*1, 1, kernel_size=1, padding='SAME', dtype=conv_dtype)
215
else:
216
self.out_convm = nn.Conv2D( d_mask_ch*2, 1, kernel_size=1, padding='SAME', dtype=conv_dtype)
217
218
219
220
def forward(self, z):
221
x = self.upscale0(z)
222
x = self.res0(x)
223
x = self.upscale1(x)
224
x = self.res1(x)
225
x = self.upscale2(x)
226
x = self.res2(x)
227
228
if 't' in opts:
229
x = self.upscale3(x)
230
x = self.res3(x)
231
232
if 'd' in opts:
233
x = tf.nn.sigmoid( nn.depth_to_space(tf.concat( (self.out_conv(x),
234
self.out_conv1(x),
235
self.out_conv2(x),
236
self.out_conv3(x)), nn.conv2d_ch_axis), 2) )
237
else:
238
x = tf.nn.sigmoid(self.out_conv(x))
239
240
241
m = self.upscalem0(z)
242
m = self.upscalem1(m)
243
m = self.upscalem2(m)
244
245
if 't' in opts:
246
m = self.upscalem3(m)
247
if 'd' in opts:
248
m = self.upscalem4(m)
249
else:
250
if 'd' in opts:
251
m = self.upscalem3(m)
252
253
m = tf.nn.sigmoid(self.out_convm(m))
254
255
if use_fp16:
256
x = tf.cast(x, tf.float32)
257
m = tf.cast(m, tf.float32)
258
259
return x, m
260
261
self.Encoder = Encoder
262
self.Inter = Inter
263
self.Decoder = Decoder
264
265
nn.DeepFakeArchi = DeepFakeArchi
266