Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[MXNET-347] Logical Operators AND, XOR, OR (#10679)
Browse files Browse the repository at this point in the history
* logical and

* logical OR and XOR operators.

* better examples

* nits.

* elemwise operators

* non broadcast examples and tests.

* doc API

* rerun CI
  • Loading branch information
anirudhacharya authored and eric-haibin-lin committed May 1, 2018
1 parent 9f8f042 commit 61f86fc
Show file tree
Hide file tree
Showing 12 changed files with 348 additions and 5 deletions.
12 changes: 12 additions & 0 deletions docs/api/python/ndarray/ndarray.md
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,18 @@ The `ndarray` package provides several classes:
lesser_equal
```

### Logical operators

```eval_rst
.. autosummary::
:nosignatures:
logical_and
logical_or
logical_xor
logical_not
```

### Random sampling

```eval_rst
Expand Down
12 changes: 12 additions & 0 deletions docs/api/python/symbol/symbol.md
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,18 @@ Composite multiple symbols into a new one by an operator.
broadcast_lesser_equal
```

### Logical

```eval_rst
.. autosummary::
:nosignatures:
broadcast_logical_and
broadcast_logical_or
broadcast_logical_xor
broadcast_logical_not
```

### Random sampling

```eval_rst
Expand Down
181 changes: 177 additions & 4 deletions python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@

__all__ = ["NDArray", "concatenate", "_DTYPE_NP_TO_MX", "_DTYPE_MX_TO_NP", "_GRAD_REQ_MAP",
"ones", "add", "arange", "eye", "divide", "equal", "full", "greater", "greater_equal",
"imdecode", "lesser", "lesser_equal", "maximum", "minimum", "moveaxis", "modulo",
"multiply", "not_equal", "onehot_encode", "power", "subtract", "true_divide",
"waitall", "_new_empty_handle"]
"imdecode", "lesser", "lesser_equal", "logical_and", "logical_or", "logical_xor",
"maximum", "minimum", "moveaxis", "modulo", "multiply", "not_equal", "onehot_encode",
"power", "subtract", "true_divide", "waitall", "_new_empty_handle"]

_STORAGE_TYPE_UNDEFINED = -1
_STORAGE_TYPE_DEFAULT = 0
Expand Down Expand Up @@ -2485,7 +2485,7 @@ def add(lhs, rhs):
.. note::
If the corresponding dimensions of two arrays have the same size or one of them has size 1,
then the arrays are broadcastable to a common shape.
then the arrays are broadcastable to a common shape
Parameters
----------
Expand Down Expand Up @@ -3337,6 +3337,179 @@ def lesser_equal(lhs, rhs):
_internal._greater_equal_scalar)
# pylint: enable= no-member, protected-access

def logical_and(lhs, rhs):
"""Returns the result of element-wise **logical and** comparison
operation with broadcasting.
For each element in input arrays, return 1(true) if lhs elements and rhs elements
are true, otherwise return 0(false).
Equivalent to ``lhs and rhs`` and ``mx.nd.broadcast_logical_and(lhs, rhs)``.
.. note::
If the corresponding dimensions of two arrays have the same size or one of them has size 1,
then the arrays are broadcastable to a common shape.
Parameters
----------
lhs : scalar or array
First input of the function.
rhs : scalar or array
Second input of the function. If ``lhs.shape != rhs.shape``, they must be
broadcastable to a common shape.
Returns
-------
NDArray
Output array of boolean values.
Examples
--------
>>> x = mx.nd.ones((2,3))
>>> y = mx.nd.arange(2).reshape((2,1))
>>> z = mx.nd.arange(2).reshape((1,2))
>>> x.asnumpy()
array([[ 1., 1., 1.],
[ 1., 1., 1.]], dtype=float32)
>>> y.asnumpy()
array([[ 0.],
[ 1.]], dtype=float32)
>>> z.asnumpy()
array([[ 0., 1.]], dtype=float32)
>>> mx.nd.logical_and(x, 1).asnumpy()
array([[ 1., 1., 1.],
[ 1., 1., 1.]], dtype=float32)
>>> mx.nd.logical_and(x, y).asnumpy()
array([[ 0., 0., 0.],
[ 1., 1., 1.]], dtype=float32)
>>> mx.nd.logical_and(z, y).asnumpy()
array([[ 0., 0.],
[ 0., 1.]], dtype=float32)
"""
# pylint: disable= no-member, protected-access
return _ufunc_helper(
lhs,
rhs,
op.broadcast_logical_and,
lambda x, y: 1 if x and y else 0,
_internal._logical_and_scalar,
None)
# pylint: enable= no-member, protected-access

def logical_or(lhs, rhs):
"""Returns the result of element-wise **logical or** comparison
operation with broadcasting.
For each element in input arrays, return 1(true) if lhs elements or rhs elements
are true, otherwise return 0(false).
Equivalent to ``lhs or rhs`` and ``mx.nd.broadcast_logical_or(lhs, rhs)``.
.. note::
If the corresponding dimensions of two arrays have the same size or one of them has size 1,
then the arrays are broadcastable to a common shape.
Parameters
----------
lhs : scalar or array
First input of the function.
rhs : scalar or array
Second input of the function. If ``lhs.shape != rhs.shape``, they must be
broadcastable to a common shape.
Returns
-------
NDArray
Output array of boolean values.
Examples
--------
>>> x = mx.nd.ones((2,3))
>>> y = mx.nd.arange(2).reshape((2,1))
>>> z = mx.nd.arange(2).reshape((1,2))
>>> x.asnumpy()
array([[ 1., 1., 1.],
[ 1., 1., 1.]], dtype=float32)
>>> y.asnumpy()
array([[ 0.],
[ 1.]], dtype=float32)
>>> z.asnumpy()
array([[ 0., 1.]], dtype=float32)
>>> mx.nd.logical_or(x, 1).asnumpy()
array([[ 1., 1., 1.],
[ 1., 1., 1.]], dtype=float32)
>>> mx.nd.logical_or(x, y).asnumpy()
array([[ 1., 1., 1.],
[ 1., 1., 1.]], dtype=float32)
>>> mx.nd.logical_or(z, y).asnumpy()
array([[ 0., 1.],
[ 1., 1.]], dtype=float32)
"""
# pylint: disable= no-member, protected-access
return _ufunc_helper(
lhs,
rhs,
op.broadcast_logical_or,
lambda x, y: 1 if x or y else 0,
_internal._logical_or_scalar,
None)
# pylint: enable= no-member, protected-access

def logical_xor(lhs, rhs):
"""Returns the result of element-wise **logical xor** comparison
operation with broadcasting.
For each element in input arrays, return 1(true) if lhs elements or rhs elements
are true, otherwise return 0(false).
Equivalent to ``bool(lhs) ^ bool(rhs)`` and ``mx.nd.broadcast_logical_xor(lhs, rhs)``.
.. note::
If the corresponding dimensions of two arrays have the same size or one of them has size 1,
then the arrays are broadcastable to a common shape.
Parameters
----------
lhs : scalar or array
First input of the function.
rhs : scalar or array
Second input of the function. If ``lhs.shape != rhs.shape``, they must be
broadcastable to a common shape.
Returns
-------
NDArray
Output array of boolean values.
Examples
--------
>>> x = mx.nd.ones((2,3))
>>> y = mx.nd.arange(2).reshape((2,1))
>>> z = mx.nd.arange(2).reshape((1,2))
>>> x.asnumpy()
array([[ 1., 1., 1.],
[ 1., 1., 1.]], dtype=float32)
>>> y.asnumpy()
array([[ 0.],
[ 1.]], dtype=float32)
>>> z.asnumpy()
array([[ 0., 1.]], dtype=float32)
>>> mx.nd.logical_xor(x, y).asnumpy()
array([[ 1., 1., 1.],
[ 0., 0., 0.]], dtype=float32)
"""
# pylint: disable= no-member, protected-access
return _ufunc_helper(
lhs,
rhs,
op.broadcast_logical_xor,
lambda x, y: 1 if bool(x) ^ bool(y) else 0,
_internal._logical_xor_scalar,
None)
# pylint: enable= no-member, protected-access

def true_divide(lhs, rhs):

Expand Down
6 changes: 6 additions & 0 deletions src/operator/mshadow_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,12 @@ MXNET_BINARY_MATH_OP_NC(eq, a == b ? DType(1) : DType(0));

MXNET_BINARY_MATH_OP_NC(ne, a != b ? DType(1) : DType(0));

MXNET_BINARY_MATH_OP(logical_and, a && b ? DType(1) : DType(0));

MXNET_BINARY_MATH_OP(logical_or, a || b ? DType(1) : DType(0));

MXNET_BINARY_MATH_OP(logical_xor, (a || b) && !(a && b) ? DType(1) : DType(0));

MXNET_UNARY_MATH_OP(square_root, math::sqrt(a));

MXNET_UNARY_MATH_OP(square_root_grad, 0.5f / math::id(a));
Expand Down
6 changes: 6 additions & 0 deletions src/operator/operator_tune.cc
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,12 @@ IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::ne); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::ne); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::eq); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::eq); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::logical_and); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::logical_and); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::logical_or); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::logical_or); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::logical_xor); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::logical_xor); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::smooth_l1_loss); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::smooth_l1_gradient); // NOLINT()
IMPLEMENT_BLANK_WORKLOAD_FWD(mxnet::op::mxnet_op::set_to_int<0>); // NOLINT()
Expand Down
54 changes: 54 additions & 0 deletions src/operator/tensor/elemwise_binary_broadcast_op_logic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,5 +137,59 @@ Example::
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, mshadow_op::le>)
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes);

MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(broadcast_logical_and)
.describe(R"code(Returns the result of element-wise **logical and** with broadcasting.
Example::
x = [[ 1., 1., 1.],
[ 1., 1., 1.]]
y = [[ 0.],
[ 1.]]
broadcast_logical_and(x, y) = [[ 0., 0., 0.],
[ 1., 1., 1.]]
)code" ADD_FILELINE)
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, mshadow_op::logical_and>)
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes);

MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(broadcast_logical_or)
.describe(R"code(Returns the result of element-wise **logical or** with broadcasting.
Example::
x = [[ 1., 1., 0.],
[ 1., 1., 0.]]
y = [[ 1.],
[ 0.]]
broadcast_logical_or(x, y) = [[ 1., 1., 1.],
[ 1., 1., 0.]]
)code" ADD_FILELINE)
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, mshadow_op::logical_or>)
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes);

MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(broadcast_logical_xor)
.describe(R"code(Returns the result of element-wise **logical xor** with broadcasting.
Example::
x = [[ 1., 1., 0.],
[ 1., 1., 0.]]
y = [[ 1.],
[ 0.]]
broadcast_logical_xor(x, y) = [[ 0., 0., 1.],
[ 1., 1., 0.]]
)code" ADD_FILELINE)
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, mshadow_op::logical_xor>)
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes);

} // namespace op
} // namespace mxnet
9 changes: 9 additions & 0 deletions src/operator/tensor/elemwise_binary_broadcast_op_logic.cu
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,14 @@ NNVM_REGISTER_OP(broadcast_lesser)
NNVM_REGISTER_OP(broadcast_lesser_equal)
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, mshadow_op::le>);

NNVM_REGISTER_OP(broadcast_logical_and)
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, mshadow_op::logical_and>);

NNVM_REGISTER_OP(broadcast_logical_or)
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, mshadow_op::logical_or>);

NNVM_REGISTER_OP(broadcast_logical_xor)
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, mshadow_op::logical_xor>);

} // namespace op
} // namespace mxnet
15 changes: 15 additions & 0 deletions src/operator/tensor/elemwise_binary_op_logic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,5 +57,20 @@ MXNET_OPERATOR_REGISTER_BINARY(_lesser_equal)
.set_attr<FCompute>("FCompute<cpu>", ElemwiseBinaryOp::Compute<cpu, mshadow_op::le>)
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes);

MXNET_OPERATOR_REGISTER_BINARY(_logical_and)
.add_alias("_Logical_And")
.set_attr<FCompute>("FCompute<cpu>", ElemwiseBinaryOp::Compute<cpu, mshadow_op::logical_and>)
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes);

MXNET_OPERATOR_REGISTER_BINARY(_logical_or)
.add_alias("_Logical_Or")
.set_attr<FCompute>("FCompute<cpu>", ElemwiseBinaryOp::Compute<cpu, mshadow_op::logical_or>)
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes);

MXNET_OPERATOR_REGISTER_BINARY(_logical_xor)
.add_alias("_Logical_Xor")
.set_attr<FCompute>("FCompute<cpu>", ElemwiseBinaryOp::Compute<cpu, mshadow_op::logical_xor>)
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes);

} // namespace op
} // namespace mxnet
9 changes: 9 additions & 0 deletions src/operator/tensor/elemwise_binary_op_logic.cu
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,14 @@ NNVM_REGISTER_OP(_lesser)
NNVM_REGISTER_OP(_lesser_equal)
.set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryOp::Compute<gpu, mshadow_op::le>);

NNVM_REGISTER_OP(_logical_and)
.set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryOp::Compute<gpu, mshadow_op::logical_and>);

NNVM_REGISTER_OP(_logical_or)
.set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryOp::Compute<gpu, mshadow_op::logical_or>);

NNVM_REGISTER_OP(_logical_xor)
.set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryOp::Compute<gpu, mshadow_op::logical_xor>);

} // namespace op
} // namespace mxnet
15 changes: 15 additions & 0 deletions src/operator/tensor/elemwise_binary_scalar_op_logic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,5 +59,20 @@ MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_lesser_equal_scalar)
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.add_alias("_LesserEqualScalar");

MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_logical_and_scalar)
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::logical_and>)
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.add_alias("_LogicalAndScalar");

MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_logical_or_scalar)
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::logical_or>)
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.add_alias("_LogicalOrScalar");

MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_logical_xor_scalar)
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::logical_xor>)
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.add_alias("_LogicalXorScalar");

} // namespace op
} // namespace mxnet
Loading

0 comments on commit 61f86fc

Please sign in to comment.