Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
iperov
GitHub Repository: iperov/deepfacelab
Path: blob/master/core/leras/optimizers/OptimizerBase.py
628 views
1
import copy
2
from core.leras import nn
3
tf = nn.tf
4
5
class OptimizerBase(nn.Saveable):
6
def __init__(self, name=None):
7
super().__init__(name=name)
8
9
def tf_clip_norm(self, g, c, n):
10
"""Clip the gradient `g` if the L2 norm `n` exceeds `c`.
11
# Arguments
12
g: Tensor, the gradient tensor
13
c: float >= 0. Gradients will be clipped
14
when their L2 norm exceeds this value.
15
n: Tensor, actual norm of `g`.
16
# Returns
17
Tensor, the gradient clipped if required.
18
"""
19
if c <= 0: # if clipnorm == 0 no need to add ops to the graph
20
return g
21
22
condition = n >= c
23
then_expression = tf.scalar_mul(c / n, g)
24
else_expression = g
25
26
# saving the shape to avoid converting sparse tensor to dense
27
if isinstance(then_expression, tf.Tensor):
28
g_shape = copy.copy(then_expression.get_shape())
29
elif isinstance(then_expression, tf.IndexedSlices):
30
g_shape = copy.copy(then_expression.dense_shape)
31
if condition.dtype != tf.bool:
32
condition = tf.cast(condition, 'bool')
33
g = tf.cond(condition,
34
lambda: then_expression,
35
lambda: else_expression)
36
if isinstance(then_expression, tf.Tensor):
37
g.set_shape(g_shape)
38
elif isinstance(then_expression, tf.IndexedSlices):
39
g._dense_shape = g_shape
40
41
return g
42
nn.OptimizerBase = OptimizerBase
43
44