Path: blob/master/core/leras/layers/Conv2DTranspose.py
628 views
import numpy as np1from core.leras import nn2tf = nn.tf34class Conv2DTranspose(nn.LayerBase):5"""6use_wscale enables weight scale (equalized learning rate)7if kernel_initializer is None, it will be forced to random_normal8"""9def __init__(self, in_ch, out_ch, kernel_size, strides=2, padding='SAME', use_bias=True, use_wscale=False, kernel_initializer=None, bias_initializer=None, trainable=True, dtype=None, **kwargs ):10if not isinstance(strides, int):11raise ValueError ("strides must be an int type")12kernel_size = int(kernel_size)1314if dtype is None:15dtype = nn.floatx1617self.in_ch = in_ch18self.out_ch = out_ch19self.kernel_size = kernel_size20self.strides = strides21self.padding = padding22self.use_bias = use_bias23self.use_wscale = use_wscale24self.kernel_initializer = kernel_initializer25self.bias_initializer = bias_initializer26self.trainable = trainable27self.dtype = dtype28super().__init__(**kwargs)2930def build_weights(self):31kernel_initializer = self.kernel_initializer32if self.use_wscale:33gain = 1.0 if self.kernel_size == 1 else np.sqrt(2)34fan_in = self.kernel_size*self.kernel_size*self.in_ch35he_std = gain / np.sqrt(fan_in) # He init36self.wscale = tf.constant(he_std, dtype=self.dtype )37if kernel_initializer is None:38kernel_initializer = tf.initializers.random_normal(0, 1.0, dtype=self.dtype)3940#if kernel_initializer is None:41# kernel_initializer = nn.initializers.ca()42self.weight = tf.get_variable("weight", (self.kernel_size,self.kernel_size,self.out_ch,self.in_ch), dtype=self.dtype, initializer=kernel_initializer, trainable=self.trainable )4344if self.use_bias:45bias_initializer = self.bias_initializer46if bias_initializer is None:47bias_initializer = tf.initializers.zeros(dtype=self.dtype)4849self.bias = tf.get_variable("bias", (self.out_ch,), dtype=self.dtype, initializer=bias_initializer, trainable=self.trainable )5051def get_weights(self):52weights = [self.weight]53if self.use_bias:54weights += [self.bias]55return weights5657def forward(self, x):58shape = x.shape5960if nn.data_format == "NHWC":61h,w,c = shape[1], shape[2], shape[3]62output_shape = tf.stack ( (tf.shape(x)[0],63self.deconv_length(w, self.strides, self.kernel_size, self.padding),64self.deconv_length(h, self.strides, self.kernel_size, self.padding),65self.out_ch) )6667strides = [1,self.strides,self.strides,1]68else:69c,h,w = shape[1], shape[2], shape[3]70output_shape = tf.stack ( (tf.shape(x)[0],71self.out_ch,72self.deconv_length(w, self.strides, self.kernel_size, self.padding),73self.deconv_length(h, self.strides, self.kernel_size, self.padding),74) )75strides = [1,1,self.strides,self.strides]76weight = self.weight77if self.use_wscale:78weight = weight * self.wscale7980x = tf.nn.conv2d_transpose(x, weight, output_shape, strides, padding=self.padding, data_format=nn.data_format)8182if self.use_bias:83if nn.data_format == "NHWC":84bias = tf.reshape (self.bias, (1,1,1,self.out_ch) )85else:86bias = tf.reshape (self.bias, (1,self.out_ch,1,1) )87x = tf.add(x, bias)88return x8990def __str__(self):91r = f"{self.__class__.__name__} : in_ch:{self.in_ch} out_ch:{self.out_ch} "9293return r9495def deconv_length(self, dim_size, stride_size, kernel_size, padding):96assert padding in {'SAME', 'VALID', 'FULL'}97if dim_size is None:98return None99if padding == 'VALID':100dim_size = dim_size * stride_size + max(kernel_size - stride_size, 0)101elif padding == 'FULL':102dim_size = dim_size * stride_size - (stride_size + kernel_size - 2)103elif padding == 'SAME':104dim_size = dim_size * stride_size105return dim_size106nn.Conv2DTranspose = Conv2DTranspose107108