Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
iperov
GitHub Repository: iperov/deepfacelab
Path: blob/master/core/leras/layers/Dense.py
628 views
1
import numpy as np
2
from core.leras import nn
3
tf = nn.tf
4
5
class Dense(nn.LayerBase):
6
def __init__(self, in_ch, out_ch, use_bias=True, use_wscale=False, maxout_ch=0, kernel_initializer=None, bias_initializer=None, trainable=True, dtype=None, **kwargs ):
7
"""
8
use_wscale enables weight scale (equalized learning rate)
9
if kernel_initializer is None, it will be forced to random_normal
10
11
maxout_ch https://link.springer.com/article/10.1186/s40537-019-0233-0
12
typical 2-4 if you want to enable DenseMaxout behaviour
13
"""
14
self.in_ch = in_ch
15
self.out_ch = out_ch
16
self.use_bias = use_bias
17
self.use_wscale = use_wscale
18
self.maxout_ch = maxout_ch
19
self.kernel_initializer = kernel_initializer
20
self.bias_initializer = bias_initializer
21
self.trainable = trainable
22
if dtype is None:
23
dtype = nn.floatx
24
25
self.dtype = dtype
26
super().__init__(**kwargs)
27
28
def build_weights(self):
29
if self.maxout_ch > 1:
30
weight_shape = (self.in_ch,self.out_ch*self.maxout_ch)
31
else:
32
weight_shape = (self.in_ch,self.out_ch)
33
34
kernel_initializer = self.kernel_initializer
35
36
if self.use_wscale:
37
gain = 1.0
38
fan_in = np.prod( weight_shape[:-1] )
39
he_std = gain / np.sqrt(fan_in) # He init
40
self.wscale = tf.constant(he_std, dtype=self.dtype )
41
if kernel_initializer is None:
42
kernel_initializer = tf.initializers.random_normal(0, 1.0, dtype=self.dtype)
43
44
if kernel_initializer is None:
45
kernel_initializer = tf.initializers.glorot_uniform(dtype=self.dtype)
46
47
self.weight = tf.get_variable("weight", weight_shape, dtype=self.dtype, initializer=kernel_initializer, trainable=self.trainable )
48
49
if self.use_bias:
50
bias_initializer = self.bias_initializer
51
if bias_initializer is None:
52
bias_initializer = tf.initializers.zeros(dtype=self.dtype)
53
self.bias = tf.get_variable("bias", (self.out_ch,), dtype=self.dtype, initializer=bias_initializer, trainable=self.trainable )
54
55
def get_weights(self):
56
weights = [self.weight]
57
if self.use_bias:
58
weights += [self.bias]
59
return weights
60
61
def forward(self, x):
62
weight = self.weight
63
if self.use_wscale:
64
weight = weight * self.wscale
65
66
x = tf.matmul(x, weight)
67
68
if self.maxout_ch > 1:
69
x = tf.reshape (x, (-1, self.out_ch, self.maxout_ch) )
70
x = tf.reduce_max(x, axis=-1)
71
72
if self.use_bias:
73
x = tf.add(x, tf.reshape(self.bias, (1,self.out_ch) ) )
74
75
return x
76
nn.Dense = Dense
77