Skip to content

Commit

Permalink
Internal change:
Browse files Browse the repository at this point in the history
Allow resetting colocation groups via colocate_with(None, ignore...=True)
Change: 140052949
  • Loading branch information
ebrevdo authored and tensorflower-gardener committed Nov 23, 2016
1 parent 0699425 commit 7eac949
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 8 deletions.
25 changes: 17 additions & 8 deletions tensorflow/python/framework/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2959,24 +2959,31 @@ def colocate_with(self, op, ignore_existing=False):
`b` and `c` will always be colocated with `a`, no matter where `a`
is eventually placed.
**NOTE** Using a colocation scope resets any existing device constraints.
If `op` is `None` then `ignore_existing` must be `True` and the new
scope resets all colocation and device constraints.
Args:
op: The op to colocate all created ops with.
op: The op to colocate all created ops with, or `None`.
ignore_existing: If true, only applies colocation of this op within
the context, rather than applying all colocation properties
on the stack.
on the stack. If `op` is `None`, this value must be `True`.
Raises:
ValueError: if op is None.
ValueError: if op is None but ignore_existing is False.
Yields:
A context manager that specifies the op with which to colocate
newly created ops.
"""
if op is None:
raise ValueError("Tried to colocate with None")
if op is None and not ignore_existing:
raise ValueError(
"Trying to reset colocation (op is None) but "
"ignore_existing is not True")

if not isinstance(op, Operation):
if op is not None and not isinstance(op, Operation):
# We always want to colocate with the reference op.
op = internal_convert_to_tensor_or_indexed_slices(op, as_ref=True).op

Expand All @@ -2994,14 +3001,16 @@ def colocate_with(self, op, ignore_existing=False):
current_stack = self._colocation_stack
self._colocation_stack = []

self._colocation_stack.append(op)
if op is not None:
self._colocation_stack.append(op)

try:
yield
finally:
# Restore device function stack
self._device_function_stack = device_fn_tmp
self._colocation_stack.pop()
if op is not None:
self._colocation_stack.pop()

# Reset the colocation stack if requested.
if ignore_existing:
Expand Down
19 changes: 19 additions & 0 deletions tensorflow/python/framework/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1489,6 +1489,25 @@ def testColocationIgnoreStack(self):
c = constant_op.constant(4.0)
self.assertEqual(set([b"loc:@b"]), set(c.op.colocation_groups()))

def testColocateWithReset(self):
a = constant_op.constant([2.0], name="a")
with ops.colocate_with(a.op):
b = constant_op.constant(3.0, name="b")
with ops.colocate_with(None, ignore_existing=True):
c = constant_op.constant(4.0, name="c")
self.assertEqual([b"loc:@a"], b.op.colocation_groups())
self.assertEqual([b"loc:@c"], c.op.colocation_groups())

def testColocateWithInitialNoneThenNested(self):
a = constant_op.constant([2.0], name="a")
with ops.colocate_with(a.op):
with ops.colocate_with(None, ignore_existing=True):
b = constant_op.constant(3.0, name="b")
with ops.colocate_with(b.op):
c = constant_op.constant(4.0, name="c")
self.assertEqual([b"loc:@b"], b.op.colocation_groups())
self.assertEqual([b"loc:@b"], c.op.colocation_groups())

def testColocateVariables(self):
a = variables.Variable([2.0], name="a")
with ops.colocate_with(a.op):
Expand Down

0 comments on commit 7eac949

Please sign in to comment.