Path: blob/master/core/leras/models/PatchDiscriminator.py
628 views
import numpy as np1from core.leras import nn2tf = nn.tf34patch_discriminator_kernels = \5{ 1 : (512, [ [1,1] ]),62 : (512, [ [2,1] ]),73 : (512, [ [2,1], [2,1] ]),84 : (512, [ [2,2], [2,2] ]),95 : (512, [ [3,2], [2,2] ]),106 : (512, [ [4,2], [2,2] ]),117 : (512, [ [3,2], [3,2] ]),128 : (512, [ [4,2], [3,2] ]),139 : (512, [ [3,2], [4,2] ]),1410 : (512, [ [4,2], [4,2] ]),1511 : (512, [ [3,2], [3,2], [2,1] ]),1612 : (512, [ [4,2], [3,2], [2,1] ]),1713 : (512, [ [3,2], [4,2], [2,1] ]),1814 : (512, [ [4,2], [4,2], [2,1] ]),1915 : (512, [ [3,2], [3,2], [3,1] ]),2016 : (512, [ [4,2], [3,2], [3,1] ]),2117 : (512, [ [3,2], [4,2], [3,1] ]),2218 : (512, [ [4,2], [4,2], [3,1] ]),2319 : (512, [ [3,2], [3,2], [4,1] ]),2420 : (512, [ [4,2], [3,2], [4,1] ]),2521 : (512, [ [3,2], [4,2], [4,1] ]),2622 : (512, [ [4,2], [4,2], [4,1] ]),2723 : (256, [ [3,2], [3,2], [3,2], [2,1] ]),2824 : (256, [ [4,2], [3,2], [3,2], [2,1] ]),2925 : (256, [ [3,2], [4,2], [3,2], [2,1] ]),3026 : (256, [ [4,2], [4,2], [3,2], [2,1] ]),3127 : (256, [ [3,2], [4,2], [4,2], [2,1] ]),3228 : (256, [ [4,2], [3,2], [4,2], [2,1] ]),3329 : (256, [ [3,2], [4,2], [4,2], [2,1] ]),3430 : (256, [ [4,2], [4,2], [4,2], [2,1] ]),3531 : (256, [ [3,2], [3,2], [3,2], [3,1] ]),3632 : (256, [ [4,2], [3,2], [3,2], [3,1] ]),3733 : (256, [ [3,2], [4,2], [3,2], [3,1] ]),3834 : (256, [ [4,2], [4,2], [3,2], [3,1] ]),3935 : (256, [ [3,2], [4,2], [4,2], [3,1] ]),4036 : (256, [ [4,2], [3,2], [4,2], [3,1] ]),4137 : (256, [ [3,2], [4,2], [4,2], [3,1] ]),4238 : (256, [ [4,2], [4,2], [4,2], [3,1] ]),4339 : (256, [ [3,2], [3,2], [3,2], [4,1] ]),4440 : (256, [ [4,2], [3,2], [3,2], [4,1] ]),4541 : (256, [ [3,2], [4,2], [3,2], [4,1] ]),4642 : (256, [ [4,2], [4,2], [3,2], [4,1] ]),4743 : (256, [ [3,2], [4,2], [4,2], [4,1] ]),4844 : (256, [ [4,2], [3,2], [4,2], [4,1] ]),4945 : (256, [ [3,2], [4,2], [4,2], [4,1] ]),5046 : (256, [ [4,2], [4,2], [4,2], [4,1] ]),51}525354class PatchDiscriminator(nn.ModelBase):55def on_build(self, patch_size, in_ch, base_ch=None, conv_kernel_initializer=None):56suggested_base_ch, kernels_strides = patch_discriminator_kernels[patch_size]5758if base_ch is None:59base_ch = suggested_base_ch6061prev_ch = in_ch62self.convs = []63for i, (kernel_size, strides) in enumerate(kernels_strides):64cur_ch = base_ch * min( (2**i), 8 )6566self.convs.append ( nn.Conv2D( prev_ch, cur_ch, kernel_size=kernel_size, strides=strides, padding='SAME', kernel_initializer=conv_kernel_initializer) )67prev_ch = cur_ch6869self.out_conv = nn.Conv2D( prev_ch, 1, kernel_size=1, padding='VALID', kernel_initializer=conv_kernel_initializer)7071def forward(self, x):72for conv in self.convs:73x = tf.nn.leaky_relu( conv(x), 0.1 )74return self.out_conv(x)7576nn.PatchDiscriminator = PatchDiscriminator7778class UNetPatchDiscriminator(nn.ModelBase):79"""80Inspired by https://arxiv.org/abs/2002.12655 "A U-Net Based Discriminator for Generative Adversarial Networks"81"""82def calc_receptive_field_size(self, layers):83"""84result the same as https://fomoro.com/research/article/receptive-field-calculatorindex.html85"""86rf = 087ts = 188for i, (k, s) in enumerate(layers):89if i == 0:90rf = k91else:92rf += (k-1)*ts93ts *= s94return rf9596def find_archi(self, target_patch_size, max_layers=9):97"""98Find the best configuration of layers using only 3x3 convs for target patch size99"""100s = {}101for layers_count in range(1,max_layers+1):102val = 1 << (layers_count-1)103while True:104val -= 1105106layers = []107sum_st = 0108layers.append ( [3, 2])109sum_st += 2110for i in range(layers_count-1):111st = 1 + (1 if val & (1 << i) !=0 else 0 )112layers.append ( [3, st ])113sum_st += st114115rf = self.calc_receptive_field_size(layers)116117s_rf = s.get(rf, None)118if s_rf is None:119s[rf] = (layers_count, sum_st, layers)120else:121if layers_count < s_rf[0] or \122( layers_count == s_rf[0] and sum_st > s_rf[1] ):123s[rf] = (layers_count, sum_st, layers)124125if val == 0:126break127128x = sorted(list(s.keys()))129q=x[np.abs(np.array(x)-target_patch_size).argmin()]130return s[q][2]131132def on_build(self, patch_size, in_ch, base_ch = 16, use_fp16 = False):133self.use_fp16 = use_fp16134conv_dtype = tf.float16 if use_fp16 else tf.float32135136class ResidualBlock(nn.ModelBase):137def on_build(self, ch, kernel_size=3 ):138self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype)139self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype)140141def forward(self, inp):142x = self.conv1(inp)143x = tf.nn.leaky_relu(x, 0.2)144x = self.conv2(x)145x = tf.nn.leaky_relu(inp + x, 0.2)146return x147148prev_ch = in_ch149self.convs = []150self.upconvs = []151layers = self.find_archi(patch_size)152153level_chs = { i-1:v for i,v in enumerate([ min( base_ch * (2**i), 512 ) for i in range(len(layers)+1)]) }154155self.in_conv = nn.Conv2D( in_ch, level_chs[-1], kernel_size=1, padding='VALID', dtype=conv_dtype)156157for i, (kernel_size, strides) in enumerate(layers):158self.convs.append ( nn.Conv2D( level_chs[i-1], level_chs[i], kernel_size=kernel_size, strides=strides, padding='SAME', dtype=conv_dtype) )159160self.upconvs.insert (0, nn.Conv2DTranspose( level_chs[i]*(2 if i != len(layers)-1 else 1), level_chs[i-1], kernel_size=kernel_size, strides=strides, padding='SAME', dtype=conv_dtype) )161162self.out_conv = nn.Conv2D( level_chs[-1]*2, 1, kernel_size=1, padding='VALID', dtype=conv_dtype)163164self.center_out = nn.Conv2D( level_chs[len(layers)-1], 1, kernel_size=1, padding='VALID', dtype=conv_dtype)165self.center_conv = nn.Conv2D( level_chs[len(layers)-1], level_chs[len(layers)-1], kernel_size=1, padding='VALID', dtype=conv_dtype)166167168def forward(self, x):169if self.use_fp16:170x = tf.cast(x, tf.float16)171172x = tf.nn.leaky_relu( self.in_conv(x), 0.2 )173174encs = []175for conv in self.convs:176encs.insert(0, x)177x = tf.nn.leaky_relu( conv(x), 0.2 )178179center_out, x = self.center_out(x), tf.nn.leaky_relu( self.center_conv(x), 0.2 )180181for i, (upconv, enc) in enumerate(zip(self.upconvs, encs)):182x = tf.nn.leaky_relu( upconv(x), 0.2 )183x = tf.concat( [enc, x], axis=nn.conv2d_ch_axis)184185x = self.out_conv(x)186187if self.use_fp16:188center_out = tf.cast(center_out, tf.float32)189x = tf.cast(x, tf.float32)190191return center_out, x192193nn.UNetPatchDiscriminator = UNetPatchDiscriminator194195196