Skip to content

Commit

Permalink
Fix slicing for variables
Browse files Browse the repository at this point in the history
tf.Variable's operator overloaded code had hard coded arities for specific
operators.  Now it's simpler and more general.  As a consequence, var[...]
now works.
Change: 130592832
  • Loading branch information
girving authored and tensorflower-gardener committed Aug 18, 2016
1 parent 1532027 commit 0aa130a
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 29 deletions.
9 changes: 9 additions & 0 deletions tensorflow/python/kernel_tests/array_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,15 @@ def testExpand(self):
# Ellipsis in middle of two newaxis
_ = checker[np.newaxis, ..., np.newaxis]

def testExpandVariable(self):
for use_gpu in False, True:
with self.test_session(use_gpu=use_gpu):
x = tf.Variable(7, dtype=tf.int32)
x.initializer.run()
y = x[None].eval()
self.assertEqual(y.shape, (1,))
self.assertAllEqual(y, (7,))

def testOptimizedCases(self):
for use_gpu in [False, True]:
with self.test_session(use_gpu=use_gpu):
Expand Down
36 changes: 7 additions & 29 deletions tensorflow/python/ops/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,12 +590,6 @@ def _TensorConversionFunction(v, dtype=None, name=None, as_ref=False):
else:
return v.value()

# Operator overloading.
#
# To carry over all overloaded operators from ops.Tensor to Variable, we
# register the _RunOp() static method as the implementation of all operators.
# That function dynamically discovers the overloaded operator in ops.Tensor
# and invokes it after converting the Variable to a tensor.
@staticmethod
def _OverloadAllOperators():
"""Register overloads for all operators."""
Expand All @@ -604,15 +598,17 @@ def _OverloadAllOperators():

@staticmethod
def _OverloadOperator(operator):
"""Register _RunOp as the implementation of 'operator'.
"""Defer an operator overload to `ops.Tensor`.
We pull the operator out of ops.Tensor dynamically to avoid ordering issues.
Args:
operator: string. The operator name.
"""
if operator in ["__invert__", "__neg__", "__abs__"]:
setattr(Variable, operator, lambda a: Variable._RunOp(operator, a, None))
else:
setattr(Variable, operator, lambda a, b: Variable._RunOp(operator, a, b))
def _run_op(a, *args):
# pylint: disable=protected-access
return getattr(ops.Tensor, operator)(a._AsTensor(), *args)
setattr(Variable, operator, _run_op)

# NOTE(mrry): This enables the Variable's overloaded "right" binary
# operators to run when the left operand is an ndarray, because it
Expand All @@ -623,24 +619,6 @@ def _OverloadOperator(operator):
# with ndarrays.
__array_priority__ = 100

@staticmethod
def _RunOp(operator, a, b):
"""Run the operator 'op' for 'a'.
Args:
operator: string. The operator name.
a: A Variable.
b: Second argument to the operator. None if unary.
Returns:
The result of the operator.
"""
# pylint: disable=protected-access
if b is not None:
return getattr(ops.Tensor, operator)(a._AsTensor(), b)
else:
return getattr(ops.Tensor, operator)(a._AsTensor())
# pylint: enable=protected-access

@property
def name(self):
"""The name of this variable."""
Expand Down

0 comments on commit 0aa130a

Please sign in to comment.