Skip to content

Commit 2239508

Browse files
abergerontqchen
authored andcommitted
[Relay] Add logical operators (#2743)
1 parent 695647d commit 2239508

File tree

7 files changed

+133
-5
lines changed

7 files changed

+133
-5
lines changed

nnvm/src/top/tensor/elemwise.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,7 @@ NNVM_REGISTER_ELEMWISE_BINARY_OP(logical_and)
366366
.describe(R"code(Elementwise compute the logical AND
367367
368368
)code")
369-
.set_support_level(1)
369+
.set_support_level(4)
370370
.set_attr<FTVMCompute>(
371371
"FTVMCompute", [](const NodeAttrs& attrs,
372372
const Array<Tensor>& inputs,
@@ -378,7 +378,7 @@ NNVM_REGISTER_ELEMWISE_BINARY_OP(logical_or)
378378
.describe(R"code(Elementwise compute the logical OR
379379
380380
)code")
381-
.set_support_level(1)
381+
.set_support_level(4)
382382
.set_attr<FTVMCompute>(
383383
"FTVMCompute", [](const NodeAttrs& attrs,
384384
const Array<Tensor>& inputs,
@@ -413,7 +413,7 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(logical_not)
413413
.describe(R"code(Elementwise compute the logical NOT
414414
415415
)code" NNVM_ADD_FILELINE)
416-
.set_support_level(3)
416+
.set_support_level(4)
417417
.set_attr<FTVMCompute>(
418418
"FTVMCompute", [](const NodeAttrs& attrs,
419419
const Array<Tensor>& inputs,

python/tvm/relay/frontend/tensorflow.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -849,6 +849,11 @@ def _impl(inputs, attr, params):
849849
transforms={'axis': ('axis', 1)})([inputs[0]], attr)
850850
return _impl
851851

852+
def _logical(name):
853+
def _impl(inputs, attr, params):
854+
return AttrCvt(op_name=name)(inputs, attr)
855+
return _impl
856+
852857
# compatible operators that do NOT require any conversion.
853858
_identity_list = []
854859

@@ -909,6 +914,9 @@ def _impl(inputs, attr, params):
909914
'Transpose' : _transpose(),
910915
'Tanh' : AttrCvt('tanh'),
911916
'Mean' : _mean(),
917+
'LogicalAnd' : _logical('logical_and'),
918+
'LogicalOr' : _logical('logical_or'),
919+
'LogicalNot' : _logical('logical_not'),
912920
'Less' : _broadcast('less'),
913921
'Greater' : _broadcast('greater'),
914922
'LessEqual' : _broadcast('less_equal'),

python/tvm/relay/op/_tensor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
register_schedule("round", schedule_broadcast)
1919
register_schedule("abs", schedule_broadcast)
2020
register_schedule("tanh", schedule_broadcast)
21+
register_schedule("logical_not", schedule_broadcast)
2122
register_schedule("negative", schedule_broadcast)
2223
register_schedule("copy", schedule_broadcast)
2324

@@ -27,6 +28,8 @@
2728
register_schedule("divide", schedule_broadcast)
2829
register_schedule("power", schedule_injective)
2930
register_schedule("mod", schedule_broadcast)
31+
register_schedule("logical_and", schedule_broadcast)
32+
register_schedule("logical_or", schedule_broadcast)
3033
register_schedule("equal", schedule_broadcast)
3134
register_schedule("not_equal", schedule_broadcast)
3235
register_schedule("less", schedule_broadcast)

python/tvm/relay/op/tensor.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,22 @@ def negative(data):
191191
return _make.negative(data)
192192

193193

194+
def logical_not(data):
195+
"""Compute element-wise logical not of data.
196+
197+
Parameters
198+
----------
199+
data : relay.Expr
200+
The input data
201+
202+
Returns
203+
-------
204+
result : relay.Expr
205+
The computed result.
206+
"""
207+
return _make.logical_not(data)
208+
209+
194210
def add(lhs, rhs):
195211
"""Addition with numpy-style broadcasting.
196212
@@ -307,6 +323,42 @@ def mod(lhs, rhs):
307323
return _make.mod(lhs, rhs)
308324

309325

326+
def logical_and(lhs, rhs):
327+
"""logical AND with numpy-style broadcasting.
328+
329+
Parameters
330+
----------
331+
lhs : relay.Expr
332+
The left hand side input data
333+
rhs : relay.Expr
334+
The right hand side input data
335+
336+
Returns
337+
-------
338+
result : relay.Expr
339+
The computed result.
340+
"""
341+
return _make.logical_and(lhs, rhs)
342+
343+
344+
def logical_or(lhs, rhs):
345+
"""logical OR with numpy-style broadcasting.
346+
347+
Parameters
348+
----------
349+
lhs : relay.Expr
350+
The left hand side input data
351+
rhs : relay.Expr
352+
The right hand side input data
353+
354+
Returns
355+
-------
356+
result : relay.Expr
357+
The computed result.
358+
"""
359+
return _make.logical_or(lhs, rhs)
360+
361+
310362
def equal(lhs, rhs):
311363
"""Broadcasted elementwise test for (lhs == rhs).
312364

src/relay/op/tensor/binary.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,18 @@ RELAY_REGISTER_BINARY_OP("mod")
8282
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::mod));
8383

8484

85+
RELAY_REGISTER_BINARY_OP("logical_and")
86+
.describe("Elementwise logical AND with broadcasting")
87+
.set_support_level(4)
88+
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::logical_and));
89+
90+
91+
RELAY_REGISTER_BINARY_OP("logical_or")
92+
.describe("Elementwise logical OR with broadcasting")
93+
.set_support_level(4)
94+
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::logical_or));
95+
96+
8597
RELAY_REGISTER_CMP_OP("equal")
8698
.describe("Elementwise equal compare with broadcasting")
8799
.set_support_level(4)

src/relay/op/tensor/unary.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,5 +178,16 @@ RELAY_REGISTER_UNARY_OP("negative")
178178
.set_support_level(3)
179179
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::negative));
180180

181+
182+
RELAY_REGISTER_UNARY_OP("logical_not")
183+
.describe(R"code(Returns the logical inverse of input array, computed element-wise.
184+
185+
.. math::
186+
~(x)
187+
188+
)code" TVM_ADD_FILELINE)
189+
.set_support_level(4)
190+
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::logical_not));
191+
181192
} // namespace relay
182193
} // namespace tvm

tests/python/frontend/tensorflow/test_forward.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -682,6 +682,49 @@ def test_forward_pad():
682682
_test_pad((2, 3), [[1,1], [2,2]], mode="CONSTANT")
683683
_test_pad((2, 3), [[1,1], [2,2]], mode="CONSTANT", constant_values=1.0)
684684

685+
#######################################################################
686+
# Logical operators
687+
# --------------------
688+
def test_logical_and():
689+
with tf.Graph().as_default():
690+
in1 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in1')
691+
in2 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in2')
692+
out = tf.logical_and(in1, in2, name='out')
693+
in_data1 = np.random.choice(a=[False, True],size=(1, 4, 4, 3)).astype('bool')
694+
in_data2 = np.random.choice(a=[False, True],size=(1, 4, 4, 3)).astype('bool')
695+
compare_tf_with_tvm([in_data1, in_data2], ['in1:0', 'in2:0'], 'out:0')
696+
697+
def test_logical_or():
698+
with tf.Graph().as_default():
699+
in1 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in1')
700+
in2 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in2')
701+
out = tf.logical_or(in1, in2, name='out')
702+
in_data1 = np.random.choice(a=[False, True],size=(1, 4, 4, 3)).astype('bool')
703+
in_data2 = np.random.choice(a=[False, True],size=(1, 4, 4, 3)).astype('bool')
704+
compare_tf_with_tvm([in_data1, in_data2], ['in1:0', 'in2:0'], 'out:0')
705+
706+
def test_logical_xor():
707+
with tf.Graph().as_default():
708+
in1 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in1')
709+
in2 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in2')
710+
out = tf.logical_xor(in1, in2, name='out')
711+
in_data1 = np.random.choice(a=[False, True],size=(1, 4, 4, 3)).astype('bool')
712+
in_data2 = np.random.choice(a=[False, True],size=(1, 4, 4, 3)).astype('bool')
713+
compare_tf_with_tvm([in_data1, in_data2], ['in1:0', 'in2:0'], 'out:0')
714+
715+
def test_logical_not():
716+
with tf.Graph().as_default():
717+
in1 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in1')
718+
out = tf.logical_not(in1, name='out')
719+
in_data1 = np.random.choice(a=[False, True],size=(1, 4, 4, 3)).astype('bool')
720+
compare_tf_with_tvm(in_data1, 'in1:0', 'out:0')
721+
722+
def test_forward_logical():
723+
test_logical_and()
724+
test_logical_or()
725+
test_logical_xor()
726+
test_logical_not()
727+
685728

686729
#######################################################################
687730
# Inception V3
@@ -1109,5 +1152,4 @@ def test_forward_rel_ops():
11091152

11101153
# Relational ops
11111154
test_forward_rel_ops()
1112-
1113-
1155+
test_forward_logical()

0 commit comments

Comments
 (0)