Skip to content

Commit 2b3ef79

Browse files
authored
Merge branch 'master' into ansor-gpu-tutorial
2 parents e602929 + de0c3a4 commit 2b3ef79

File tree

11 files changed

+1208
-1230
lines changed

11 files changed

+1208
-1230
lines changed

python/tvm/exec/rpc_server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def init_utvm(args):
6969
args : argparse.Namespace
7070
parsed args from command-line invocation
7171
"""
72-
from tvm import micro
72+
from tvm import micro # pylint: disable=import-outside-toplevel
7373

7474
if args.utvm_dev_config and args.utvm_dev_id:
7575
raise RuntimeError("only one of --utvm-dev-config and --utvm-dev-id allowed")

python/tvm/hybrid/_ffi_api.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,4 @@
1717
"""FFI APIs for tvm.hybrid"""
1818
import tvm._ffi
1919

20-
21-
tvm._ffi._init_api("tir.hybrid", __name__)
20+
tvm._ffi._init_api("hybrid", __name__)

python/tvm/hybrid/intrin.py

Lines changed: 70 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -23,114 +23,146 @@
2323
from .registry import register_intrin
2424

2525

26-
@register_intrin
26+
@register_intrin()
2727
def bool(imm):
28-
return tvm.tir.const(imm.value, "bool")
28+
return tvm.tir.const(imm, "bool")
2929

3030

31-
@register_intrin
31+
@register_intrin()
3232
def int8(imm):
33-
return tvm.tir.const(imm.value, "int8")
33+
return tvm.tir.const(imm, "int8")
3434

3535

36-
@register_intrin
36+
@register_intrin()
3737
def int16(imm):
38-
return tvm.tir.const(imm.value, "int16")
38+
return tvm.tir.const(imm, "int16")
3939

4040

41-
@register_intrin
41+
@register_intrin()
4242
def int32(imm):
43-
return tvm.tir.const(imm.value, "int32")
43+
return tvm.tir.const(imm, "int32")
4444

4545

46-
@register_intrin
46+
@register_intrin()
4747
def int64(imm):
48-
return tvm.tir.const(imm.value, "int64")
48+
return tvm.tir.const(imm, "int64")
4949

5050

51-
@register_intrin
51+
@register_intrin()
5252
def uint8(imm):
53-
return tvm.tir.const(imm.value, "uint8")
53+
return tvm.tir.const(imm, "uint8")
5454

5555

56-
@register_intrin
56+
@register_intrin()
5757
def uint16(imm):
58-
return tvm.tir.const(imm.value, "uint16")
58+
return tvm.tir.const(imm, "uint16")
5959

6060

61-
@register_intrin
61+
@register_intrin()
6262
def uint32(imm):
63-
return tvm.tir.const(imm.value, "uint32")
63+
return tvm.tir.const(imm, "uint32")
6464

6565

66-
@register_intrin
66+
@register_intrin()
6767
def uint64(imm):
68-
return tvm.tir.const(imm.value, "uint64")
68+
return tvm.tir.const(imm, "uint64")
6969

7070

71-
@register_intrin
71+
@register_intrin()
7272
def float8(imm):
73-
return tvm.tir.const(imm.value, "float8")
73+
return tvm.tir.const(imm, "float8")
7474

7575

76-
@register_intrin
76+
@register_intrin()
7777
def float16(imm):
78-
return tvm.tir.const(imm.value, "float16")
78+
return tvm.tir.const(imm, "float16")
7979

8080

81-
@register_intrin
81+
@register_intrin()
8282
def float32(imm):
83-
return tvm.tir.const(imm.value, "float32")
83+
return tvm.tir.const(imm, "float32")
8484

8585

86-
@register_intrin
86+
@register_intrin()
8787
def float64(imm):
88-
return tvm.tir.const(imm.value, "float64")
88+
return tvm.tir.const(imm, "float64")
8989

9090

91-
@register_intrin
91+
@register_intrin()
9292
def floordiv(x, y):
9393
return tvm.tir.floordiv(x, y)
9494

9595

96-
@register_intrin
96+
@register_intrin()
9797
def floormod(x, y):
9898
return tvm.tir.floormod(x, y)
9999

100100

101-
@register_intrin
101+
@register_intrin()
102102
def load(dtype, var, index, predicate=True):
103103
return tvm.tir.Load(dtype, var, index, predicate)
104104

105105

106-
@register_intrin
107-
def cast(dtype, value):
106+
@register_intrin()
107+
def cast(value, dtype):
108108
return tvm.tir.Cast(dtype, value)
109109

110110

111-
@register_intrin
111+
@register_intrin()
112112
def ramp(base, stride, lanes):
113-
lanes = lanes.value if not isinstance(lanes, int) else lanes
114113
return tvm.tir.Ramp(base, stride, lanes)
115114

116115

117-
@register_intrin
116+
@register_intrin()
118117
def broadcast(value, lanes):
119-
lanes = lanes.value if not isinstance(lanes, int) else lanes
120118
return tvm.tir.Broadcast(value, lanes)
121119

122120

123-
@register_intrin
121+
@register_intrin()
124122
def evaluate(value):
125123
return tvm.tir.Evaluate(value)
126124

127125

128-
@register_intrin
126+
@register_intrin()
129127
def store(var, index, value, predicate=True):
130128
return tvm.tir.Store(var, value, index, predicate)
131129

132130

133-
@register_intrin
131+
@register_intrin()
134132
def iter_var(var, dom, iter_type, thread_tag):
135133
iter_type = getattr(tvm.tir.IterVar, iter_type)
136134
return tvm.tir.IterVar(dom, var, iter_type, thread_tag)
135+
136+
137+
@register_intrin()
138+
def max(a, b): # pylint: disable=redefined-builtin
139+
return tvm.tir.Max(a, b)
140+
141+
142+
def get_axis(begin, end, iter_type):
143+
ana = tvm.arith.Analyzer()
144+
extent = ana.simplify(end - begin)
145+
block_var_dom = tvm.ir.Range.from_min_extent(begin, extent)
146+
147+
iter_type_dict = {"data_par": 0, "reduce": 2, "scan": 3, "opaque": 4}
148+
return tvm.tir.IterVar(block_var_dom, "bv", iter_type_dict[iter_type])
149+
150+
151+
@register_intrin()
152+
def range(begin, end):
153+
return get_axis(begin, end, "data_par")
154+
155+
156+
@register_intrin()
157+
def reduce_axis(begin, end):
158+
return get_axis(begin, end, "reduce")
159+
160+
161+
@register_intrin()
162+
def scan_axis(begin, end):
163+
return get_axis(begin, end, "scan")
164+
165+
166+
@register_intrin()
167+
def opaque_axis(begin, end):
168+
return get_axis(begin, end, "opaque")

0 commit comments

Comments
 (0)