Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
iperov
GitHub Repository: iperov/deepfacelab
Path: blob/master/core/leras/models/PatchDiscriminator.py
628 views
1
import numpy as np
2
from core.leras import nn
3
tf = nn.tf
4
5
patch_discriminator_kernels = \
6
{ 1 : (512, [ [1,1] ]),
7
2 : (512, [ [2,1] ]),
8
3 : (512, [ [2,1], [2,1] ]),
9
4 : (512, [ [2,2], [2,2] ]),
10
5 : (512, [ [3,2], [2,2] ]),
11
6 : (512, [ [4,2], [2,2] ]),
12
7 : (512, [ [3,2], [3,2] ]),
13
8 : (512, [ [4,2], [3,2] ]),
14
9 : (512, [ [3,2], [4,2] ]),
15
10 : (512, [ [4,2], [4,2] ]),
16
11 : (512, [ [3,2], [3,2], [2,1] ]),
17
12 : (512, [ [4,2], [3,2], [2,1] ]),
18
13 : (512, [ [3,2], [4,2], [2,1] ]),
19
14 : (512, [ [4,2], [4,2], [2,1] ]),
20
15 : (512, [ [3,2], [3,2], [3,1] ]),
21
16 : (512, [ [4,2], [3,2], [3,1] ]),
22
17 : (512, [ [3,2], [4,2], [3,1] ]),
23
18 : (512, [ [4,2], [4,2], [3,1] ]),
24
19 : (512, [ [3,2], [3,2], [4,1] ]),
25
20 : (512, [ [4,2], [3,2], [4,1] ]),
26
21 : (512, [ [3,2], [4,2], [4,1] ]),
27
22 : (512, [ [4,2], [4,2], [4,1] ]),
28
23 : (256, [ [3,2], [3,2], [3,2], [2,1] ]),
29
24 : (256, [ [4,2], [3,2], [3,2], [2,1] ]),
30
25 : (256, [ [3,2], [4,2], [3,2], [2,1] ]),
31
26 : (256, [ [4,2], [4,2], [3,2], [2,1] ]),
32
27 : (256, [ [3,2], [4,2], [4,2], [2,1] ]),
33
28 : (256, [ [4,2], [3,2], [4,2], [2,1] ]),
34
29 : (256, [ [3,2], [4,2], [4,2], [2,1] ]),
35
30 : (256, [ [4,2], [4,2], [4,2], [2,1] ]),
36
31 : (256, [ [3,2], [3,2], [3,2], [3,1] ]),
37
32 : (256, [ [4,2], [3,2], [3,2], [3,1] ]),
38
33 : (256, [ [3,2], [4,2], [3,2], [3,1] ]),
39
34 : (256, [ [4,2], [4,2], [3,2], [3,1] ]),
40
35 : (256, [ [3,2], [4,2], [4,2], [3,1] ]),
41
36 : (256, [ [4,2], [3,2], [4,2], [3,1] ]),
42
37 : (256, [ [3,2], [4,2], [4,2], [3,1] ]),
43
38 : (256, [ [4,2], [4,2], [4,2], [3,1] ]),
44
39 : (256, [ [3,2], [3,2], [3,2], [4,1] ]),
45
40 : (256, [ [4,2], [3,2], [3,2], [4,1] ]),
46
41 : (256, [ [3,2], [4,2], [3,2], [4,1] ]),
47
42 : (256, [ [4,2], [4,2], [3,2], [4,1] ]),
48
43 : (256, [ [3,2], [4,2], [4,2], [4,1] ]),
49
44 : (256, [ [4,2], [3,2], [4,2], [4,1] ]),
50
45 : (256, [ [3,2], [4,2], [4,2], [4,1] ]),
51
46 : (256, [ [4,2], [4,2], [4,2], [4,1] ]),
52
}
53
54
55
class PatchDiscriminator(nn.ModelBase):
56
def on_build(self, patch_size, in_ch, base_ch=None, conv_kernel_initializer=None):
57
suggested_base_ch, kernels_strides = patch_discriminator_kernels[patch_size]
58
59
if base_ch is None:
60
base_ch = suggested_base_ch
61
62
prev_ch = in_ch
63
self.convs = []
64
for i, (kernel_size, strides) in enumerate(kernels_strides):
65
cur_ch = base_ch * min( (2**i), 8 )
66
67
self.convs.append ( nn.Conv2D( prev_ch, cur_ch, kernel_size=kernel_size, strides=strides, padding='SAME', kernel_initializer=conv_kernel_initializer) )
68
prev_ch = cur_ch
69
70
self.out_conv = nn.Conv2D( prev_ch, 1, kernel_size=1, padding='VALID', kernel_initializer=conv_kernel_initializer)
71
72
def forward(self, x):
73
for conv in self.convs:
74
x = tf.nn.leaky_relu( conv(x), 0.1 )
75
return self.out_conv(x)
76
77
nn.PatchDiscriminator = PatchDiscriminator
78
79
class UNetPatchDiscriminator(nn.ModelBase):
80
"""
81
Inspired by https://arxiv.org/abs/2002.12655 "A U-Net Based Discriminator for Generative Adversarial Networks"
82
"""
83
def calc_receptive_field_size(self, layers):
84
"""
85
result the same as https://fomoro.com/research/article/receptive-field-calculatorindex.html
86
"""
87
rf = 0
88
ts = 1
89
for i, (k, s) in enumerate(layers):
90
if i == 0:
91
rf = k
92
else:
93
rf += (k-1)*ts
94
ts *= s
95
return rf
96
97
def find_archi(self, target_patch_size, max_layers=9):
98
"""
99
Find the best configuration of layers using only 3x3 convs for target patch size
100
"""
101
s = {}
102
for layers_count in range(1,max_layers+1):
103
val = 1 << (layers_count-1)
104
while True:
105
val -= 1
106
107
layers = []
108
sum_st = 0
109
layers.append ( [3, 2])
110
sum_st += 2
111
for i in range(layers_count-1):
112
st = 1 + (1 if val & (1 << i) !=0 else 0 )
113
layers.append ( [3, st ])
114
sum_st += st
115
116
rf = self.calc_receptive_field_size(layers)
117
118
s_rf = s.get(rf, None)
119
if s_rf is None:
120
s[rf] = (layers_count, sum_st, layers)
121
else:
122
if layers_count < s_rf[0] or \
123
( layers_count == s_rf[0] and sum_st > s_rf[1] ):
124
s[rf] = (layers_count, sum_st, layers)
125
126
if val == 0:
127
break
128
129
x = sorted(list(s.keys()))
130
q=x[np.abs(np.array(x)-target_patch_size).argmin()]
131
return s[q][2]
132
133
def on_build(self, patch_size, in_ch, base_ch = 16, use_fp16 = False):
134
self.use_fp16 = use_fp16
135
conv_dtype = tf.float16 if use_fp16 else tf.float32
136
137
class ResidualBlock(nn.ModelBase):
138
def on_build(self, ch, kernel_size=3 ):
139
self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype)
140
self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype)
141
142
def forward(self, inp):
143
x = self.conv1(inp)
144
x = tf.nn.leaky_relu(x, 0.2)
145
x = self.conv2(x)
146
x = tf.nn.leaky_relu(inp + x, 0.2)
147
return x
148
149
prev_ch = in_ch
150
self.convs = []
151
self.upconvs = []
152
layers = self.find_archi(patch_size)
153
154
level_chs = { i-1:v for i,v in enumerate([ min( base_ch * (2**i), 512 ) for i in range(len(layers)+1)]) }
155
156
self.in_conv = nn.Conv2D( in_ch, level_chs[-1], kernel_size=1, padding='VALID', dtype=conv_dtype)
157
158
for i, (kernel_size, strides) in enumerate(layers):
159
self.convs.append ( nn.Conv2D( level_chs[i-1], level_chs[i], kernel_size=kernel_size, strides=strides, padding='SAME', dtype=conv_dtype) )
160
161
self.upconvs.insert (0, nn.Conv2DTranspose( level_chs[i]*(2 if i != len(layers)-1 else 1), level_chs[i-1], kernel_size=kernel_size, strides=strides, padding='SAME', dtype=conv_dtype) )
162
163
self.out_conv = nn.Conv2D( level_chs[-1]*2, 1, kernel_size=1, padding='VALID', dtype=conv_dtype)
164
165
self.center_out = nn.Conv2D( level_chs[len(layers)-1], 1, kernel_size=1, padding='VALID', dtype=conv_dtype)
166
self.center_conv = nn.Conv2D( level_chs[len(layers)-1], level_chs[len(layers)-1], kernel_size=1, padding='VALID', dtype=conv_dtype)
167
168
169
def forward(self, x):
170
if self.use_fp16:
171
x = tf.cast(x, tf.float16)
172
173
x = tf.nn.leaky_relu( self.in_conv(x), 0.2 )
174
175
encs = []
176
for conv in self.convs:
177
encs.insert(0, x)
178
x = tf.nn.leaky_relu( conv(x), 0.2 )
179
180
center_out, x = self.center_out(x), tf.nn.leaky_relu( self.center_conv(x), 0.2 )
181
182
for i, (upconv, enc) in enumerate(zip(self.upconvs, encs)):
183
x = tf.nn.leaky_relu( upconv(x), 0.2 )
184
x = tf.concat( [enc, x], axis=nn.conv2d_ch_axis)
185
186
x = self.out_conv(x)
187
188
if self.use_fp16:
189
center_out = tf.cast(center_out, tf.float32)
190
x = tf.cast(x, tf.float32)
191
192
return center_out, x
193
194
nn.UNetPatchDiscriminator = UNetPatchDiscriminator
195
196