Skip to content

Commit

Permalink
[TIR] Support Return in TIR (apache#7084)
Browse files Browse the repository at this point in the history
  • Loading branch information
ZihengJiang authored Jan 16, 2021
1 parent 3f15d06 commit 052ad3d
Show file tree
Hide file tree
Showing 10 changed files with 185 additions and 10 deletions.
4 changes: 4 additions & 0 deletions include/tvm/tir/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ namespace tir {

/*! \brief Collection of builtin intrinsics as ops */
namespace builtin {
/*!
* \brief Return value.
*/
TVM_DLL const Op& ret();
/*!
* \brief Reinterpret the value using the target type.
*/
Expand Down
9 changes: 9 additions & 0 deletions include/tvm/tir/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,15 @@ TVM_DLL Type GetType(const PrimExpr& expr);
*/
TVM_DLL runtime::DataType GetRuntimeDataType(const Type& type);

/*!
* \brief Return the value.
*
* \param value The returned value.
* \param span The location of this operation in the source.
* \return The return expression.
*/
TVM_DLL PrimExpr ret(PrimExpr value, Span span = Span());

/*!
* Query the maximum possible value of dtype.
* \param dtype The data type.
Expand Down
6 changes: 5 additions & 1 deletion include/tvm/tir/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,11 @@ enum class CallEffectKind : int {
/*!
* \brief Embed opaque information in the Expr, cannot be codegen.
*/
kEmbedInfo = 5
kEmbedInfo = 5,
/*!
* \brief Function that changes control flow
*/
kControlJump = 6,
};

/*! \brief Use integer to record the kind. */
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from .function import PrimFunc

from .op import call_packed, call_intrin, call_pure_extern, call_extern
from .op import call_llvm_intrin, call_llvm_pure_intrin, all, any, min_value, max_value, trace
from .op import call_llvm_intrin, call_llvm_pure_intrin, ret, all, any, min_value, max_value, trace
from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp
from .op import sin, sinh, asin, asinh
from .op import cos, cosh, acos, acosh
Expand Down
28 changes: 22 additions & 6 deletions python/tvm/tir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,22 @@ def call_llvm_pure_intrin(dtype, name, *args, span=None):
)


def ret(val):
"""Create a tir return expression
Parameters
----------
val : Expr
The returned tir expression, whose data type is int, float or void pointer.
Returns
-------
ret : PrimExpr
The return expression
"""
return call_intrin(val.dtype, "tir.ret", val)


def any(*args, span=None):
"""Create a new experssion of the union of all conditions in the arguments
Expand All @@ -241,10 +257,10 @@ def any(*args, span=None):
raise ValueError("Any must take at least 1 argument")
if len(args) == 1:
return args[0]
ret = _ffi_api._OpOr(args[0], args[1], span)
val = _ffi_api._OpOr(args[0], args[1], span)
for i in range(2, len(args)):
ret = _ffi_api._OpOr(ret, args[i], span)
return ret
val = _ffi_api._OpOr(val, args[i], span)
return val


def all(*args, span=None):
Expand All @@ -268,10 +284,10 @@ def all(*args, span=None):
raise ValueError("Any must take at least 1 argument")
if len(args) == 1:
return args[0]
ret = _ffi_api._OpAnd(args[0], args[1], span)
val = _ffi_api._OpAnd(args[0], args[1], span)
for i in range(2, len(args)):
ret = _ffi_api._OpAnd(ret, args[i], span)
return ret
val = _ffi_api._OpAnd(val, args[i], span)
return val


@tvm._ffi.register_func("tvm.default_trace_action")
Expand Down
12 changes: 12 additions & 0 deletions src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -927,6 +927,18 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
value->addIncoming(then_value, then_value_block);
value->addIncoming(else_value, else_value_block);
return value;
} else if (op->op.same_as(builtin::ret())) {
auto const* val = op->args[0].as<IntImmNode>();
ICHECK(val) << "the tir.ret should be transformed to return zero "
<< "before the llvm code generation.";
ICHECK_EQ(val->value, 0) << "the tir.ret should be transformed to "
<< "return zero before the llvm code generation.";
builder_->CreateRet(ConstInt32(0));
// LLVM allows exactly one terminator in a single basic block
// append a new dummy basic block to avoid error.
llvm::BasicBlock* ret_dummy = llvm::BasicBlock::Create(*ctx_, "ret_dummy", function_);
builder_->SetInsertPoint(ret_dummy);
return ret_dummy;
} else if (op->op.same_as(builtin::reinterpret())) {
llvm::Type* target = DTypeToLLVMType(op->dtype);
return builder_->CreateBitCast(MakeValue(op->args[0]), target);
Expand Down
4 changes: 4 additions & 0 deletions src/tir/op/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ TIR_DEFINE_BUILTIN_FUNC(reinterpret)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure))
.set_num_inputs(1);

TIR_DEFINE_BUILTIN_FUNC(ret)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kControlJump))
.set_num_inputs(1);

TIR_DEFINE_BUILTIN_FUNC(likely)
.set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kExprAnnotation))
Expand Down
4 changes: 4 additions & 0 deletions src/tir/op/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,10 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, Span span) { // NOLINT(*)
}
}

PrimExpr ret(PrimExpr value, Span span) {
return tir::Call(value.dtype(), tir::builtin::ret(), {value}, span);
}

// maximum and min limits
PrimExpr max_value(const DataType& dtype, Span span) {
using namespace tir;
Expand Down
66 changes: 64 additions & 2 deletions src/tir/transforms/make_packed_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,67 @@
namespace tvm {
namespace tir {

class ReturnRewriter : public StmtMutator {
public:
explicit ReturnRewriter(Var ret_var, Var ret_tcode) : ret_var_(ret_var), ret_tcode_(ret_tcode) {}

Stmt VisitStmt_(const ForNode* node) override {
if (node->for_type == ForType::Parallel) in_parallel_ += 1;
Stmt ret = StmtMutator::VisitStmt_(node);
if (node->for_type == ForType::Parallel) in_parallel_ -= 1;
return ret;
}

Stmt VisitStmt_(const EvaluateNode* node) override {
Stmt ret = StmtMutator::VisitStmt_(node);
const EvaluateNode* eval = ret.as<EvaluateNode>();
ICHECK(eval);
if (const CallNode* call = eval->value.as<CallNode>()) {
if (call->op.same_as(builtin::ret())) {
ICHECK_EQ(in_parallel_, 0) << "tir.ret cannot be used in parallel scope.";
ICHECK_EQ(call->args.size(), 1) << "tir.ret expect a single argument.";
ret = WriteToOut(call->args[0], ret_var_, ret_tcode_);
}
}
return ret;
}

private:
std::pair<int, PrimExpr> ConvertForFFI(PrimExpr val) {
// convert val's data type to FFI data type, return type code
DataType dtype = val.dtype();
if (dtype.is_int() || dtype.is_uint()) {
return {kTVMArgInt, Cast(DataType::Int(64), val)};
} else if (dtype.is_float()) {
return {kTVMArgFloat, Cast(DataType::Float(64), val)};
} else if (dtype.is_void()) {
return {kTVMNullptr, val};
} else {
LOG(FATAL) << "data type " << dtype << " not supported yet";
}
return {kTVMNullptr, val};
}

Stmt WriteToOut(PrimExpr val, Var ret_var, Var ret_tcode) {
auto p = ConvertForFFI(val);
int tcode = p.first;
val = p.second;
Stmt store_val = Store(ret_var_, val, 0, const_true());
Stmt store_tcode = Store(ret_tcode_, tcode, 0, const_true());
Stmt ret_zero = Evaluate(tvm::ret(0));
return SeqStmt({store_val, store_tcode, ret_zero});
}

Var ret_var_;
Var ret_tcode_;
int in_parallel_{0};
};

Stmt RewriteReturn(Stmt body, Var ret_var, Var ret_tcode) {
ReturnRewriter rewriter(ret_var, ret_tcode);
return rewriter(body);
}

inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) {
return AssertStmt(lhs == rhs, tvm::tir::StringImm(msg), Evaluate(0));
}
Expand Down Expand Up @@ -182,8 +243,9 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) {
func = WithAttr(std::move(func), tvm::attr::kCallingConv, Integer(CallingConv::kCPackedFunc));
}

Stmt body = AttrStmt(make_zero(DataType::Int(32)), attr::compute_scope,
StringImm(name_hint + "_compute_"), func_ptr->body);
Stmt body = RewriteReturn(func_ptr->body, v_out_ret_value, v_out_ret_tcode);
body = AttrStmt(make_zero(DataType::Int(32)), attr::compute_scope,
StringImm(name_hint + "_compute_"), body);
// Set device context
if (vmap.count(device_id.get())) {
PrimExpr node = StringImm("default");
Expand Down
60 changes: 60 additions & 0 deletions tests/python/unittest/test_tir_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm import tir
from tvm.ir.transform import PassContext


def build_tir_func(func):
func = func.with_attr("global_symbol", "main")
pass_ctx = PassContext.current()
if pass_ctx.config.get("tir.noalias", True):
func = func.with_attr("tir.noalias", True)
mod = tvm.IRModule({"main": func})
func = tvm.build(mod)
return func


def test_scalar_add():
a = tir.Var("a", "float32")
b = tir.Var("b", "float32")
c = a + b
c = tir.ret(c)
c = tir.Evaluate(c)
func = tir.PrimFunc([a, b], c)
func = build_tir_func(func)
out = func(1.0, 2.0)
assert out == 3.0


def test_control_flow_jump():
ib = tvm.tir.ir_builder.create()
a = tir.Var("a", "float32")
b = tir.Var("b", "float32")
with ib.if_scope(True):
ib.emit(tir.Evaluate(tir.ret(a)))
ib.emit(tir.Evaluate(tir.ret(b)))
stmt = ib.get()
func = tir.PrimFunc([a, b], stmt)
func = build_tir_func(func)
out = func(1.0, 2.0)
assert out == 1.0


if __name__ == "__main__":
test_scalar_add()
test_control_flow_jump()

0 comments on commit 052ad3d

Please sign in to comment.