Skip to content

Commit

Permalink
Update docstring and comments
Browse files Browse the repository at this point in the history
  • Loading branch information
MattConley committed Apr 9, 2019
1 parent 83bfd7e commit bc68bc0
Showing 1 changed file with 19 additions and 12 deletions.
31 changes: 19 additions & 12 deletions tensorflow/python/training/loss_scale_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,30 @@
class LossScaleOptimizer(optimizer.Optimizer):
"""An optimizer that applies loss scaling.
Loss scaling is a process that multiplies the loss by a multiplier called the
loss scale, and divides each gradient by the same multiplier. The pseudocode
for this process is:
```
loss = ...
loss *= loss_scale
grads = gradients(loss, vars)
grads /= loss_scale
```
Mathematically, loss scaling has no effect, but can help avoid numerical
underflow in intermediate gradients when float16 tensors are used. By
multiplying the loss, each intermediate gradient will have the same multiplier
applied.
The loss scale can either be a fixed constant, chosen by the user, or be
dynamically determined. Dynamically determining the loss scale is convenient
as a loss scale does not have to be explicitly chosen. However it reduces
performance.
This optimizer wraps another optimizer and applies loss scaling to it. Loss
scaling is applied whenever gradients are computed.
Args:
opt: The Optimizer instance to wrap.
loss_scale: The loss scale or LossScale class to scale the loss and
gradients. This can either be an int/float to use a fixed loss scale,
the string "dynamic" to use dynamic loss scaling, or an instance of a
LossScale class. The string "dynamic" is equivalent to passing
`DynamicLossScale()`, and passing an int/float is equivalent
to passing a FixedLossScale instance with the given loss scale.
This optimizer wraps another optimizer and applies loss scaling to it via a
`LossScale`. Loss scaling is applied whenever gradients are
computed, such as through `minimize()`.
"""
def __init__(self, opt, loss_scale):
if not isinstance(opt, optimizer.Optimizer):
Expand Down Expand Up @@ -113,7 +121,6 @@ def compute_gradients(self, loss, var_list=None,
return list(zip(scaled_grads, variables))

def _scale_loss(self, loss):
# The loss is callable for `_compute_gradients`, but not `get_gradients`.
loss_scale = self._loss_scale()
if callable(loss):
return lambda: loss() * loss_scale
Expand Down

0 comments on commit bc68bc0

Please sign in to comment.