Skip to content

Commit 4fdef3a

Browse files
ZihengJiangtqchen
authored andcommitted
[LANG] Support for Bitwise Operation (#502)
* [LANG] Support for Bitwise Operation * Add test
1 parent 9c2fc09 commit 4fdef3a

File tree

2 files changed

+29
-0
lines changed

2 files changed

+29
-0
lines changed

python/tvm/expr.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,24 @@ def __mod__(self, other):
6262
def __neg__(self):
6363
return self.__mul__(-1)
6464

65+
def __lshift__(self, other):
66+
return _make.Call(self.dtype, "shift_left", [self, other], Call.PureIntrinsic, None, 0)
67+
68+
def __rshift__(self, other):
69+
return _make.Call(self.dtype, "shift_right", [self, other], Call.PureIntrinsic, None, 0)
70+
71+
def __and__(self, other):
72+
return _make.Call(self.dtype, "bitwise_and", [self, other], Call.PureIntrinsic, None, 0)
73+
74+
def __or__(self, other):
75+
return _make.Call(self.dtype, "bitwise_or", [self, other], Call.PureIntrinsic, None, 0)
76+
77+
def __xor__(self, other):
78+
return _make.Call(self.dtype, "bitwise_xor", [self, other], Call.PureIntrinsic, None, 0)
79+
80+
def __invert__(self):
81+
return _make.Call(self.dtype, "bitwise_not", [self], Call.PureIntrinsic, None, 0)
82+
6583
def __lt__(self, other):
6684
return _make.LT(self, other)
6785

tests/python/unittest/test_lang_basic.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,16 @@ def test_all():
123123
'(((%s < %s) && (%s > (%s + 1))) && (%s < (%s*2)))' % (
124124
x.name, y.name, y.name, z.name, x.name, z.name)
125125

126+
def test_bitwise():
127+
x = tvm.var('x')
128+
y = tvm.var('y')
129+
assert str(x << y) == 'shift_left(x, y)'
130+
assert str(x >> y) == 'shift_right(x, y)'
131+
assert str(x & y) == 'bitwise_and(x, y)'
132+
assert str(x | y) == 'bitwise_or(x, y)'
133+
assert str(x ^ y) == 'bitwise_xor(x, y)'
134+
assert str(~x) == 'bitwise_not(x)'
135+
126136

127137
if __name__ == "__main__":
128138
test_cast()
@@ -137,3 +147,4 @@ def test_all():
137147
test_dtype()
138148
test_any()
139149
test_all()
150+
test_bitwise()

0 commit comments

Comments
 (0)