Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
iperov
GitHub Repository: iperov/deepfacelab
Path: blob/master/core/leras/layers/BatchNorm2D.py
628 views
1
from core.leras import nn
2
tf = nn.tf
3
4
class BatchNorm2D(nn.LayerBase):
5
"""
6
currently not for training
7
"""
8
def __init__(self, dim, eps=1e-05, momentum=0.1, dtype=None, **kwargs):
9
self.dim = dim
10
self.eps = eps
11
self.momentum = momentum
12
if dtype is None:
13
dtype = nn.floatx
14
self.dtype = dtype
15
super().__init__(**kwargs)
16
17
def build_weights(self):
18
self.weight = tf.get_variable("weight", (self.dim,), dtype=self.dtype, initializer=tf.initializers.ones() )
19
self.bias = tf.get_variable("bias", (self.dim,), dtype=self.dtype, initializer=tf.initializers.zeros() )
20
self.running_mean = tf.get_variable("running_mean", (self.dim,), dtype=self.dtype, initializer=tf.initializers.zeros(), trainable=False )
21
self.running_var = tf.get_variable("running_var", (self.dim,), dtype=self.dtype, initializer=tf.initializers.zeros(), trainable=False )
22
23
def get_weights(self):
24
return [self.weight, self.bias, self.running_mean, self.running_var]
25
26
def forward(self, x):
27
if nn.data_format == "NHWC":
28
shape = (1,1,1,self.dim)
29
else:
30
shape = (1,self.dim,1,1)
31
32
weight = tf.reshape ( self.weight , shape )
33
bias = tf.reshape ( self.bias , shape )
34
running_mean = tf.reshape ( self.running_mean, shape )
35
running_var = tf.reshape ( self.running_var , shape )
36
37
x = (x - running_mean) / tf.sqrt( running_var + self.eps )
38
x *= weight
39
x += bias
40
return x
41
42
nn.BatchNorm2D = BatchNorm2D
43