Skip to content

Commit ea3e12b

Browse files
committed
[TOPI][PYTORCH]Logical & Bitwise operator support
1 parent b1364eb commit ea3e12b

File tree

11 files changed

+215
-1
lines changed

11 files changed

+215
-1
lines changed

docs/api/python/topi.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ List of operators
9999
topi.logical_and
100100
topi.logical_or
101101
topi.logical_not
102+
topi.logical_xor
102103
topi.arange
103104
topi.stack
104105
topi.repeat
@@ -193,6 +194,7 @@ topi
193194
.. autofunction:: topi.logical_and
194195
.. autofunction:: topi.logical_or
195196
.. autofunction:: topi.logical_not
197+
.. autofunction:: topi.logical_xor
196198

197199
topi.nn
198200
~~~~~~~

docs/langref/relay_op.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ This level enables additional math and transform operators.
150150
tvm.relay.logical_and
151151
tvm.relay.logical_or
152152
tvm.relay.logical_not
153+
tvm.relay.logical_xor
153154
tvm.relay.maximum
154155
tvm.relay.minimum
155156
tvm.relay.power

python/tvm/relay/frontend/pytorch.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1168,7 +1168,6 @@ def _impl(inputs, input_types):
11681168

11691169
def _clamp():
11701170
def _impl(inputs, input_types):
1171-
print(inputs, input_types)
11721171
data = inputs[0]
11731172
amin = inputs[1] if inputs[1] else np.finfo(np.float32).min
11741173
amax = inputs[2] if inputs[2] else np.finfo(np.float32).max
@@ -1297,6 +1296,61 @@ def _impl(inputs, input_types):
12971296
return _op.nn.dense(inputs[0], inputs[1])
12981297
return _impl
12991298

1299+
def _bitwise_not():
1300+
def _impl(inputs, input_types):
1301+
data = inputs[0]
1302+
if input_types[0] == "bool":
1303+
data = _op.cast(data, "bool")
1304+
else:
1305+
data = _op.cast(data, "int")
1306+
1307+
return _op.bitwise_not(data)
1308+
return _impl
1309+
1310+
def _bitwise_xor():
1311+
def _impl(inputs, input_types):
1312+
lhs = inputs[0]
1313+
1314+
import torch
1315+
if isinstance(inputs[1], _expr.Var):
1316+
rhs = inputs[1]
1317+
elif isinstance(inputs[1], torch.Tensor):
1318+
rhs = _wrap_const(inputs[1].numpy())
1319+
else:
1320+
msg = "Data type %s could not be parsed in bitwise_xor operator." % (type(inputs[1]))
1321+
raise AssertionError(msg)
1322+
1323+
lhs = _op.cast(lhs, "bool") if input_types[0] == "bool" else _op.cast(lhs, "int")
1324+
rhs = _op.cast(rhs, "bool") if input_types[1] == "bool" else _op.cast(rhs, "int")
1325+
1326+
return _op.bitwise_xor(lhs, rhs)
1327+
return _impl
1328+
1329+
def _logical_not():
1330+
def _impl(inputs, input_types):
1331+
data = inputs[0]
1332+
1333+
return _op.logical_not(_op.cast(data, "bool"))
1334+
return _impl
1335+
1336+
1337+
def _logical_xor():
1338+
def _impl(inputs, input_types):
1339+
lhs = _op.cast(inputs[0], "bool")
1340+
1341+
import torch
1342+
if isinstance(inputs[1], _expr.Var):
1343+
rhs = inputs[1]
1344+
elif isinstance(inputs[1], torch.Tensor):
1345+
rhs = _wrap_const(inputs[1].numpy())
1346+
else:
1347+
msg = "Data type %s could not be parsed in logical_xor operator." % (type(inputs[1]))
1348+
raise AssertionError(msg)
1349+
1350+
rhs = _op.cast(rhs, "bool")
1351+
return _op.logical_xor(lhs, rhs)
1352+
return _impl
1353+
13001354

13011355
def _isfinite():
13021356
def _impl(inputs, input_types):
@@ -1524,6 +1578,10 @@ def _get_convert_map(prelude):
15241578
"aten::ge" : _elemwise("greater_equal"),
15251579
"aten::ne" : _elemwise("not_equal"),
15261580
"aten::eq" : _elemwise("equal"),
1581+
"aten::logical_not" : _logical_not(),
1582+
"aten::logical_xor" : _logical_xor(),
1583+
"aten::bitwise_not" : _bitwise_not(),
1584+
"aten::bitwise_xor" : _bitwise_xor(),
15271585
"aten::isfinite" : _isfinite(),
15281586
"aten::isnan" : _isnan(),
15291587
"aten::Bool" : _Bool(),

python/tvm/relay/op/_tensor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
register_broadcast_schedule("logical_not")
5454
register_broadcast_schedule("logical_and")
5555
register_broadcast_schedule("logical_or")
56+
register_broadcast_schedule("logical_xor")
5657
register_broadcast_schedule("bitwise_not")
5758
register_broadcast_schedule("bitwise_and")
5859
register_broadcast_schedule("bitwise_or")
@@ -205,6 +206,7 @@ def elemwise_shape_func(attrs, inputs, _):
205206
register_shape_func("floor_mod", False, broadcast_shape_func)
206207
register_shape_func("logical_and", False, broadcast_shape_func)
207208
register_shape_func("logical_or", False, broadcast_shape_func)
209+
register_shape_func("logical_xor", False, broadcast_shape_func)
208210
register_shape_func("bitwise_not", False, broadcast_shape_func)
209211
register_shape_func("bitwise_and", False, broadcast_shape_func)
210212
register_shape_func("bitwise_or", False, broadcast_shape_func)

python/tvm/relay/op/tensor.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,23 @@ def logical_or(lhs, rhs):
537537
return _make.logical_or(lhs, rhs)
538538

539539

540+
def logical_xor(lhs, rhs):
541+
"""logical XOR with numpy-style broadcasting.
542+
543+
Parameters
544+
----------
545+
lhs : relay.Expr
546+
The left hand side input data
547+
rhs : relay.Expr
548+
The right hand side input data
549+
550+
Returns
551+
-------
552+
result : relay.Expr
553+
The computed result.
554+
"""
555+
return _make.logical_xor(lhs, rhs)
556+
540557
def bitwise_and(lhs, rhs):
541558
"""bitwise AND with numpy-style broadcasting.
542559

src/relay/op/tensor/binary.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,12 @@ RELAY_REGISTER_BINARY_OP("logical_or")
123123
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::logical_or));
124124

125125

126+
RELAY_REGISTER_BINARY_OP("logical_xor")
127+
.describe("Elementwise logical XOR with broadcasting")
128+
.set_support_level(4)
129+
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::logical_xor));
130+
131+
126132
RELAY_REGISTER_BINARY_OP("bitwise_and")
127133
.describe("Elementwise bitwise AND with broadcasting")
128134
.set_support_level(4)

tests/python/frontend/pytorch/test_forward.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1600,6 +1600,95 @@ def forward(self, *args):
16001600
verify_model(Topk6().float().eval(), input_data=input_data)
16011601

16021602

1603+
def test_forward_logical_not():
1604+
torch.set_grad_enabled(False)
1605+
1606+
class LogicalNot1(Module):
1607+
def forward(self, *args):
1608+
return torch.logical_not(args[0])
1609+
1610+
input_data = torch.tensor([True, False])
1611+
verify_model(LogicalNot1().float().eval(), input_data=input_data)
1612+
1613+
input_data = torch.tensor([0, 1, -10], dtype=torch.int8)
1614+
verify_model(LogicalNot1().float().eval(), input_data=input_data)
1615+
1616+
input_data = torch.tensor([0., 1.5, -10.], dtype=torch.double)
1617+
verify_model(LogicalNot1().float().eval(), input_data=input_data)
1618+
1619+
input_data = torch.tensor([0., 1., -10.], dtype=torch.int32)
1620+
verify_model(LogicalNot1().float().eval(), input_data=input_data)
1621+
1622+
1623+
def test_forward_bitwise_not():
1624+
torch.set_grad_enabled(False)
1625+
1626+
class BitwiseNot1(Module):
1627+
def forward(self, *args):
1628+
return torch.bitwise_not(args[0])
1629+
1630+
input_data = torch.tensor([0, 1, -10], dtype=torch.int8)
1631+
verify_model(BitwiseNot1().float().eval(), input_data=input_data)
1632+
1633+
input_data = torch.tensor([0., 1., -10.], dtype=torch.int32)
1634+
verify_model(BitwiseNot1().float().eval(), input_data=input_data)
1635+
1636+
input_data = torch.tensor([True, False])
1637+
verify_model(BitwiseNot1().float().eval(), input_data=input_data)
1638+
1639+
1640+
def test_forward_bitwise_xor():
1641+
torch.set_grad_enabled(False)
1642+
1643+
class BitwiseXor1(Module):
1644+
def forward(self, *args):
1645+
return torch.bitwise_xor(args[0], args[1])
1646+
1647+
class BitwiseXor2(Module):
1648+
def forward(self, *args):
1649+
rhs = torch.tensor([1, 0, 3], dtype=torch.int8)
1650+
if torch.cuda.is_available():
1651+
rhs = rhs.cuda()
1652+
return torch.bitwise_xor(args[0], rhs)
1653+
1654+
lhs = torch.tensor([-1, -2, 3], dtype=torch.int8)
1655+
rhs = torch.tensor([1, 0, 3], dtype=torch.int8)
1656+
verify_model(BitwiseXor1().float().eval(), input_data=[lhs, rhs])
1657+
1658+
lhs = torch.tensor([True, True, False])
1659+
rhs = torch.tensor([False, True, False])
1660+
verify_model(BitwiseXor1().float().eval(), input_data=[lhs, rhs])
1661+
1662+
lhs = torch.tensor([-1, -2, 3], dtype=torch.int8)
1663+
verify_model(BitwiseXor2().float().eval(), input_data=[lhs])
1664+
1665+
1666+
def test_forward_logical_xor():
1667+
torch.set_grad_enabled(False)
1668+
1669+
class LogicalXor1(Module):
1670+
def forward(self, *args):
1671+
return torch.logical_xor(args[0], args[1])
1672+
1673+
class LogicalXor2(Module):
1674+
def forward(self, *args):
1675+
rhs = torch.tensor([1, 0, 3], dtype=torch.int8)
1676+
if torch.cuda.is_available():
1677+
rhs = rhs.cuda()
1678+
return torch.logical_xor(args[0], rhs)
1679+
1680+
lhs = torch.tensor([-1, -2, 3], dtype=torch.int8)
1681+
rhs = torch.tensor([1, 0, 3], dtype=torch.int8)
1682+
verify_model(LogicalXor1().float().eval(), input_data=[lhs, rhs])
1683+
1684+
lhs = torch.tensor([True, True, False])
1685+
rhs = torch.tensor([False, True, False])
1686+
verify_model(LogicalXor1().float().eval(), input_data=[lhs, rhs])
1687+
1688+
lhs = torch.tensor([-1, -2, 3], dtype=torch.int8)
1689+
verify_model(LogicalXor2().float().eval(), input_data=[lhs])
1690+
1691+
16031692
if __name__ == "__main__":
16041693
# Single operator tests
16051694
test_forward_add()
@@ -1663,6 +1752,10 @@ def forward(self, *args):
16631752
test_forward_clamp()
16641753
test_forward_floor()
16651754
test_forward_round()
1755+
test_forward_logical_not()
1756+
test_forward_bitwise_not()
1757+
test_forward_bitwise_xor()
1758+
test_forward_logical_xor()
16661759
test_forward_isfinite()
16671760
test_forward_isnan()
16681761
test_forward_isinf()

topi/include/topi/broadcast.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,19 @@ TOPI_DEFINE_OP_OVERLOAD(operator&&, logical_and);
140140
TOPI_DEFINE_BCAST_OP(logical_or, { return a || b; });
141141
TOPI_DEFINE_OP_OVERLOAD(operator||, logical_or);
142142

143+
/*!
144+
* \fn logical_xor
145+
* \brief Compute A ^ B with auto-broadcasting.
146+
*
147+
* \param A The first tensor, or Expr
148+
* \param B The second tensor, or Expr
149+
* \param name The name of the operation
150+
* \param tag The tag to mark the operation
151+
*
152+
* \return The result.
153+
*/
154+
TOPI_DEFINE_BCAST_OP(logical_xor, { return a ^ b; });
155+
143156
/*!
144157
* \fn bitwise_and
145158
* \brief Compute A & B with auto-broadcasting.

topi/python/topi/broadcast.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,25 @@ def logical_or(lhs, rhs):
420420
return _cpp.logical_or(lhs, rhs)
421421

422422

423+
def logical_xor(lhs, rhs):
424+
"""Compute element-wise logical xor of data.
425+
426+
Parameters
427+
----------
428+
lhs : tvm.te.Tensor or Expr
429+
The left operand
430+
rhs : tvm.te.Tensor or Expr
431+
The right operand
432+
433+
Returns
434+
-------
435+
ret : tvm.te.Tensor or Expr
436+
Returns Expr if both operands are Expr.
437+
Otherwise returns Tensor.
438+
"""
439+
return _cpp.logical_xor(lhs, rhs)
440+
441+
423442
def bitwise_and(lhs, rhs):
424443
"""Compute element-wise bitwise and of data.
425444

topi/src/broadcast.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ TOPI_REGISTER_BCAST_OP("topi.power", topi::power);
6565
TOPI_REGISTER_BCAST_OP("topi.left_shift", topi::left_shift);
6666
TOPI_REGISTER_BCAST_OP("topi.logical_and", topi::logical_and);
6767
TOPI_REGISTER_BCAST_OP("topi.logical_or", topi::logical_or);
68+
TOPI_REGISTER_BCAST_OP("topi.logical_xor", topi::logical_xor);
6869
TOPI_REGISTER_BCAST_OP("topi.bitwise_and", topi::bitwise_and);
6970
TOPI_REGISTER_BCAST_OP("topi.bitwise_or", topi::bitwise_or);
7071
TOPI_REGISTER_BCAST_OP("topi.bitwise_xor", topi::bitwise_xor);

0 commit comments

Comments
 (0)