Skip to content

Commit

Permalink
Put CPU-only ops in ProdGrad in a device('/cpu:0') block
Browse files Browse the repository at this point in the history
Fixes tensorflow#3957.
Change: 132025448
  • Loading branch information
girving authored and tensorflower-gardener committed Sep 2, 2016
1 parent e9ec815 commit 182fef1
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions tensorflow/python/ops/math_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,15 @@ def _ProdGrad(op, grad):

# Pack all reduced dimensions into a single one, so we can perform the
# cumprod ops. If the reduction dims list is empty, it defaults to float32,
# so we need to cast here.
reduced = math_ops.cast(op.inputs[1], dtypes.int32)
idx = math_ops.range(0, array_ops.rank(op.inputs[0]))
other, _ = array_ops.listdiff(idx, reduced)
perm = array_ops.concat(0, [reduced, other])
reduced_num = math_ops.reduce_prod(array_ops.gather(input_shape, reduced))
other_num = math_ops.reduce_prod(array_ops.gather(input_shape, other))
# so we need to cast here. We put all the shape-related ops on CPU to avoid
# copying back and forth, and since listdiff is CPU only.
with ops.device("/cpu:0"):
reduced = math_ops.cast(op.inputs[1], dtypes.int32)
idx = math_ops.range(0, array_ops.rank(op.inputs[0]))
other, _ = array_ops.listdiff(idx, reduced)
perm = array_ops.concat(0, [reduced, other])
reduced_num = math_ops.reduce_prod(array_ops.gather(input_shape, reduced))
other_num = math_ops.reduce_prod(array_ops.gather(input_shape, other))
permuted = array_ops.transpose(op.inputs[0], perm)
permuted_shape = array_ops.shape(permuted)
reshaped = array_ops.reshape(permuted, (reduced_num, other_num))
Expand Down

0 comments on commit 182fef1

Please sign in to comment.