|
15 | 15 |
|
16 | 16 | """Tensor utility functions.
|
17 | 17 |
|
| 18 | +@@assert_negative |
| 19 | +@@assert_positive |
| 20 | +@@assert_non_negative |
| 21 | +@@assert_non_positive |
| 22 | +@@assert_less |
| 23 | +@@assert_less_equal |
18 | 24 | @@assert_same_float_dtype
|
19 |
| -@@is_numeric_tensor |
20 | 25 | @@assert_scalar_int
|
| 26 | +@@is_numeric_tensor |
| 27 | +@@is_non_decreasing |
| 28 | +@@is_strictly_increasing |
21 | 29 | @@local_variable
|
22 | 30 | @@reduce_sum_n
|
23 | 31 | @@with_shape
|
|
30 | 38 | from tensorflow.python.framework import dtypes
|
31 | 39 | from tensorflow.python.framework import ops
|
32 | 40 | from tensorflow.python.ops import array_ops
|
| 41 | +from tensorflow.python.ops import control_flow_ops |
33 | 42 | from tensorflow.python.ops import logging_ops
|
34 | 43 | from tensorflow.python.ops import math_ops
|
35 | 44 | from tensorflow.python.ops import variables
|
36 | 45 |
|
37 | 46 | __all__ = [
|
38 | 47 | 'assert_same_float_dtype', 'is_numeric_tensor', 'assert_scalar_int',
|
39 |
| - 'local_variable', 'reduce_sum_n', 'with_shape', 'with_same_shape'] |
| 48 | + 'local_variable', 'reduce_sum_n', 'with_shape', 'with_same_shape', |
| 49 | + 'assert_positive', 'assert_negative', 'assert_non_positive', |
| 50 | + 'assert_non_negative', 'assert_less', 'assert_less_equal', |
| 51 | + 'is_strictly_increasing', 'is_non_decreasing', |
| 52 | +] |
40 | 53 |
|
41 | 54 |
|
42 | 55 | NUMERIC_TYPES = frozenset([dtypes.float32, dtypes.float64, dtypes.int8,
|
|
45 | 58 | dtypes.quint8, dtypes.complex64])
|
46 | 59 |
|
47 | 60 |
|
48 |
| -def is_numeric_tensor(tensor): |
49 |
| - return isinstance(tensor, ops.Tensor) and tensor.dtype in NUMERIC_TYPES |
| 61 | +def assert_negative(x, data=None, summarize=None, name=None): |
| 62 | + """Assert the condition `x < 0` holds element-wise. |
| 63 | +
|
| 64 | + Negative means, for every element `x[i]` of `x`, we have `x[i] < 0`. |
| 65 | + If `x` is empty this is trivially satisfied. |
| 66 | +
|
| 67 | + Args: |
| 68 | + x: Numeric `Tensor`. |
| 69 | + data: The tensors to print out if the condition is False. Defaults to |
| 70 | + error message and first few entries of `x`. |
| 71 | + summarize: Print this many entries of each tensor. |
| 72 | + name: A name for this operation (optional). Defaults to "assert_negative". |
| 73 | +
|
| 74 | + Returns: |
| 75 | + Op raising `InvalidArgumentError` unless `x` is all negative. |
| 76 | + """ |
| 77 | + with ops.op_scope([x, data], name, 'assert_negative'): |
| 78 | + x = ops.convert_to_tensor(x, name='x') |
| 79 | + if data is None: |
| 80 | + data = ['Condition x < 0 did not hold element-wise: x = ', x.name, x] |
| 81 | + zero = ops.convert_to_tensor(0, dtype=x.dtype) |
| 82 | + return assert_less(x, zero, data=data, summarize=summarize) |
| 83 | + |
| 84 | + |
| 85 | +def assert_positive(x, data=None, summarize=None, name=None): |
| 86 | + """Assert the condition `x > 0` holds element-wise. |
| 87 | +
|
| 88 | + Positive means, for every element `x[i]` of `x`, we have `x[i] > 0`. |
| 89 | + If `x` is empty this is trivially satisfied. |
| 90 | +
|
| 91 | + Args: |
| 92 | + x: Numeric `Tensor`. |
| 93 | + data: The tensors to print out if the condition is False. Defaults to |
| 94 | + error message and first few entries of `x`. |
| 95 | + summarize: Print this many entries of each tensor. |
| 96 | + name: A name for this operation (optional). Defaults to "assert_negative". |
| 97 | +
|
| 98 | + Returns: |
| 99 | + Op raising `InvalidArgumentError` unless `x` is all positive. |
| 100 | + """ |
| 101 | + with ops.op_scope([x, data], name, 'assert_positive'): |
| 102 | + x = ops.convert_to_tensor(x, name='x') |
| 103 | + if data is None: |
| 104 | + data = ['Condition x > 0 did not hold element-wise: x = ', x.name, x] |
| 105 | + zero = ops.convert_to_tensor(0, dtype=x.dtype) |
| 106 | + return assert_less(zero, x, data=data, summarize=summarize) |
| 107 | + |
| 108 | + |
| 109 | +def assert_non_negative(x, data=None, summarize=None, name=None): |
| 110 | + """Assert the condition `x >= 0` holds element-wise. |
| 111 | +
|
| 112 | + Non-negative means, for every element `x[i]` of `x`, we have `x[i] >= 0`. |
| 113 | + If `x` is empty this is trivially satisfied. |
| 114 | +
|
| 115 | + Args: |
| 116 | + x: Numeric `Tensor`. |
| 117 | + data: The tensors to print out if the condition is False. Defaults to |
| 118 | + error message and first few entries of `x`. |
| 119 | + summarize: Print this many entries of each tensor. |
| 120 | + name: A name for this operation (optional). |
| 121 | + Defaults to "assert_non_negative". |
| 122 | +
|
| 123 | + Returns: |
| 124 | + Op raising `InvalidArgumentError` unless `x` is all non-negative. |
| 125 | + """ |
| 126 | + with ops.op_scope([x, data], name, 'assert_non_negative'): |
| 127 | + x = ops.convert_to_tensor(x, name='x') |
| 128 | + if data is None: |
| 129 | + data = ['Condition x >= 0 did not hold element-wise: x = ', x.name, x] |
| 130 | + zero = ops.convert_to_tensor(0, dtype=x.dtype) |
| 131 | + return assert_less_equal(zero, x, data=data, summarize=summarize) |
| 132 | + |
| 133 | + |
| 134 | +def assert_non_positive(x, data=None, summarize=None, name=None): |
| 135 | + """Assert the condition `x <= 0` holds element-wise. |
| 136 | +
|
| 137 | + Non-positive means, for every element `x[i]` of `x`, we have `x[i] <= 0`. |
| 138 | + If `x` is empty this is trivially satisfied. |
| 139 | +
|
| 140 | + Args: |
| 141 | + x: Numeric `Tensor`. |
| 142 | + data: The tensors to print out if the condition is False. Defaults to |
| 143 | + error message and first few entries of `x`. |
| 144 | + summarize: Print this many entries of each tensor. |
| 145 | + name: A name for this operation (optional). |
| 146 | + Defaults to "assert_non_positive". |
| 147 | +
|
| 148 | + Returns: |
| 149 | + Op raising `InvalidArgumentError` unless `x` is all non-positive. |
| 150 | + """ |
| 151 | + with ops.op_scope([x, data], name, 'assert_non_positive'): |
| 152 | + x = ops.convert_to_tensor(x, name='x') |
| 153 | + if data is None: |
| 154 | + data = ['Condition x <= 0 did not hold element-wise: x = ', x.name, x] |
| 155 | + zero = ops.convert_to_tensor(0, dtype=x.dtype) |
| 156 | + return assert_less_equal(x, zero, data=data, summarize=summarize) |
| 157 | + |
| 158 | + |
| 159 | +def assert_less(x, y, data=None, summarize=None, name=None): |
| 160 | + """Assert the condition `x < y` holds element-wise. |
| 161 | +
|
| 162 | + This condition holds if for every pair of (possibly broadcast) elements |
| 163 | + `x[i]`, `y[i]`, we have `x[i] < y[i]`. |
| 164 | + If both `x` and `y` are empty, this is trivially satisfied. |
| 165 | +
|
| 166 | + Args: |
| 167 | + x: Numeric `Tensor`. |
| 168 | + y: Numeric `Tensor`, same dtype as and broadcastable to `x`. |
| 169 | + data: The tensors to print out if the condition is False. Defaults to |
| 170 | + error message and first few entries of `x`, `y`. |
| 171 | + summarize: Print this many entries of each tensor. |
| 172 | + name: A name for this operation (optional). Defaults to "assert_less". |
| 173 | +
|
| 174 | + Returns: |
| 175 | + Op that raises `InvalidArgumentError` if `x < y` is False. |
| 176 | + """ |
| 177 | + with ops.op_scope([x, y, data], name, 'assert_less'): |
| 178 | + x = ops.convert_to_tensor(x, name='x') |
| 179 | + y = ops.convert_to_tensor(y, name='y') |
| 180 | + if data is None: |
| 181 | + data = [ |
| 182 | + 'Condition x < y did not hold element-wise: x = ', |
| 183 | + x.name, x, 'y = ', y.name, y] |
| 184 | + condition = math_ops.reduce_all(math_ops.less(x, y)) |
| 185 | + return logging_ops.Assert(condition, data, summarize=summarize) |
| 186 | + |
| 187 | + |
| 188 | +def assert_less_equal(x, y, data=None, summarize=None, name=None): |
| 189 | + """Assert the condition `x <= y` holds element-wise. |
| 190 | +
|
| 191 | + This condition holds if for every pair of (possibly broadcast) elements |
| 192 | + `x[i]`, `y[i]`, we have `x[i] <= y[i]`. |
| 193 | + If both `x` and `y` are empty, this is trivially satisfied. |
| 194 | +
|
| 195 | + Args: |
| 196 | + x: Numeric `Tensor`. |
| 197 | + y: Numeric `Tensor`, same dtype as and broadcastable to `x`. |
| 198 | + data: The tensors to print out if the condition is False. Defaults to |
| 199 | + error message and first few entries of `x`, `y`. |
| 200 | + summarize: Print this many entries of each tensor. |
| 201 | + name: A name for this operation (optional). Defaults to "assert_less_equal" |
| 202 | +
|
| 203 | + Returns: |
| 204 | + Op that raises `InvalidArgumentError` if `x <= y` is False. |
| 205 | + """ |
| 206 | + with ops.op_scope([x, y, data], name, 'assert_less_equal'): |
| 207 | + x = ops.convert_to_tensor(x, name='x') |
| 208 | + y = ops.convert_to_tensor(y, name='y') |
| 209 | + if data is None: |
| 210 | + data = [ |
| 211 | + 'Condition x <= y did not hold element-wise: x = ', |
| 212 | + x.name, x, 'y = ', y.name, y] |
| 213 | + condition = math_ops.reduce_all(math_ops.less_equal(x, y)) |
| 214 | + return logging_ops.Assert(condition, data, summarize=summarize) |
| 215 | + |
| 216 | + |
| 217 | +def _get_diff_for_monotonic_comparison(x): |
| 218 | + """Gets the difference x[1:] - x[:-1].""" |
| 219 | + x = array_ops.reshape(x, [-1]) |
| 220 | + if not is_numeric_tensor(x): |
| 221 | + raise ValueError('Expected x to be numeric, instead found: %s' % x) |
| 222 | + |
| 223 | + # If x has less than 2 elements, there is nothing to compare. So return []. |
| 224 | + is_shorter_than_two = math_ops.less(array_ops.size(x), 2) |
| 225 | + short_result = lambda: ops.convert_to_tensor([], dtype=x.dtype) |
| 226 | + |
| 227 | + # With 2 or more elements, return x[1:] - x[:-1] |
| 228 | + s_len = array_ops.shape(x) - 1 |
| 229 | + diff = lambda: array_ops.slice(x, [1], s_len) - array_ops.slice(x, [0], s_len) |
| 230 | + return control_flow_ops.cond(is_shorter_than_two, short_result, diff) |
| 231 | + |
| 232 | + |
| 233 | +def is_non_decreasing(x, name=None): |
| 234 | + """Returns `True` if `x` is non-decreasing. |
| 235 | +
|
| 236 | + Elements of `x` are compared in row-major order. The tensor `[x[0],...]` |
| 237 | + is non-decreasing if for every adjacent pair we have `x[i] <= x[i+1]`. |
| 238 | + If `x` has less than two elements, it is trivially non-decreasing. |
| 239 | +
|
| 240 | + See also: `is_strictly_increasing` |
| 241 | +
|
| 242 | + Args: |
| 243 | + x: Numeric `Tensor`. |
| 244 | + name: A name for this operation (optional). Defaults to "is_non_decreasing" |
| 245 | +
|
| 246 | + Returns: |
| 247 | + Boolean `Tensor`, equal to `True` iff `x` is non-decreasing. |
| 248 | +
|
| 249 | + Raises: |
| 250 | + ValueError: if `x` is not a numeric tensor. |
| 251 | + """ |
| 252 | + with ops.op_scope([x], name, 'is_non_decreasing'): |
| 253 | + diff = _get_diff_for_monotonic_comparison(x) |
| 254 | + # When len(x) = 1, diff = [], less_equal = [], and reduce_all([]) = True. |
| 255 | + zero = ops.convert_to_tensor(0, dtype=diff.dtype) |
| 256 | + return math_ops.reduce_all(math_ops.less_equal(zero, diff)) |
| 257 | + |
| 258 | + |
| 259 | +def is_strictly_increasing(x, name=None): |
| 260 | + """Returns `True` if `x` is strictly increasing. |
| 261 | +
|
| 262 | + Elements of `x` are compared in row-major order. The tensor `[x[0],...]` |
| 263 | + is strictly increasing if for every adjacent pair we have `x[i] < x[i+1]`. |
| 264 | + If `x` has less than two elements, it is trivially strictly increasing. |
| 265 | +
|
| 266 | + See also: `is_non_decreasing` |
| 267 | +
|
| 268 | + Args: |
| 269 | + x: Numeric `Tensor`. |
| 270 | + name: A name for this operation (optional). |
| 271 | + Defaults to "is_strictly_increasing" |
| 272 | +
|
| 273 | + Returns: |
| 274 | + Boolean `Tensor`, equal to `True` iff `x` is strictly increasing. |
| 275 | +
|
| 276 | + Raises: |
| 277 | + ValueError: if `x` is not a numeric tensor. |
| 278 | + """ |
| 279 | + with ops.op_scope([x], name, 'is_strictly_increasing'): |
| 280 | + diff = _get_diff_for_monotonic_comparison(x) |
| 281 | + # When len(x) = 1, diff = [], less = [], and reduce_all([]) = True. |
| 282 | + zero = ops.convert_to_tensor(0, dtype=diff.dtype) |
| 283 | + return math_ops.reduce_all(math_ops.less(zero, diff)) |
50 | 284 |
|
51 | 285 |
|
52 | 286 | def _assert_same_base_type(items, expected_type=None):
|
@@ -126,6 +360,10 @@ def assert_scalar_int(tensor):
|
126 | 360 | return tensor
|
127 | 361 |
|
128 | 362 |
|
| 363 | +def is_numeric_tensor(tensor): |
| 364 | + return isinstance(tensor, ops.Tensor) and tensor.dtype in NUMERIC_TYPES |
| 365 | + |
| 366 | + |
129 | 367 | # TODO(ptucker): Move to tf.variables?
|
130 | 368 | def local_variable(initial_value, validate_shape=True, name=None):
|
131 | 369 | """Create variable and add it to `GraphKeys.LOCAL_VARIABLES` collection.
|
|
0 commit comments