|
23 | 23 | from .registry import register_intrin |
24 | 24 |
|
25 | 25 |
|
26 | | -@register_intrin |
| 26 | +@register_intrin() |
27 | 27 | def bool(imm): |
28 | | - return tvm.tir.const(imm.value, "bool") |
| 28 | + return tvm.tir.const(imm, "bool") |
29 | 29 |
|
30 | 30 |
|
31 | | -@register_intrin |
| 31 | +@register_intrin() |
32 | 32 | def int8(imm): |
33 | | - return tvm.tir.const(imm.value, "int8") |
| 33 | + return tvm.tir.const(imm, "int8") |
34 | 34 |
|
35 | 35 |
|
36 | | -@register_intrin |
| 36 | +@register_intrin() |
37 | 37 | def int16(imm): |
38 | | - return tvm.tir.const(imm.value, "int16") |
| 38 | + return tvm.tir.const(imm, "int16") |
39 | 39 |
|
40 | 40 |
|
41 | | -@register_intrin |
| 41 | +@register_intrin() |
42 | 42 | def int32(imm): |
43 | | - return tvm.tir.const(imm.value, "int32") |
| 43 | + return tvm.tir.const(imm, "int32") |
44 | 44 |
|
45 | 45 |
|
46 | | -@register_intrin |
| 46 | +@register_intrin() |
47 | 47 | def int64(imm): |
48 | | - return tvm.tir.const(imm.value, "int64") |
| 48 | + return tvm.tir.const(imm, "int64") |
49 | 49 |
|
50 | 50 |
|
51 | | -@register_intrin |
| 51 | +@register_intrin() |
52 | 52 | def uint8(imm): |
53 | | - return tvm.tir.const(imm.value, "uint8") |
| 53 | + return tvm.tir.const(imm, "uint8") |
54 | 54 |
|
55 | 55 |
|
56 | | -@register_intrin |
| 56 | +@register_intrin() |
57 | 57 | def uint16(imm): |
58 | | - return tvm.tir.const(imm.value, "uint16") |
| 58 | + return tvm.tir.const(imm, "uint16") |
59 | 59 |
|
60 | 60 |
|
61 | | -@register_intrin |
| 61 | +@register_intrin() |
62 | 62 | def uint32(imm): |
63 | | - return tvm.tir.const(imm.value, "uint32") |
| 63 | + return tvm.tir.const(imm, "uint32") |
64 | 64 |
|
65 | 65 |
|
66 | | -@register_intrin |
| 66 | +@register_intrin() |
67 | 67 | def uint64(imm): |
68 | | - return tvm.tir.const(imm.value, "uint64") |
| 68 | + return tvm.tir.const(imm, "uint64") |
69 | 69 |
|
70 | 70 |
|
71 | | -@register_intrin |
| 71 | +@register_intrin() |
72 | 72 | def float8(imm): |
73 | | - return tvm.tir.const(imm.value, "float8") |
| 73 | + return tvm.tir.const(imm, "float8") |
74 | 74 |
|
75 | 75 |
|
76 | | -@register_intrin |
| 76 | +@register_intrin() |
77 | 77 | def float16(imm): |
78 | | - return tvm.tir.const(imm.value, "float16") |
| 78 | + return tvm.tir.const(imm, "float16") |
79 | 79 |
|
80 | 80 |
|
81 | | -@register_intrin |
| 81 | +@register_intrin() |
82 | 82 | def float32(imm): |
83 | | - return tvm.tir.const(imm.value, "float32") |
| 83 | + return tvm.tir.const(imm, "float32") |
84 | 84 |
|
85 | 85 |
|
86 | | -@register_intrin |
| 86 | +@register_intrin() |
87 | 87 | def float64(imm): |
88 | | - return tvm.tir.const(imm.value, "float64") |
| 88 | + return tvm.tir.const(imm, "float64") |
89 | 89 |
|
90 | 90 |
|
91 | | -@register_intrin |
| 91 | +@register_intrin() |
92 | 92 | def floordiv(x, y): |
93 | 93 | return tvm.tir.floordiv(x, y) |
94 | 94 |
|
95 | 95 |
|
96 | | -@register_intrin |
| 96 | +@register_intrin() |
97 | 97 | def floormod(x, y): |
98 | 98 | return tvm.tir.floormod(x, y) |
99 | 99 |
|
100 | 100 |
|
101 | | -@register_intrin |
| 101 | +@register_intrin() |
102 | 102 | def load(dtype, var, index, predicate=True): |
103 | 103 | return tvm.tir.Load(dtype, var, index, predicate) |
104 | 104 |
|
105 | 105 |
|
106 | | -@register_intrin |
107 | | -def cast(dtype, value): |
| 106 | +@register_intrin() |
| 107 | +def cast(value, dtype): |
108 | 108 | return tvm.tir.Cast(dtype, value) |
109 | 109 |
|
110 | 110 |
|
111 | | -@register_intrin |
| 111 | +@register_intrin() |
112 | 112 | def ramp(base, stride, lanes): |
113 | | - lanes = lanes.value if not isinstance(lanes, int) else lanes |
114 | 113 | return tvm.tir.Ramp(base, stride, lanes) |
115 | 114 |
|
116 | 115 |
|
117 | | -@register_intrin |
| 116 | +@register_intrin() |
118 | 117 | def broadcast(value, lanes): |
119 | | - lanes = lanes.value if not isinstance(lanes, int) else lanes |
120 | 118 | return tvm.tir.Broadcast(value, lanes) |
121 | 119 |
|
122 | 120 |
|
123 | | -@register_intrin |
| 121 | +@register_intrin() |
124 | 122 | def evaluate(value): |
125 | 123 | return tvm.tir.Evaluate(value) |
126 | 124 |
|
127 | 125 |
|
128 | | -@register_intrin |
| 126 | +@register_intrin() |
129 | 127 | def store(var, index, value, predicate=True): |
130 | 128 | return tvm.tir.Store(var, value, index, predicate) |
131 | 129 |
|
132 | 130 |
|
133 | | -@register_intrin |
| 131 | +@register_intrin() |
134 | 132 | def iter_var(var, dom, iter_type, thread_tag): |
135 | 133 | iter_type = getattr(tvm.tir.IterVar, iter_type) |
136 | 134 | return tvm.tir.IterVar(dom, var, iter_type, thread_tag) |
| 135 | + |
| 136 | + |
| 137 | +@register_intrin() |
| 138 | +def max(a, b): # pylint: disable=redefined-builtin |
| 139 | + return tvm.tir.Max(a, b) |
| 140 | + |
| 141 | + |
| 142 | +def get_axis(begin, end, iter_type): |
| 143 | + ana = tvm.arith.Analyzer() |
| 144 | + extent = ana.simplify(end - begin) |
| 145 | + block_var_dom = tvm.ir.Range.from_min_extent(begin, extent) |
| 146 | + |
| 147 | + iter_type_dict = {"data_par": 0, "reduce": 2, "scan": 3, "opaque": 4} |
| 148 | + return tvm.tir.IterVar(block_var_dom, "bv", iter_type_dict[iter_type]) |
| 149 | + |
| 150 | + |
| 151 | +@register_intrin() |
| 152 | +def range(begin, end): |
| 153 | + return get_axis(begin, end, "data_par") |
| 154 | + |
| 155 | + |
| 156 | +@register_intrin() |
| 157 | +def reduce_axis(begin, end): |
| 158 | + return get_axis(begin, end, "reduce") |
| 159 | + |
| 160 | + |
| 161 | +@register_intrin() |
| 162 | +def scan_axis(begin, end): |
| 163 | + return get_axis(begin, end, "scan") |
| 164 | + |
| 165 | + |
| 166 | +@register_intrin() |
| 167 | +def opaque_axis(begin, end): |
| 168 | + return get_axis(begin, end, "opaque") |
0 commit comments