Path: blob/master/core/leras/layers/InstanceNorm2D.py
628 views
from core.leras import nn1tf = nn.tf23class InstanceNorm2D(nn.LayerBase):4def __init__(self, in_ch, dtype=None, **kwargs):5self.in_ch = in_ch67if dtype is None:8dtype = nn.floatx9self.dtype = dtype1011super().__init__(**kwargs)1213def build_weights(self):14kernel_initializer = tf.initializers.glorot_uniform(dtype=self.dtype)15self.weight = tf.get_variable("weight", (self.in_ch,), dtype=self.dtype, initializer=kernel_initializer )16self.bias = tf.get_variable("bias", (self.in_ch,), dtype=self.dtype, initializer=tf.initializers.zeros() )1718def get_weights(self):19return [self.weight, self.bias]2021def forward(self, x):22if nn.data_format == "NHWC":23shape = (1,1,1,self.in_ch)24else:25shape = (1,self.in_ch,1,1)2627weight = tf.reshape ( self.weight , shape )28bias = tf.reshape ( self.bias , shape )2930x_mean = tf.reduce_mean(x, axis=nn.conv2d_spatial_axes, keepdims=True )31x_std = tf.math.reduce_std(x, axis=nn.conv2d_spatial_axes, keepdims=True ) + 1e-53233x = (x - x_mean) / x_std34x *= weight35x += bias3637return x3839nn.InstanceNorm2D = InstanceNorm2D4041