Skip to content

Commit 2a93563

Browse files
rebase and update
1 parent 69842d5 commit 2a93563

File tree

3 files changed

+23
-21
lines changed

3 files changed

+23
-21
lines changed

python/tvm/target/datatype.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@
1717
"""Custom datatype functionality"""
1818
import tvm
1919
from tvm.runtime import convert, DataType
20-
from tvm.tir.expr import Call as _Call, Cast as _Cast, FloatImm as _FloatImm
20+
from tvm.tir.expr import Call as _Call, Cast as _Cast, FloatImm as _FloatImm, BinaryOpExpr as _BinaryOpExpr
21+
from tvm.tir.op import call_pure_extern
2122
from tvm._ffi import register_func as _register_func
2223
from tvm.tir import call_intrin
24+
from tvm.ir import Op
2325

2426

2527
def register(type_name, type_code):
@@ -154,17 +156,15 @@ def lower(op):
154156
dtype += "x" + str(t.lanes)
155157
if isinstance(op, _Cast):
156158
src_bits = bit_length(op.value.dtype)
157-
return _Call(dtype, extern_func_map[(src_bits, t.bits)], convert([op.value]),
158-
_Call.Extern)
159+
return call_pure_extern(dtype, extern_func_map[(src_bits, t.bits)], op.value)
159160
elif isinstance(op, _FloatImm):
160-
return _Call(dtype, extern_func_map[t.bits], convert([op.value]),
161-
_Call.Extern)
162-
elif isinstance(op, _Call) and (op.call_type == _Call.Intrinsic or
163-
op.call_type == _Call.PureIntrinsic):
164-
return _Call(dtype, extern_func_map[t.bits], convert(op.args),
165-
_Call.Extern)
166-
return _Call(dtype, extern_func_map[t.bits], convert([op.a, op.b]),
167-
_Call.Extern)
161+
return call_pure_extern(dtype, extern_func_map[t.bits], op.value)
162+
elif isinstance(op, _Call):
163+
return call_pure_extern(dtype, extern_func_map[t.bits], *op.args)
164+
elif isinstance(op, _BinaryOpExpr):
165+
return call_pure_extern(dtype, extern_func_map[t.bits], op.a, op.b)
166+
167+
raise RuntimeError(f"lowering unsupported op: {op.astext()}")
168168

169169
return lower
170170

@@ -179,6 +179,6 @@ def lower_ite(ite_intrin):
179179
dtype = "uint" + str(t.bits)
180180
if t.lanes > 1:
181181
dtype += "x" + str(t.lanes)
182-
return call_intrin(dtype, "tvm_if_then_else", convert(ite_intrin.args[0]),
182+
return call_intrin(dtype, "tir.if_then_else", convert(ite_intrin.args[0]),
183183
convert(ite_intrin.args[1]),
184184
convert(ite_intrin.args[2]))

src/tir/transforms/lower_custom_datatypes.cc

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
#include <tvm/runtime/registry.h>
2525
#include <tvm/target/target.h>
26+
#include <tvm/tir/op.h>
2627
#include <tvm/tir/stmt_functor.h>
2728
#include <tvm/tir/transform.h>
2829

@@ -102,11 +103,12 @@ class CustomDatatypesLowerer : public StmtExprMutator {
102103
PrimExpr expr = StmtExprMutator::VisitExpr_(call);
103104
call = expr.as<CallNode>();
104105
if (toBeLowered) {
105-
CHECK(call->call_type == CallNode::Intrinsic || call->call_type == CallNode::PureIntrinsic)
106+
auto op = call->op.as<OpNode>();
107+
CHECK(op != nullptr)
106108
<< "Lowering non-intrinsic Calls not implemented";
107-
auto lower = datatype::GetIntrinLowerFunc(target_, call->name, call->dtype.code());
109+
auto lower = datatype::GetIntrinLowerFunc(target_, op->name, call->dtype.code());
108110
CHECK(lower) << "Intrinsic lowering function for target " << target_ << ", intrinsic name "
109-
<< call->name << ", type " << static_cast<unsigned>(call->dtype.code())
111+
<< op->name << ", type " << static_cast<unsigned>(call->dtype.code())
110112
<< " not found";
111113
return (*lower)(expr);
112114
}

tests/python/unittest/test_custom_datatypes.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -142,32 +142,32 @@ def setup():
142142
32: 'Posit32es2Sqrt',
143143
16: 'Posit16es2Sqrt',
144144
8: 'Posit8es2Sqrt'
145-
}), "Call", "llvm", "posites2", intrinsic_name="sqrt")
145+
}), "Call", "llvm", "posites2", intrinsic_name="tir.sqrt")
146146
register_op(lower_ite,
147147
"Call",
148148
"llvm",
149149
"posites2",
150-
intrinsic_name="tvm_if_then_else")
150+
intrinsic_name="tir.if_then_else")
151151
register_op(create_lower_func({
152152
32: 'Posit32es2Exp',
153153
16: 'Posit16es2Exp',
154154
8: 'Posit8es2Exp'
155-
}), "Call", "llvm", "posites2", intrinsic_name="exp")
155+
}), "Call", "llvm", "posites2", intrinsic_name="tir.exp")
156156
register_op(create_lower_func({
157157
32: 'Posit32es2Log',
158158
16: 'Posit16es2Log',
159159
8: 'Posit8es2Log'
160-
}), "Call", "llvm", "posites2", intrinsic_name="log")
160+
}), "Call", "llvm", "posites2", intrinsic_name="tir.log")
161161
register_op(create_lower_func({
162162
32: 'Posit32es2Sigmoid',
163163
16: 'Posit16es2Sigmoid',
164164
8: 'Posit8es2Sigmoid'
165-
}), "Call", "llvm", "posites2", intrinsic_name="sigmoid")
165+
}), "Call", "llvm", "posites2", intrinsic_name="tir.sigmoid")
166166
register_op(create_lower_func({
167167
32: 'Posit32es2Tanh',
168168
16: 'Posit16es2Tanh',
169169
8: 'Posit8es2Tanh'
170-
}), "Call", "llvm", "posites2", intrinsic_name="tanh")
170+
}), "Call", "llvm", "posites2", intrinsic_name="tir.tanh")
171171
register_min_func(lambda num_bits: - (2 ** 2 ** 2) ** (num_bits - 2), "posites2")
172172

173173
def run_ops(src_dtype, dst_dtype, rtol=1e-7, atol=1e-7):

0 commit comments

Comments
 (0)