Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
iperov
GitHub Repository: iperov/deepfacelab
Path: blob/master/core/leras/optimizers/RMSprop.py
628 views
1
import numpy as np
2
from tensorflow.python.ops import control_flow_ops, state_ops
3
from core.leras import nn
4
tf = nn.tf
5
6
class RMSprop(nn.OptimizerBase):
7
def __init__(self, lr=0.001, rho=0.9, lr_dropout=1.0, lr_cos=0, clipnorm=0.0, name=None, **kwargs):
8
super().__init__(name=name)
9
10
if name is None:
11
raise ValueError('name must be defined.')
12
13
self.lr_dropout = lr_dropout
14
self.lr_cos = lr_cos
15
self.lr = lr
16
self.rho = rho
17
self.clipnorm = clipnorm
18
19
with tf.device('/CPU:0') :
20
with tf.variable_scope(self.name):
21
22
self.iterations = tf.Variable(0, dtype=tf.int64, name='iters')
23
24
self.accumulators_dict = {}
25
self.lr_rnds_dict = {}
26
27
def get_weights(self):
28
return [self.iterations] + list(self.accumulators_dict.values())
29
30
def initialize_variables(self, trainable_weights, vars_on_cpu=True, lr_dropout_on_cpu=False):
31
# Initialize here all trainable variables used in training
32
e = tf.device('/CPU:0') if vars_on_cpu else None
33
if e: e.__enter__()
34
with tf.variable_scope(self.name):
35
accumulators = { v.name : tf.get_variable ( f'acc_{v.name}'.replace(':','_'), v.shape, dtype=v.dtype, initializer=tf.initializers.constant(0.0), trainable=False) for v in trainable_weights }
36
self.accumulators_dict.update ( accumulators)
37
38
if self.lr_dropout != 1.0:
39
e = tf.device('/CPU:0') if lr_dropout_on_cpu else None
40
if e: e.__enter__()
41
lr_rnds = [ nn.random_binomial( v.shape, p=self.lr_dropout, dtype=v.dtype) for v in trainable_weights ]
42
if e: e.__exit__(None, None, None)
43
self.lr_rnds_dict.update ( { v.name : rnd for v,rnd in zip(trainable_weights,lr_rnds) } )
44
if e: e.__exit__(None, None, None)
45
46
def get_update_op(self, grads_vars):
47
updates = []
48
49
if self.clipnorm > 0.0:
50
norm = tf.sqrt( sum([tf.reduce_sum(tf.square(tf.cast(g, tf.float32))) for g,v in grads_vars]))
51
updates += [ state_ops.assign_add( self.iterations, 1) ]
52
for i, (g,v) in enumerate(grads_vars):
53
if self.clipnorm > 0.0:
54
g = self.tf_clip_norm(g, self.clipnorm, tf.cast(norm, g.dtype) )
55
56
a = self.accumulators_dict[ v.name ]
57
58
new_a = self.rho * a + (1. - self.rho) * tf.square(g)
59
60
lr = tf.constant(self.lr, g.dtype)
61
if self.lr_cos != 0:
62
lr *= (tf.cos( tf.cast(self.iterations, g.dtype) * (2*3.1415926535/ float(self.lr_cos) ) ) + 1.0) / 2.0
63
64
v_diff = - lr * g / (tf.sqrt(new_a) + np.finfo( g.dtype.as_numpy_dtype ).resolution )
65
if self.lr_dropout != 1.0:
66
lr_rnd = self.lr_rnds_dict[v.name]
67
v_diff *= lr_rnd
68
new_v = v + v_diff
69
70
updates.append (state_ops.assign(a, new_a))
71
updates.append (state_ops.assign(v, new_v))
72
73
return control_flow_ops.group ( *updates, name=self.name+'_updates')
74
nn.RMSprop = RMSprop
75