From 7eac94936d99e32245dfd5ad9f4bd5714df50755 Mon Sep 17 00:00:00 2001 From: Eugene Brevdo Date: Wed, 23 Nov 2016 11:19:36 -0800 Subject: [PATCH] Internal change: Allow resetting colocation groups via colocate_with(None, ignore...=True) Change: 140052949 --- tensorflow/python/framework/ops.py | 25 +++++++++++++++++-------- tensorflow/python/framework/ops_test.py | 19 +++++++++++++++++++ 2 files changed, 36 insertions(+), 8 deletions(-) diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index c4c64f157750c5..c8d5b117391fe4 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -2959,24 +2959,31 @@ def colocate_with(self, op, ignore_existing=False): `b` and `c` will always be colocated with `a`, no matter where `a` is eventually placed. + **NOTE** Using a colocation scope resets any existing device constraints. + + If `op` is `None` then `ignore_existing` must be `True` and the new + scope resets all colocation and device constraints. + Args: - op: The op to colocate all created ops with. + op: The op to colocate all created ops with, or `None`. ignore_existing: If true, only applies colocation of this op within the context, rather than applying all colocation properties - on the stack. + on the stack. If `op` is `None`, this value must be `True`. Raises: - ValueError: if op is None. + ValueError: if op is None but ignore_existing is False. Yields: A context manager that specifies the op with which to colocate newly created ops. """ - if op is None: - raise ValueError("Tried to colocate with None") + if op is None and not ignore_existing: + raise ValueError( + "Trying to reset colocation (op is None) but " + "ignore_existing is not True") - if not isinstance(op, Operation): + if op is not None and not isinstance(op, Operation): # We always want to colocate with the reference op. op = internal_convert_to_tensor_or_indexed_slices(op, as_ref=True).op @@ -2994,14 +3001,16 @@ def colocate_with(self, op, ignore_existing=False): current_stack = self._colocation_stack self._colocation_stack = [] - self._colocation_stack.append(op) + if op is not None: + self._colocation_stack.append(op) try: yield finally: # Restore device function stack self._device_function_stack = device_fn_tmp - self._colocation_stack.pop() + if op is not None: + self._colocation_stack.pop() # Reset the colocation stack if requested. if ignore_existing: diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py index 57cb74ec494965..3fc3d16d8d7439 100644 --- a/tensorflow/python/framework/ops_test.py +++ b/tensorflow/python/framework/ops_test.py @@ -1489,6 +1489,25 @@ def testColocationIgnoreStack(self): c = constant_op.constant(4.0) self.assertEqual(set([b"loc:@b"]), set(c.op.colocation_groups())) + def testColocateWithReset(self): + a = constant_op.constant([2.0], name="a") + with ops.colocate_with(a.op): + b = constant_op.constant(3.0, name="b") + with ops.colocate_with(None, ignore_existing=True): + c = constant_op.constant(4.0, name="c") + self.assertEqual([b"loc:@a"], b.op.colocation_groups()) + self.assertEqual([b"loc:@c"], c.op.colocation_groups()) + + def testColocateWithInitialNoneThenNested(self): + a = constant_op.constant([2.0], name="a") + with ops.colocate_with(a.op): + with ops.colocate_with(None, ignore_existing=True): + b = constant_op.constant(3.0, name="b") + with ops.colocate_with(b.op): + c = constant_op.constant(4.0, name="c") + self.assertEqual([b"loc:@b"], b.op.colocation_groups()) + self.assertEqual([b"loc:@b"], c.op.colocation_groups()) + def testColocateVariables(self): a = variables.Variable([2.0], name="a") with ops.colocate_with(a.op):