Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/api/python/topi.rst
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ List of operators
topi.not_equal
topi.greater_equal
topi.less_equal
topi.all
topi.logical_and
topi.logical_or
topi.logical_not
Expand Down Expand Up @@ -140,6 +141,7 @@ topi
.. autofunction:: topi.gather_nd
.. autofunction:: topi.full
.. autofunction:: topi.full_like
.. autofunction:: topi.all
.. autofunction:: topi.max
.. autofunction:: topi.sum
.. autofunction:: topi.min
Expand Down
2 changes: 2 additions & 0 deletions docs/langref/relay_op.rst
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ This level enables additional math and transform operators.
tvm.relay.greater_equal
tvm.relay.less
tvm.relay.less_equal
tvm.relay.all
tvm.relay.logical_and
tvm.relay.logical_or
tvm.relay.logical_not
Expand Down Expand Up @@ -277,6 +278,7 @@ Level 4 Definitions
.. autofunction:: tvm.relay.greater_equal
.. autofunction:: tvm.relay.less
.. autofunction:: tvm.relay.less_equal
.. autofunction:: tvm.relay.all
.. autofunction:: tvm.relay.logical_and
.. autofunction:: tvm.relay.logical_or
.. autofunction:: tvm.relay.logical_not
Expand Down
7 changes: 7 additions & 0 deletions include/tvm/expr_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,13 @@ TVM_DLL Expr abs(Expr x);
*/
TVM_DLL Expr sum(Expr source, Array<IterVar> axis);

/*!
* \brief logical And of of source expression over axis
* \param source The source expression.
* \param axis List of iteration variables that will be used for reduction.
*/
TVM_DLL Expr all(Expr source, Array<IterVar> axis);

/*!
* \brief max of of source expression over axis
* \param source The source expression.
Expand Down
12 changes: 12 additions & 0 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,17 @@ def _impl(inputs, attr, params):
ignores=['name', 'Tidx'])([inputs[0]], attr)
return _impl

def _reduce_all():
def _impl(inputs, attr, params):
axis = params.pop(inputs[1].name_hint).asnumpy()
axis = tuple(axis)
return AttrCvt(
op_name='all',
extras={'axis': axis},
transforms={'keep_dims':'keepdims'},
ignores=['name', 'Tidx'])([inputs[0]], attr)
return _impl

def _square():
def _impl(inputs, attr, params):
return _op.multiply(inputs[0], inputs[0])
Expand Down Expand Up @@ -1099,6 +1110,7 @@ def _impl(inputs, attr, params):
# for N to 1 mapping, currently not supported(?)
_convert_map = {
'Add' : _elemwise('add'),
'All' : _reduce_all(),
'ArgMax' : _argx(_op.argmax, 'argmax'),
'ArgMin' : _argx(_op.argmin, 'argmin'),
'AvgPool' : _pooling('avg_pool'),
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/op/_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def _schedule_reduce(_, outs, target):
_reg.register_schedule("argmax", _schedule_reduce)
_reg.register_schedule("argmin", _schedule_reduce)
_reg.register_schedule("sum", _schedule_reduce)
_reg.register_schedule("all", _schedule_reduce)
_reg.register_schedule("max", _schedule_reduce)
_reg.register_schedule("min", _schedule_reduce)
_reg.register_schedule("prod", _schedule_reduce)
Expand Down
66 changes: 59 additions & 7 deletions python/tvm/relay/op/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def argmax(data, axis=None, keepdims=False, exclude=False):

exclude : bool
If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.
NOT in axis instead.

Returns
-------
Expand Down Expand Up @@ -69,7 +69,7 @@ def argmin(data, axis=None, keepdims=False, exclude=False):

exclude : bool
If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.
NOT in axis instead.

Returns
-------
Expand Down Expand Up @@ -100,7 +100,7 @@ def sum(data, axis=None, keepdims=False, exclude=False):

exclude : bool
If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.
NOT in axis instead.

Returns
-------
Expand All @@ -111,6 +111,58 @@ def sum(data, axis=None, keepdims=False, exclude=False):
return _make.sum(data, axis, keepdims, exclude)


def all(data, axis=None, keepdims=False, exclude=False):
"""Computes the logical AND of boolean array elements over given axes.

Parameters
----------
data : relay.Expr
The input boolean tensor

axis : None or int or tuple of int
Axis or axes along which a sum is performed. The default, axis=None,
will sum all of the elements of the input array. If axis is
negative it counts from the last to the first axis.

keepdims : bool
If this is set to True, the axes which are reduced are left in the result as
dimensions with size one. With this option, the result will broadcast
correctly against the input array.

exclude : bool
If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also add the example in cpp to python as well?

Returns
-------
result : relay.Expr
The computed result.

Examples
--------
.. code-block:: python

data = relay.Constant(tvm.nd.array([[[ True, True, True],
[ True, True, True],
[False, True, False]],
[[ True, False, False],
[ True, True, False],
[False, True, True]]]))

relay.all(data, axis=1)
# [[False, True, False],
# [False, False, False]]

relay.all(data, axis=0)
# [[ True, False, False],
# [ True, True, False],
# [False, True, False]]

"""
axis = [axis] if axis and isinstance(axis, int) else axis
return _make.all(data, axis, keepdims, exclude)


def max(data, axis=None, keepdims=False, exclude=False):
""" Computes the max of array elements over given axes.

Expand All @@ -131,7 +183,7 @@ def max(data, axis=None, keepdims=False, exclude=False):

exclude : bool
If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.
NOT in axis instead.

Returns
-------
Expand Down Expand Up @@ -163,7 +215,7 @@ def min(data, axis=None, keepdims=False, exclude=False):

exclude : bool
If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.
NOT in axis instead.

Returns
-------
Expand Down Expand Up @@ -194,7 +246,7 @@ def mean(data, axis=None, keepdims=False, exclude=False):

exclude : bool
If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.
NOT in axis instead.

Returns
-------
Expand Down Expand Up @@ -225,7 +277,7 @@ def prod(data, axis=None, keepdims=False, exclude=False):

exclude : bool
If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.
NOT in axis instead.

Returns
-------
Expand Down
10 changes: 10 additions & 0 deletions src/lang/expr_operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,16 @@ Expr sum(Expr source, Array<IterVar> rdom) {
return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0);
}

Expr all(Expr source, Array<IterVar> rdom) {
CHECK(source.type().is_bool());
Var x("x", source.type()), y("y", source.type());
Expr result = ir::And::make(x, y);
Expr identity_element = make_const(source.type(), true);
ir::CommReducer combiner =
ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0);
}

Expr max(Expr source, Array<IterVar> rdom) {
Var x("x", source.type()), y("y", source.type());
Expr result = ir::Max::make(x, y);
Expand Down
37 changes: 37 additions & 0 deletions src/relay/op/tensor/reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,43 @@ Example::
.set_attr<TOpPattern>("TOpPattern", kCommReduce);


Array<Tensor> AllCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
return ReduceCompute(attrs, inputs, out_type, target, topi::all);
}


RELAY_REGISTER_REDUCE_OP("all")
.describe(R"code(Computes the logical AND of boolean array elements over given axes.

Example::

data = [[[ True, True, True],
Copy link
Contributor

@kevinthesun kevinthesun May 6, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this operator allow numerical type?We'd better update the doc to reflect this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just allow boolean tensor, have updated the doc

[ True, True, True],
[False, True, False]],
[[ True, False, False],
[ True, True, False],
[False, True, True]]]

all(data, axis=1)
[[False, True, False],
[False, False, False]]

all(data, axis=0)
[[ True, False, False],
[ True, True, False],
[False, True, False]]

)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.ReduceAttrs")
.set_support_level(4)
.add_type_rel("Reduce", ReduceRel)
.set_attr<FTVMCompute>("FTVMCompute", AllCompute)
.set_attr<TOpPattern>("TOpPattern", kCommReduce);


Array<Tensor> MaxCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
Expand Down
12 changes: 12 additions & 0 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1471,6 +1471,17 @@ def check_mean(ishape, **kwargs):
check_mean((10, 8, 16, 32), axis=(2,3))
check_mean((10, 8, 16, 32), axis=(1,2), keepdims=True)

#######################################################################
# All
# ---
def test_forward_all():
"""Test the All operator."""
np_data = np.random.choice([True, False], size=(5, 7, 11))
tf.reset_default_graph()
in_data = tf.placeholder(tf.bool, (5, 7, 11), name="in_data")
tf.reduce_all(in_data, name="all")
compare_tf_with_tvm([np_data], ['in_data:0'], 'all:0')

#######################################################################
# Relational operators
# --------------------
Expand Down Expand Up @@ -1569,6 +1580,7 @@ def test_forward_reduce_prod():
test_forward_reduce()
test_forward_mean()
test_forward_reduce_prod()
test_forward_all()

# General
test_forward_multi_input()
Expand Down
7 changes: 6 additions & 1 deletion tests/python/relay/test_op_level4.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def test_where():
def verify_reduce(funcs, data, axis, keepdims, exclude, output, dtype="float32"):
test_func = funcs[0]
ref_func = funcs[1]
dtype = "bool" if ref_func in [np.all] else dtype

x = relay.var("x", relay.TensorType(data, dtype))
z = test_func(x, axis, keepdims, exclude)
Expand All @@ -155,7 +156,9 @@ def verify_reduce(funcs, data, axis, keepdims, exclude, output, dtype="float32")
return

func = relay.Function([x], z)
x_data = np.random.uniform(size=data).astype(dtype)
x_data = np.random.choice([True, False], size=data) if ref_func in [np.all] \
else np.random.uniform(size=data).astype(dtype)

if ref_func in [np.sum]:
ref_res = ref_func(x_data + 0, axis=axis, dtype=dtype, keepdims=keepdims)
elif ref_func in [np.max, np.min, np.mean, np.prod]:
Expand Down Expand Up @@ -194,6 +197,7 @@ def _wrapper(data, axis=None, keepdims=False):
[relay.min, np.min],
[relay.mean, np.mean],
[relay.prod, np.prod],
[relay.all, np.all],
[relay.argmin, _with_keepdims(np.argmin)],
[relay.argmax, _with_keepdims(np.argmax)]]:
verify_reduce(func, (d1, d2, d3, d4), None, False, False, ())
Expand All @@ -203,6 +207,7 @@ def _wrapper(data, axis=None, keepdims=False):
verify_reduce(func, (d1, d2, d3), (0, 1), True, False, (1, 1, d3))
verify_reduce(func, (2, 3, 4), 1, True, False, (2, 1, 4))
verify_reduce(func, (2, 3, 4), (1,), True, False, (2, 1, 4))
verify_reduce(func, (2, 3, 4), -1, True, False, (2, 3, 1))
verify_reduce(func, (2, 3, 4), (0, 1, 2), False, False, ())
verify_reduce(func, (4, 4, 3), None, False, False, ())
verify_reduce(func, (4, 4, 3), (0, 2), False, False, (4,))
Expand Down
21 changes: 21 additions & 0 deletions topi/include/topi/reduction.h
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,27 @@ inline Tensor collapse_sum(const Tensor& data, Array<Expr> target_shape) {
return DoCommReduce(data, tvm::sum, target_shape, reduce_axes, squeeze_axes);
}

/*!
* \brief Creates an operation that computes the logical AND of elements
* over a given axis
*
* \param data The input boolean tensor
* \param axis The axes to reduce. If axis is empty, the operation will
* perform logical AND over all elements of the array.
* \param keepdims If this is set to true, the axes which are reduced are
* left in the result as dimensions with size one. This enables the result
* to broadcast correctly against the input array.
* \param atleast1d Whether the output need to be atleast1d.
*
* \return A Tensor whose op member is the all operation
*/
inline Tensor all(const Tensor& data,
const Array<Integer>& axis,
bool keepdims = false,
bool atleast1d = false) {
return CommReduce(data, axis, tvm::all, keepdims, atleast1d);
}

/*!
* \brief Creates an operation that finds the minimum of elements over
* a given axis.
Expand Down
25 changes: 25 additions & 0 deletions topi/python/topi/reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,31 @@ def sum(data, axis=None, keepdims=False):
return cpp.sum(data, axis, keepdims)


def all(data, axis=None, keepdims=False):
"""Logical AND of array elements over a given axis or a list of axes

Parameters
----------
data : tvm.Tensor
The input tvm boolean tensor

axis : None or int or tuple of int
Axis or axes along which a logical AND is performed.
The default, axis=None, will perform logical AND over all elements of the input array.
If axis is negative it counts from the last to the first axis.

keepdims : bool
If this is set to True, the axes which are reduced are left in the result as dimensions
with size one.
With this option, the result will broadcast correctly against the input array.

Returns
-------
ret : tvm.Tensor
"""
return cpp.all(data, axis, keepdims)


def max(data, axis=None, keepdims=False):
"""Maximum of array elements over a given axis or a list of axes

Expand Down
5 changes: 5 additions & 0 deletions topi/src/topi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,11 @@ TVM_REGISTER_GLOBAL("topi.prod")
*rv = topi::prod(args[0], ArrayOrInt(args[1]), args[2]);
});

TVM_REGISTER_GLOBAL("topi.all")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::all(args[0], ArrayOrInt(args[1]), args[2]);
});

/* Ops from transform.h */
TVM_REGISTER_GLOBAL("topi.expand_dims")
.set_body([](TVMArgs args, TVMRetValue *rv) {
Expand Down
Loading