Path: blob/master/core/leras/optimizers/OptimizerBase.py
628 views
import copy1from core.leras import nn2tf = nn.tf34class OptimizerBase(nn.Saveable):5def __init__(self, name=None):6super().__init__(name=name)78def tf_clip_norm(self, g, c, n):9"""Clip the gradient `g` if the L2 norm `n` exceeds `c`.10# Arguments11g: Tensor, the gradient tensor12c: float >= 0. Gradients will be clipped13when their L2 norm exceeds this value.14n: Tensor, actual norm of `g`.15# Returns16Tensor, the gradient clipped if required.17"""18if c <= 0: # if clipnorm == 0 no need to add ops to the graph19return g2021condition = n >= c22then_expression = tf.scalar_mul(c / n, g)23else_expression = g2425# saving the shape to avoid converting sparse tensor to dense26if isinstance(then_expression, tf.Tensor):27g_shape = copy.copy(then_expression.get_shape())28elif isinstance(then_expression, tf.IndexedSlices):29g_shape = copy.copy(then_expression.dense_shape)30if condition.dtype != tf.bool:31condition = tf.cast(condition, 'bool')32g = tf.cond(condition,33lambda: then_expression,34lambda: else_expression)35if isinstance(then_expression, tf.Tensor):36g.set_shape(g_shape)37elif isinstance(then_expression, tf.IndexedSlices):38g._dense_shape = g_shape3940return g41nn.OptimizerBase = OptimizerBase424344