Skip to content

Commit 9229fca

Browse files
A. Unique TensorFlowertensorflower-gardener
A. Unique TensorFlower
authored andcommitted
ADD: Inequality based asserts and tests.
* is_strictly_increasing(x) * is_non_decreasing(x) * assert_positive(x) * assert_non_negative(x) * assert_negative(x) * assert_non_positive(x) * assert_less(x, y) * assert_less_equal(x, y) Change: 118992401
1 parent cca5d0a commit 9229fca

File tree

3 files changed

+588
-5
lines changed

3 files changed

+588
-5
lines changed

tensorflow/contrib/framework/__init__.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,17 @@
1414
# ==============================================================================
1515
"""Framework utilities.
1616
17+
@@assert_negative
18+
@@assert_positive
19+
@@assert_non_negative
20+
@@assert_non_positive
21+
@@assert_less
22+
@@assert_less_equal
1723
@@assert_same_float_dtype
18-
@@is_numeric_tensor
1924
@@assert_scalar_int
25+
@@is_numeric_tensor
26+
@@is_non_decreasing
27+
@@is_strictly_increasing
2028
@@local_variable
2129
@@reduce_sum_n
2230
@@with_shape

tensorflow/contrib/framework/python/framework/tensor_util.py

+242-4
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,17 @@
1515

1616
"""Tensor utility functions.
1717
18+
@@assert_negative
19+
@@assert_positive
20+
@@assert_non_negative
21+
@@assert_non_positive
22+
@@assert_less
23+
@@assert_less_equal
1824
@@assert_same_float_dtype
19-
@@is_numeric_tensor
2025
@@assert_scalar_int
26+
@@is_numeric_tensor
27+
@@is_non_decreasing
28+
@@is_strictly_increasing
2129
@@local_variable
2230
@@reduce_sum_n
2331
@@with_shape
@@ -30,13 +38,18 @@
3038
from tensorflow.python.framework import dtypes
3139
from tensorflow.python.framework import ops
3240
from tensorflow.python.ops import array_ops
41+
from tensorflow.python.ops import control_flow_ops
3342
from tensorflow.python.ops import logging_ops
3443
from tensorflow.python.ops import math_ops
3544
from tensorflow.python.ops import variables
3645

3746
__all__ = [
3847
'assert_same_float_dtype', 'is_numeric_tensor', 'assert_scalar_int',
39-
'local_variable', 'reduce_sum_n', 'with_shape', 'with_same_shape']
48+
'local_variable', 'reduce_sum_n', 'with_shape', 'with_same_shape',
49+
'assert_positive', 'assert_negative', 'assert_non_positive',
50+
'assert_non_negative', 'assert_less', 'assert_less_equal',
51+
'is_strictly_increasing', 'is_non_decreasing',
52+
]
4053

4154

4255
NUMERIC_TYPES = frozenset([dtypes.float32, dtypes.float64, dtypes.int8,
@@ -45,8 +58,229 @@
4558
dtypes.quint8, dtypes.complex64])
4659

4760

48-
def is_numeric_tensor(tensor):
49-
return isinstance(tensor, ops.Tensor) and tensor.dtype in NUMERIC_TYPES
61+
def assert_negative(x, data=None, summarize=None, name=None):
62+
"""Assert the condition `x < 0` holds element-wise.
63+
64+
Negative means, for every element `x[i]` of `x`, we have `x[i] < 0`.
65+
If `x` is empty this is trivially satisfied.
66+
67+
Args:
68+
x: Numeric `Tensor`.
69+
data: The tensors to print out if the condition is False. Defaults to
70+
error message and first few entries of `x`.
71+
summarize: Print this many entries of each tensor.
72+
name: A name for this operation (optional). Defaults to "assert_negative".
73+
74+
Returns:
75+
Op raising `InvalidArgumentError` unless `x` is all negative.
76+
"""
77+
with ops.op_scope([x, data], name, 'assert_negative'):
78+
x = ops.convert_to_tensor(x, name='x')
79+
if data is None:
80+
data = ['Condition x < 0 did not hold element-wise: x = ', x.name, x]
81+
zero = ops.convert_to_tensor(0, dtype=x.dtype)
82+
return assert_less(x, zero, data=data, summarize=summarize)
83+
84+
85+
def assert_positive(x, data=None, summarize=None, name=None):
86+
"""Assert the condition `x > 0` holds element-wise.
87+
88+
Positive means, for every element `x[i]` of `x`, we have `x[i] > 0`.
89+
If `x` is empty this is trivially satisfied.
90+
91+
Args:
92+
x: Numeric `Tensor`.
93+
data: The tensors to print out if the condition is False. Defaults to
94+
error message and first few entries of `x`.
95+
summarize: Print this many entries of each tensor.
96+
name: A name for this operation (optional). Defaults to "assert_negative".
97+
98+
Returns:
99+
Op raising `InvalidArgumentError` unless `x` is all positive.
100+
"""
101+
with ops.op_scope([x, data], name, 'assert_positive'):
102+
x = ops.convert_to_tensor(x, name='x')
103+
if data is None:
104+
data = ['Condition x > 0 did not hold element-wise: x = ', x.name, x]
105+
zero = ops.convert_to_tensor(0, dtype=x.dtype)
106+
return assert_less(zero, x, data=data, summarize=summarize)
107+
108+
109+
def assert_non_negative(x, data=None, summarize=None, name=None):
110+
"""Assert the condition `x >= 0` holds element-wise.
111+
112+
Non-negative means, for every element `x[i]` of `x`, we have `x[i] >= 0`.
113+
If `x` is empty this is trivially satisfied.
114+
115+
Args:
116+
x: Numeric `Tensor`.
117+
data: The tensors to print out if the condition is False. Defaults to
118+
error message and first few entries of `x`.
119+
summarize: Print this many entries of each tensor.
120+
name: A name for this operation (optional).
121+
Defaults to "assert_non_negative".
122+
123+
Returns:
124+
Op raising `InvalidArgumentError` unless `x` is all non-negative.
125+
"""
126+
with ops.op_scope([x, data], name, 'assert_non_negative'):
127+
x = ops.convert_to_tensor(x, name='x')
128+
if data is None:
129+
data = ['Condition x >= 0 did not hold element-wise: x = ', x.name, x]
130+
zero = ops.convert_to_tensor(0, dtype=x.dtype)
131+
return assert_less_equal(zero, x, data=data, summarize=summarize)
132+
133+
134+
def assert_non_positive(x, data=None, summarize=None, name=None):
135+
"""Assert the condition `x <= 0` holds element-wise.
136+
137+
Non-positive means, for every element `x[i]` of `x`, we have `x[i] <= 0`.
138+
If `x` is empty this is trivially satisfied.
139+
140+
Args:
141+
x: Numeric `Tensor`.
142+
data: The tensors to print out if the condition is False. Defaults to
143+
error message and first few entries of `x`.
144+
summarize: Print this many entries of each tensor.
145+
name: A name for this operation (optional).
146+
Defaults to "assert_non_positive".
147+
148+
Returns:
149+
Op raising `InvalidArgumentError` unless `x` is all non-positive.
150+
"""
151+
with ops.op_scope([x, data], name, 'assert_non_positive'):
152+
x = ops.convert_to_tensor(x, name='x')
153+
if data is None:
154+
data = ['Condition x <= 0 did not hold element-wise: x = ', x.name, x]
155+
zero = ops.convert_to_tensor(0, dtype=x.dtype)
156+
return assert_less_equal(x, zero, data=data, summarize=summarize)
157+
158+
159+
def assert_less(x, y, data=None, summarize=None, name=None):
160+
"""Assert the condition `x < y` holds element-wise.
161+
162+
This condition holds if for every pair of (possibly broadcast) elements
163+
`x[i]`, `y[i]`, we have `x[i] < y[i]`.
164+
If both `x` and `y` are empty, this is trivially satisfied.
165+
166+
Args:
167+
x: Numeric `Tensor`.
168+
y: Numeric `Tensor`, same dtype as and broadcastable to `x`.
169+
data: The tensors to print out if the condition is False. Defaults to
170+
error message and first few entries of `x`, `y`.
171+
summarize: Print this many entries of each tensor.
172+
name: A name for this operation (optional). Defaults to "assert_less".
173+
174+
Returns:
175+
Op that raises `InvalidArgumentError` if `x < y` is False.
176+
"""
177+
with ops.op_scope([x, y, data], name, 'assert_less'):
178+
x = ops.convert_to_tensor(x, name='x')
179+
y = ops.convert_to_tensor(y, name='y')
180+
if data is None:
181+
data = [
182+
'Condition x < y did not hold element-wise: x = ',
183+
x.name, x, 'y = ', y.name, y]
184+
condition = math_ops.reduce_all(math_ops.less(x, y))
185+
return logging_ops.Assert(condition, data, summarize=summarize)
186+
187+
188+
def assert_less_equal(x, y, data=None, summarize=None, name=None):
189+
"""Assert the condition `x <= y` holds element-wise.
190+
191+
This condition holds if for every pair of (possibly broadcast) elements
192+
`x[i]`, `y[i]`, we have `x[i] <= y[i]`.
193+
If both `x` and `y` are empty, this is trivially satisfied.
194+
195+
Args:
196+
x: Numeric `Tensor`.
197+
y: Numeric `Tensor`, same dtype as and broadcastable to `x`.
198+
data: The tensors to print out if the condition is False. Defaults to
199+
error message and first few entries of `x`, `y`.
200+
summarize: Print this many entries of each tensor.
201+
name: A name for this operation (optional). Defaults to "assert_less_equal"
202+
203+
Returns:
204+
Op that raises `InvalidArgumentError` if `x <= y` is False.
205+
"""
206+
with ops.op_scope([x, y, data], name, 'assert_less_equal'):
207+
x = ops.convert_to_tensor(x, name='x')
208+
y = ops.convert_to_tensor(y, name='y')
209+
if data is None:
210+
data = [
211+
'Condition x <= y did not hold element-wise: x = ',
212+
x.name, x, 'y = ', y.name, y]
213+
condition = math_ops.reduce_all(math_ops.less_equal(x, y))
214+
return logging_ops.Assert(condition, data, summarize=summarize)
215+
216+
217+
def _get_diff_for_monotonic_comparison(x):
218+
"""Gets the difference x[1:] - x[:-1]."""
219+
x = array_ops.reshape(x, [-1])
220+
if not is_numeric_tensor(x):
221+
raise ValueError('Expected x to be numeric, instead found: %s' % x)
222+
223+
# If x has less than 2 elements, there is nothing to compare. So return [].
224+
is_shorter_than_two = math_ops.less(array_ops.size(x), 2)
225+
short_result = lambda: ops.convert_to_tensor([], dtype=x.dtype)
226+
227+
# With 2 or more elements, return x[1:] - x[:-1]
228+
s_len = array_ops.shape(x) - 1
229+
diff = lambda: array_ops.slice(x, [1], s_len) - array_ops.slice(x, [0], s_len)
230+
return control_flow_ops.cond(is_shorter_than_two, short_result, diff)
231+
232+
233+
def is_non_decreasing(x, name=None):
234+
"""Returns `True` if `x` is non-decreasing.
235+
236+
Elements of `x` are compared in row-major order. The tensor `[x[0],...]`
237+
is non-decreasing if for every adjacent pair we have `x[i] <= x[i+1]`.
238+
If `x` has less than two elements, it is trivially non-decreasing.
239+
240+
See also: `is_strictly_increasing`
241+
242+
Args:
243+
x: Numeric `Tensor`.
244+
name: A name for this operation (optional). Defaults to "is_non_decreasing"
245+
246+
Returns:
247+
Boolean `Tensor`, equal to `True` iff `x` is non-decreasing.
248+
249+
Raises:
250+
ValueError: if `x` is not a numeric tensor.
251+
"""
252+
with ops.op_scope([x], name, 'is_non_decreasing'):
253+
diff = _get_diff_for_monotonic_comparison(x)
254+
# When len(x) = 1, diff = [], less_equal = [], and reduce_all([]) = True.
255+
zero = ops.convert_to_tensor(0, dtype=diff.dtype)
256+
return math_ops.reduce_all(math_ops.less_equal(zero, diff))
257+
258+
259+
def is_strictly_increasing(x, name=None):
260+
"""Returns `True` if `x` is strictly increasing.
261+
262+
Elements of `x` are compared in row-major order. The tensor `[x[0],...]`
263+
is strictly increasing if for every adjacent pair we have `x[i] < x[i+1]`.
264+
If `x` has less than two elements, it is trivially strictly increasing.
265+
266+
See also: `is_non_decreasing`
267+
268+
Args:
269+
x: Numeric `Tensor`.
270+
name: A name for this operation (optional).
271+
Defaults to "is_strictly_increasing"
272+
273+
Returns:
274+
Boolean `Tensor`, equal to `True` iff `x` is strictly increasing.
275+
276+
Raises:
277+
ValueError: if `x` is not a numeric tensor.
278+
"""
279+
with ops.op_scope([x], name, 'is_strictly_increasing'):
280+
diff = _get_diff_for_monotonic_comparison(x)
281+
# When len(x) = 1, diff = [], less = [], and reduce_all([]) = True.
282+
zero = ops.convert_to_tensor(0, dtype=diff.dtype)
283+
return math_ops.reduce_all(math_ops.less(zero, diff))
50284

51285

52286
def _assert_same_base_type(items, expected_type=None):
@@ -126,6 +360,10 @@ def assert_scalar_int(tensor):
126360
return tensor
127361

128362

363+
def is_numeric_tensor(tensor):
364+
return isinstance(tensor, ops.Tensor) and tensor.dtype in NUMERIC_TYPES
365+
366+
129367
# TODO(ptucker): Move to tf.variables?
130368
def local_variable(initial_value, validate_shape=True, name=None):
131369
"""Create variable and add it to `GraphKeys.LOCAL_VARIABLES` collection.

0 commit comments

Comments
 (0)