Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
iperov
GitHub Repository: iperov/deepfacelab
Path: blob/master/core/leras/models/XSeg.py
628 views
1
from core.leras import nn
2
tf = nn.tf
3
4
class XSeg(nn.ModelBase):
5
6
def on_build (self, in_ch, base_ch, out_ch):
7
8
class ConvBlock(nn.ModelBase):
9
def on_build(self, in_ch, out_ch):
10
self.conv = nn.Conv2D (in_ch, out_ch, kernel_size=3, padding='SAME')
11
self.frn = nn.FRNorm2D(out_ch)
12
self.tlu = nn.TLU(out_ch)
13
14
def forward(self, x):
15
x = self.conv(x)
16
x = self.frn(x)
17
x = self.tlu(x)
18
return x
19
20
class UpConvBlock(nn.ModelBase):
21
def on_build(self, in_ch, out_ch):
22
self.conv = nn.Conv2DTranspose (in_ch, out_ch, kernel_size=3, padding='SAME')
23
self.frn = nn.FRNorm2D(out_ch)
24
self.tlu = nn.TLU(out_ch)
25
26
def forward(self, x):
27
x = self.conv(x)
28
x = self.frn(x)
29
x = self.tlu(x)
30
return x
31
32
self.base_ch = base_ch
33
34
self.conv01 = ConvBlock(in_ch, base_ch)
35
self.conv02 = ConvBlock(base_ch, base_ch)
36
self.bp0 = nn.BlurPool (filt_size=4)
37
38
self.conv11 = ConvBlock(base_ch, base_ch*2)
39
self.conv12 = ConvBlock(base_ch*2, base_ch*2)
40
self.bp1 = nn.BlurPool (filt_size=3)
41
42
self.conv21 = ConvBlock(base_ch*2, base_ch*4)
43
self.conv22 = ConvBlock(base_ch*4, base_ch*4)
44
self.bp2 = nn.BlurPool (filt_size=2)
45
46
self.conv31 = ConvBlock(base_ch*4, base_ch*8)
47
self.conv32 = ConvBlock(base_ch*8, base_ch*8)
48
self.conv33 = ConvBlock(base_ch*8, base_ch*8)
49
self.bp3 = nn.BlurPool (filt_size=2)
50
51
self.conv41 = ConvBlock(base_ch*8, base_ch*8)
52
self.conv42 = ConvBlock(base_ch*8, base_ch*8)
53
self.conv43 = ConvBlock(base_ch*8, base_ch*8)
54
self.bp4 = nn.BlurPool (filt_size=2)
55
56
self.conv51 = ConvBlock(base_ch*8, base_ch*8)
57
self.conv52 = ConvBlock(base_ch*8, base_ch*8)
58
self.conv53 = ConvBlock(base_ch*8, base_ch*8)
59
self.bp5 = nn.BlurPool (filt_size=2)
60
61
self.dense1 = nn.Dense ( 4*4* base_ch*8, 512)
62
self.dense2 = nn.Dense ( 512, 4*4* base_ch*8)
63
64
self.up5 = UpConvBlock (base_ch*8, base_ch*4)
65
self.uconv53 = ConvBlock(base_ch*12, base_ch*8)
66
self.uconv52 = ConvBlock(base_ch*8, base_ch*8)
67
self.uconv51 = ConvBlock(base_ch*8, base_ch*8)
68
69
self.up4 = UpConvBlock (base_ch*8, base_ch*4)
70
self.uconv43 = ConvBlock(base_ch*12, base_ch*8)
71
self.uconv42 = ConvBlock(base_ch*8, base_ch*8)
72
self.uconv41 = ConvBlock(base_ch*8, base_ch*8)
73
74
self.up3 = UpConvBlock (base_ch*8, base_ch*4)
75
self.uconv33 = ConvBlock(base_ch*12, base_ch*8)
76
self.uconv32 = ConvBlock(base_ch*8, base_ch*8)
77
self.uconv31 = ConvBlock(base_ch*8, base_ch*8)
78
79
self.up2 = UpConvBlock (base_ch*8, base_ch*4)
80
self.uconv22 = ConvBlock(base_ch*8, base_ch*4)
81
self.uconv21 = ConvBlock(base_ch*4, base_ch*4)
82
83
self.up1 = UpConvBlock (base_ch*4, base_ch*2)
84
self.uconv12 = ConvBlock(base_ch*4, base_ch*2)
85
self.uconv11 = ConvBlock(base_ch*2, base_ch*2)
86
87
self.up0 = UpConvBlock (base_ch*2, base_ch)
88
self.uconv02 = ConvBlock(base_ch*2, base_ch)
89
self.uconv01 = ConvBlock(base_ch, base_ch)
90
self.out_conv = nn.Conv2D (base_ch, out_ch, kernel_size=3, padding='SAME')
91
92
93
def forward(self, inp, pretrain=False):
94
x = inp
95
96
x = self.conv01(x)
97
x = x0 = self.conv02(x)
98
x = self.bp0(x)
99
100
x = self.conv11(x)
101
x = x1 = self.conv12(x)
102
x = self.bp1(x)
103
104
x = self.conv21(x)
105
x = x2 = self.conv22(x)
106
x = self.bp2(x)
107
108
x = self.conv31(x)
109
x = self.conv32(x)
110
x = x3 = self.conv33(x)
111
x = self.bp3(x)
112
113
x = self.conv41(x)
114
x = self.conv42(x)
115
x = x4 = self.conv43(x)
116
x = self.bp4(x)
117
118
x = self.conv51(x)
119
x = self.conv52(x)
120
x = x5 = self.conv53(x)
121
x = self.bp5(x)
122
123
x = nn.flatten(x)
124
x = self.dense1(x)
125
x = self.dense2(x)
126
x = nn.reshape_4D (x, 4, 4, self.base_ch*8 )
127
128
x = self.up5(x)
129
if pretrain:
130
x5 = tf.zeros_like(x5)
131
x = self.uconv53(tf.concat([x,x5],axis=nn.conv2d_ch_axis))
132
x = self.uconv52(x)
133
x = self.uconv51(x)
134
135
x = self.up4(x)
136
if pretrain:
137
x4 = tf.zeros_like(x4)
138
x = self.uconv43(tf.concat([x,x4],axis=nn.conv2d_ch_axis))
139
x = self.uconv42(x)
140
x = self.uconv41(x)
141
142
x = self.up3(x)
143
if pretrain:
144
x3 = tf.zeros_like(x3)
145
x = self.uconv33(tf.concat([x,x3],axis=nn.conv2d_ch_axis))
146
x = self.uconv32(x)
147
x = self.uconv31(x)
148
149
x = self.up2(x)
150
if pretrain:
151
x2 = tf.zeros_like(x2)
152
x = self.uconv22(tf.concat([x,x2],axis=nn.conv2d_ch_axis))
153
x = self.uconv21(x)
154
155
x = self.up1(x)
156
if pretrain:
157
x1 = tf.zeros_like(x1)
158
x = self.uconv12(tf.concat([x,x1],axis=nn.conv2d_ch_axis))
159
x = self.uconv11(x)
160
161
x = self.up0(x)
162
if pretrain:
163
x0 = tf.zeros_like(x0)
164
x = self.uconv02(tf.concat([x,x0],axis=nn.conv2d_ch_axis))
165
x = self.uconv01(x)
166
167
logits = self.out_conv(x)
168
return logits, tf.nn.sigmoid(logits)
169
170
nn.XSeg = XSeg
171