Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
iperov
GitHub Repository: iperov/deepfacelab
Path: blob/master/core/leras/layers/BlurPool.py
628 views
1
import numpy as np
2
from core.leras import nn
3
tf = nn.tf
4
5
class BlurPool(nn.LayerBase):
6
def __init__(self, filt_size=3, stride=2, **kwargs ):
7
8
if nn.data_format == "NHWC":
9
self.strides = [1,stride,stride,1]
10
else:
11
self.strides = [1,1,stride,stride]
12
13
self.filt_size = filt_size
14
pad = [ int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2)) ]
15
16
if nn.data_format == "NHWC":
17
self.padding = [ [0,0], pad, pad, [0,0] ]
18
else:
19
self.padding = [ [0,0], [0,0], pad, pad ]
20
21
if(self.filt_size==1):
22
a = np.array([1.,])
23
elif(self.filt_size==2):
24
a = np.array([1., 1.])
25
elif(self.filt_size==3):
26
a = np.array([1., 2., 1.])
27
elif(self.filt_size==4):
28
a = np.array([1., 3., 3., 1.])
29
elif(self.filt_size==5):
30
a = np.array([1., 4., 6., 4., 1.])
31
elif(self.filt_size==6):
32
a = np.array([1., 5., 10., 10., 5., 1.])
33
elif(self.filt_size==7):
34
a = np.array([1., 6., 15., 20., 15., 6., 1.])
35
36
a = a[:,None]*a[None,:]
37
a = a / np.sum(a)
38
a = a[:,:,None,None]
39
self.a = a
40
super().__init__(**kwargs)
41
42
def build_weights(self):
43
self.k = tf.constant (self.a, dtype=nn.floatx )
44
45
def forward(self, x):
46
k = tf.tile (self.k, (1,1,x.shape[nn.conv2d_ch_axis],1) )
47
x = tf.pad(x, self.padding )
48
x = tf.nn.depthwise_conv2d(x, k, self.strides, 'VALID', data_format=nn.data_format)
49
return x
50
nn.BlurPool = BlurPool
51