Skip to content

Commit 9fd8e3c

Browse files
yongwwwicemelon
authored andcommitted
[Relay][TOPI] operator All (#3124)
* [Relay][TOPI] operator All * Update tests/python/frontend/tensorflow/test_forward.py Co-Authored-By: yongwww <55wuyong@163.com> * fix comments * change to level 4
1 parent 3a9de90 commit 9fd8e3c

File tree

14 files changed

+232
-22
lines changed

14 files changed

+232
-22
lines changed

docs/api/python/topi.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ List of operators
8888
topi.not_equal
8989
topi.greater_equal
9090
topi.less_equal
91+
topi.all
9192
topi.logical_and
9293
topi.logical_or
9394
topi.logical_not
@@ -140,6 +141,7 @@ topi
140141
.. autofunction:: topi.gather_nd
141142
.. autofunction:: topi.full
142143
.. autofunction:: topi.full_like
144+
.. autofunction:: topi.all
143145
.. autofunction:: topi.max
144146
.. autofunction:: topi.sum
145147
.. autofunction:: topi.min

docs/langref/relay_op.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ This level enables additional math and transform operators.
135135
tvm.relay.greater_equal
136136
tvm.relay.less
137137
tvm.relay.less_equal
138+
tvm.relay.all
138139
tvm.relay.logical_and
139140
tvm.relay.logical_or
140141
tvm.relay.logical_not
@@ -277,6 +278,7 @@ Level 4 Definitions
277278
.. autofunction:: tvm.relay.greater_equal
278279
.. autofunction:: tvm.relay.less
279280
.. autofunction:: tvm.relay.less_equal
281+
.. autofunction:: tvm.relay.all
280282
.. autofunction:: tvm.relay.logical_and
281283
.. autofunction:: tvm.relay.logical_or
282284
.. autofunction:: tvm.relay.logical_not

include/tvm/expr_operator.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,13 @@ TVM_DLL Expr abs(Expr x);
428428
*/
429429
TVM_DLL Expr sum(Expr source, Array<IterVar> axis);
430430

431+
/*!
432+
* \brief logical And of of source expression over axis
433+
* \param source The source expression.
434+
* \param axis List of iteration variables that will be used for reduction.
435+
*/
436+
TVM_DLL Expr all(Expr source, Array<IterVar> axis);
437+
431438
/*!
432439
* \brief max of of source expression over axis
433440
* \param source The source expression.

python/tvm/relay/frontend/tensorflow.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -767,6 +767,17 @@ def _impl(inputs, attr, params):
767767
ignores=['name', 'Tidx'])([inputs[0]], attr)
768768
return _impl
769769

770+
def _reduce_all():
771+
def _impl(inputs, attr, params):
772+
axis = params.pop(inputs[1].name_hint).asnumpy()
773+
axis = tuple(axis)
774+
return AttrCvt(
775+
op_name='all',
776+
extras={'axis': axis},
777+
transforms={'keep_dims':'keepdims'},
778+
ignores=['name', 'Tidx'])([inputs[0]], attr)
779+
return _impl
780+
770781
def _square():
771782
def _impl(inputs, attr, params):
772783
return _op.multiply(inputs[0], inputs[0])
@@ -1180,6 +1191,7 @@ def _impl(inputs, attr, params):
11801191
# for N to 1 mapping, currently not supported(?)
11811192
_convert_map = {
11821193
'Add' : _elemwise('add'),
1194+
'All' : _reduce_all(),
11831195
'ArgMax' : _argx(_op.argmax, 'argmax'),
11841196
'ArgMin' : _argx(_op.argmin, 'argmin'),
11851197
'AvgPool' : _pooling('avg_pool'),

python/tvm/relay/op/_reduce.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def _schedule_reduce(_, outs, target):
3030
_reg.register_schedule("argmax", _schedule_reduce)
3131
_reg.register_schedule("argmin", _schedule_reduce)
3232
_reg.register_schedule("sum", _schedule_reduce)
33+
_reg.register_schedule("all", _schedule_reduce)
3334
_reg.register_schedule("max", _schedule_reduce)
3435
_reg.register_schedule("min", _schedule_reduce)
3536
_reg.register_schedule("prod", _schedule_reduce)

python/tvm/relay/op/reduce.py

Lines changed: 59 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def argmax(data, axis=None, keepdims=False, exclude=False):
3939
4040
exclude : bool
4141
If `exclude` is true, reduction will be performed on the axes that are
42-
NOT in axis instead.
42+
NOT in axis instead.
4343
4444
Returns
4545
-------
@@ -69,7 +69,7 @@ def argmin(data, axis=None, keepdims=False, exclude=False):
6969
7070
exclude : bool
7171
If `exclude` is true, reduction will be performed on the axes that are
72-
NOT in axis instead.
72+
NOT in axis instead.
7373
7474
Returns
7575
-------
@@ -100,7 +100,7 @@ def sum(data, axis=None, keepdims=False, exclude=False):
100100
101101
exclude : bool
102102
If `exclude` is true, reduction will be performed on the axes that are
103-
NOT in axis instead.
103+
NOT in axis instead.
104104
105105
Returns
106106
-------
@@ -111,6 +111,58 @@ def sum(data, axis=None, keepdims=False, exclude=False):
111111
return _make.sum(data, axis, keepdims, exclude)
112112

113113

114+
def all(data, axis=None, keepdims=False, exclude=False):
115+
"""Computes the logical AND of boolean array elements over given axes.
116+
117+
Parameters
118+
----------
119+
data : relay.Expr
120+
The input boolean tensor
121+
122+
axis : None or int or tuple of int
123+
Axis or axes along which a sum is performed. The default, axis=None,
124+
will sum all of the elements of the input array. If axis is
125+
negative it counts from the last to the first axis.
126+
127+
keepdims : bool
128+
If this is set to True, the axes which are reduced are left in the result as
129+
dimensions with size one. With this option, the result will broadcast
130+
correctly against the input array.
131+
132+
exclude : bool
133+
If `exclude` is true, reduction will be performed on the axes that are
134+
NOT in axis instead.
135+
136+
Returns
137+
-------
138+
result : relay.Expr
139+
The computed result.
140+
141+
Examples
142+
--------
143+
.. code-block:: python
144+
145+
data = relay.Constant(tvm.nd.array([[[ True, True, True],
146+
[ True, True, True],
147+
[False, True, False]],
148+
[[ True, False, False],
149+
[ True, True, False],
150+
[False, True, True]]]))
151+
152+
relay.all(data, axis=1)
153+
# [[False, True, False],
154+
# [False, False, False]]
155+
156+
relay.all(data, axis=0)
157+
# [[ True, False, False],
158+
# [ True, True, False],
159+
# [False, True, False]]
160+
161+
"""
162+
axis = [axis] if axis and isinstance(axis, int) else axis
163+
return _make.all(data, axis, keepdims, exclude)
164+
165+
114166
def max(data, axis=None, keepdims=False, exclude=False):
115167
""" Computes the max of array elements over given axes.
116168
@@ -131,7 +183,7 @@ def max(data, axis=None, keepdims=False, exclude=False):
131183
132184
exclude : bool
133185
If `exclude` is true, reduction will be performed on the axes that are
134-
NOT in axis instead.
186+
NOT in axis instead.
135187
136188
Returns
137189
-------
@@ -163,7 +215,7 @@ def min(data, axis=None, keepdims=False, exclude=False):
163215
164216
exclude : bool
165217
If `exclude` is true, reduction will be performed on the axes that are
166-
NOT in axis instead.
218+
NOT in axis instead.
167219
168220
Returns
169221
-------
@@ -194,7 +246,7 @@ def mean(data, axis=None, keepdims=False, exclude=False):
194246
195247
exclude : bool
196248
If `exclude` is true, reduction will be performed on the axes that are
197-
NOT in axis instead.
249+
NOT in axis instead.
198250
199251
Returns
200252
-------
@@ -225,7 +277,7 @@ def prod(data, axis=None, keepdims=False, exclude=False):
225277
226278
exclude : bool
227279
If `exclude` is true, reduction will be performed on the axes that are
228-
NOT in axis instead.
280+
NOT in axis instead.
229281
230282
Returns
231283
-------

src/lang/expr_operator.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,16 @@ Expr sum(Expr source, Array<IterVar> rdom) {
393393
return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0);
394394
}
395395

396+
Expr all(Expr source, Array<IterVar> rdom) {
397+
CHECK(source.type().is_bool());
398+
Var x("x", source.type()), y("y", source.type());
399+
Expr result = ir::And::make(x, y);
400+
Expr identity_element = make_const(source.type(), true);
401+
ir::CommReducer combiner =
402+
ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
403+
return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0);
404+
}
405+
396406
Expr max(Expr source, Array<IterVar> rdom) {
397407
Var x("x", source.type()), y("y", source.type());
398408
Expr result = ir::Max::make(x, y);

src/relay/op/tensor/reduce.cc

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,43 @@ Example::
355355
.set_attr<TOpPattern>("TOpPattern", kCommReduce);
356356

357357

358+
Array<Tensor> AllCompute(const Attrs& attrs,
359+
const Array<Tensor>& inputs,
360+
const Type& out_type,
361+
const Target& target) {
362+
return ReduceCompute(attrs, inputs, out_type, target, topi::all);
363+
}
364+
365+
366+
RELAY_REGISTER_REDUCE_OP("all")
367+
.describe(R"code(Computes the logical AND of boolean array elements over given axes.
368+
369+
Example::
370+
371+
data = [[[ True, True, True],
372+
[ True, True, True],
373+
[False, True, False]],
374+
[[ True, False, False],
375+
[ True, True, False],
376+
[False, True, True]]]
377+
378+
all(data, axis=1)
379+
[[False, True, False],
380+
[False, False, False]]
381+
382+
all(data, axis=0)
383+
[[ True, False, False],
384+
[ True, True, False],
385+
[False, True, False]]
386+
387+
)code" TVM_ADD_FILELINE)
388+
.set_attrs_type_key("relay.attrs.ReduceAttrs")
389+
.set_support_level(4)
390+
.add_type_rel("Reduce", ReduceRel)
391+
.set_attr<FTVMCompute>("FTVMCompute", AllCompute)
392+
.set_attr<TOpPattern>("TOpPattern", kCommReduce);
393+
394+
358395
Array<Tensor> MaxCompute(const Attrs& attrs,
359396
const Array<Tensor>& inputs,
360397
const Type& out_type,

tests/python/frontend/tensorflow/test_forward.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1597,6 +1597,17 @@ def check_mean(ishape, **kwargs):
15971597
check_mean((10, 8, 16, 32), axis=(2,3))
15981598
check_mean((10, 8, 16, 32), axis=(1,2), keepdims=True)
15991599

1600+
#######################################################################
1601+
# All
1602+
# ---
1603+
def test_forward_all():
1604+
"""Test the All operator."""
1605+
np_data = np.random.choice([True, False], size=(5, 7, 11))
1606+
tf.reset_default_graph()
1607+
in_data = tf.placeholder(tf.bool, (5, 7, 11), name="in_data")
1608+
tf.reduce_all(in_data, name="all")
1609+
compare_tf_with_tvm([np_data], ['in_data:0'], 'all:0')
1610+
16001611
#######################################################################
16011612
# Relational operators
16021613
# --------------------
@@ -1718,6 +1729,7 @@ def test_placeholder():
17181729
test_forward_reduce()
17191730
test_forward_mean()
17201731
test_forward_reduce_prod()
1732+
test_forward_all()
17211733

17221734
# General
17231735
test_forward_multi_input()

tests/python/relay/test_op_level4.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ def test_where():
138138
def verify_reduce(funcs, data, axis, keepdims, exclude, output, dtype="float32"):
139139
test_func = funcs[0]
140140
ref_func = funcs[1]
141+
dtype = "bool" if ref_func in [np.all] else dtype
141142

142143
x = relay.var("x", relay.TensorType(data, dtype))
143144
z = test_func(x, axis, keepdims, exclude)
@@ -155,7 +156,9 @@ def verify_reduce(funcs, data, axis, keepdims, exclude, output, dtype="float32")
155156
return
156157

157158
func = relay.Function([x], z)
158-
x_data = np.random.uniform(size=data).astype(dtype)
159+
x_data = np.random.choice([True, False], size=data) if ref_func in [np.all] \
160+
else np.random.uniform(size=data).astype(dtype)
161+
159162
if ref_func in [np.sum]:
160163
ref_res = ref_func(x_data + 0, axis=axis, dtype=dtype, keepdims=keepdims)
161164
elif ref_func in [np.max, np.min, np.mean, np.prod]:
@@ -194,6 +197,7 @@ def _wrapper(data, axis=None, keepdims=False):
194197
[relay.min, np.min],
195198
[relay.mean, np.mean],
196199
[relay.prod, np.prod],
200+
[relay.all, np.all],
197201
[relay.argmin, _with_keepdims(np.argmin)],
198202
[relay.argmax, _with_keepdims(np.argmax)]]:
199203
verify_reduce(func, (d1, d2, d3, d4), None, False, False, ())
@@ -203,6 +207,7 @@ def _wrapper(data, axis=None, keepdims=False):
203207
verify_reduce(func, (d1, d2, d3), (0, 1), True, False, (1, 1, d3))
204208
verify_reduce(func, (2, 3, 4), 1, True, False, (2, 1, 4))
205209
verify_reduce(func, (2, 3, 4), (1,), True, False, (2, 1, 4))
210+
verify_reduce(func, (2, 3, 4), -1, True, False, (2, 3, 1))
206211
verify_reduce(func, (2, 3, 4), (0, 1, 2), False, False, ())
207212
verify_reduce(func, (4, 4, 3), None, False, False, ())
208213
verify_reduce(func, (4, 4, 3), (0, 2), False, False, (4,))

0 commit comments

Comments
 (0)