Skip to content

Commit af8cbdd

Browse files
ZihengJiangtqchen
authored andcommitted
[TOPI] Add left_shift, right_shift, clip, cast (#504)
* [TOPI] Add left_shift, right_shift, clip, cast * [TOPI] Add test * [TOPI] Fix
1 parent 4fdef3a commit af8cbdd

File tree

2 files changed

+127
-0
lines changed

2 files changed

+127
-0
lines changed

topi/python/topi/math.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,3 +122,87 @@ def sigmoid(x):
122122
The result.
123123
"""
124124
return tvm.compute(x.shape, lambda *i: tvm.sigmoid(x(*i)))
125+
126+
127+
@tvm.tag_scope(tag=tag.ELEMWISE)
128+
def left_shift(x, n):
129+
"""Take n bits left shift of input x.
130+
131+
Parameters
132+
----------
133+
x : tvm.Tensor
134+
Input argument.
135+
n : int
136+
Number of bits.
137+
138+
Returns
139+
-------
140+
y : tvm.Tensor
141+
The result.
142+
"""
143+
return tvm.compute(x.shape, lambda *i: x(*i) << n)
144+
145+
146+
@tvm.tag_scope(tag=tag.ELEMWISE)
147+
def right_shift(x, n):
148+
"""Take n bits right shift of input x.
149+
150+
Parameters
151+
----------
152+
x : tvm.Tensor
153+
Input argument.
154+
n : int
155+
Number of bits.
156+
157+
Returns
158+
-------
159+
y : tvm.Tensor
160+
The result.
161+
"""
162+
return tvm.compute(x.shape, lambda *i: x(*i) >> n)
163+
164+
165+
@tvm.tag_scope(tag=tag.ELEMWISE)
166+
def clip(x, a_min, a_max):
167+
"""Clip (limit) the values in an array. Given an interval, values
168+
outside the interval are clipped to the interval edges.
169+
170+
Parameters
171+
----------
172+
x : tvm.Tensor
173+
Input argument.
174+
a_min : int or float
175+
Minimum value.
176+
a_max : int or float
177+
Maximum value.
178+
179+
Returns
180+
-------
181+
y : tvm.Tensor
182+
The result.
183+
"""
184+
def _compute(*indices):
185+
value = x(*indices)
186+
const_min = tvm.const(a_min, value.dtype)
187+
const_max = tvm.const(a_max, value.dtype)
188+
return tvm.max(tvm.min(value, const_max), const_min)
189+
return tvm.compute(x.shape, _compute)
190+
191+
192+
@tvm.tag_scope(tag=tag.ELEMWISE)
193+
def cast(x, dtype):
194+
"""Cast input to specified data type.
195+
196+
Parameters
197+
----------
198+
x : tvm.Tensor
199+
Input argument.
200+
dtype : str
201+
Data type.
202+
203+
Returns
204+
-------
205+
y : tvm.Tensor
206+
The result.
207+
"""
208+
return tvm.compute(x.shape, lambda *i: x(*i).astype(dtype))
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
"""Test code for clip operator"""
2+
import numpy as np
3+
import tvm
4+
import topi
5+
from topi.util import get_const_tuple
6+
from tvm.contrib.pickle_memoize import memoize
7+
8+
9+
def verify_clip(N, a_min, a_max, dtype):
10+
A = tvm.placeholder((N, N), dtype=dtype, name='A')
11+
B = topi.clip(A, a_min, a_max)
12+
s = tvm.create_schedule([B.op])
13+
14+
# use memoize to pickle the test data for next time use
15+
@memoize("topi.tests.test_topi_clip")
16+
def get_ref_data():
17+
a_np = np.random.uniform(a_min*2, a_max*2, size=(N, N)).astype(dtype)
18+
b_np = np.clip(a_np, a_min, a_max)
19+
return a_np, b_np
20+
a_np, b_np = get_ref_data()
21+
22+
def check_device(device):
23+
if not tvm.module.enabled(device):
24+
print("Skip because %s is not enabled" % device)
25+
return
26+
ctx = tvm.cpu(0) if device == "llvm" else tvm.gpu(0)
27+
a = tvm.nd.array(a_np, ctx)
28+
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx)
29+
f = tvm.build(s, [A, B], device, name="clip")
30+
f(a, b)
31+
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
32+
33+
for device in ['llvm']:
34+
check_device(device)
35+
36+
def test_clip():
37+
verify_clip(1024, -127, 127, 'int8')
38+
verify_clip(1024, -127, 127, 'int16')
39+
verify_clip(1024, -127, 127, 'float32')
40+
41+
42+
if __name__ == "__main__":
43+
test_clip()

0 commit comments

Comments
 (0)