Skip to content

Commit

Permalink
Update function names for V2 consistency
Browse files Browse the repository at this point in the history
  • Loading branch information
MattConley committed Apr 9, 2019
1 parent bc68bc0 commit cd4bd9d
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions tensorflow/python/training/loss_scale_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,10 @@ def _scale_loss(self, loss):
def _scale_grads(self, grads):
loss_scale = self._loss_scale()
loss_scale_reciprical = 1 / loss_scale
return [None if g is None else self._indexed_slices(
return [None if g is None else self._scale_grad(
g, loss_scale_reciprical) for g in grads]

def _indexed_slices(self, grad, loss_scale_reciprical):
def _scale_grad(self, grad, loss_scale_reciprical):
if isinstance(grad, ops.IndexedSlices):
grad_vals = grad.values * loss_scale_reciprical
return ops.IndexedSlices(grad_vals, grad.indices, grad.dense_shape)
Expand Down Expand Up @@ -170,7 +170,7 @@ def apply_gradients(self, grads_and_vars, global_step=None, name=None):

# TODO(nluehr) cleanup GraphKeys.TRAIN_OP
return replica_context.merge_call(
self._maybe_apply_gradients_cross_replica,
self._apply_gradients_cross_replica,
args=(grads_and_vars, global_step, name))

def _distributed_apply(self,
Expand All @@ -197,10 +197,10 @@ def _distributed_apply(self,
replicas. If `global_step` was not None, that operation also
increments `global_step`
"""
self._maybe_apply_gradients_cross_replica(distribution, grads_and_vars,
self._apply_gradients_cross_replica(distribution, grads_and_vars,
global_step, name)

def _maybe_apply_gradients_cross_replica(self, distribution, grads_and_vars,
def _apply_gradients_cross_replica(self, distribution, grads_and_vars,
global_step, name):
"""Conditionally apply gradients in cross replica context."""
name = name if name is not None else self.get_name()
Expand Down

0 comments on commit cd4bd9d

Please sign in to comment.