Skip to content

Commit

Permalink
[Bugfix][VM] Fix var binding to a ConstantNode; Force VM if.cond regi…
Browse files Browse the repository at this point in the history
…ster to take an NDArray instead of POD. (apache#216)

Fix the bug in apache#212. The cause of this bug is VM Codegen did not handle binding ConstantNode to variable (`x = relax.const([1, 2])`) and save the constant NDArray to the register. Previously the codegen only handles the case where ConstantNode as CallNode's arguments. Now it's fixed and unit test is added. 

Fix the bug in tlc-pack/relax#214 (comment), the issue was caused by the VM simply read the condition register of the If instruction, and expect it to be a POD int or bool. tlc-pack/relax@811e877 adds a `LoadScalarInt` function similar to the Relay VM to check the If.cond register stores an NDArray, and cast it to int_64. Since we haven't introduced PrimValue and PrimType (that represents POD values like int and bool) to the Relax language yet, let's enforce `If->cond` to be a Tensor (NDArray at runtime).
  • Loading branch information
YuchenJin authored and junrushao committed Feb 5, 2023
1 parent cce713c commit 88f56a4
Show file tree
Hide file tree
Showing 4 changed files with 182 additions and 38 deletions.
7 changes: 6 additions & 1 deletion include/tvm/runtime/relax_vm/vm.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,12 @@ class VirtualMachine : public runtime::ModuleNode {
* \return The object representing the result.
*/
RegType Invoke(Index fidx, const std::vector<RegType>& args);

/*!
* \brief Read a VM register and cast it to int64_t.
* \param reg The register to read from.
* \return The read scalar.
*/
int64_t LoadScalarInt(RegName reg) const;
/*! \brief Run VM dispatch loop. */
void RunLoop();
/*!
Expand Down
66 changes: 35 additions & 31 deletions src/relax/backend/vm/codegen_vm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

/*!
* \file src/relax/backend/vm/codegen_vm.cc
* \brief A codegen to generate VM executable from an IRModule with relax functions.
* \brief A codegen to generate VM executable from a Relax IRModule.
*/

#include "codegen_vm.h"
Expand Down Expand Up @@ -127,7 +127,7 @@ class CodeGenVM : public ExprFunctor<Instruction::Arg(const Expr&)> {
} else if (call_node->op == invoke_closure_op_) {
return EmitInvokeClosure(call);
} else {
// every "normal" operator is lowered to a global var in the IR module. The Attrs for those
// every "normal" operator is lowered to a global var in the IRModule. The Attrs for those
// ops are handled in a pass when lowering them to TIR.
LOG(FATAL) << "CodeGenVM cannot handle this intrinsic now:\n" << call_node->op;
}
Expand All @@ -142,17 +142,16 @@ class CodeGenVM : public ExprFunctor<Instruction::Arg(const Expr&)> {
LOG(FATAL) << "CodeGenVM does not support calls to " << call_node->op->GetTypeKey();
}
std::vector<Instruction::Arg> args;
// TODO(prakalp): For extern function `vm.builtin.alloc_shape_heap` we must pass vm register as
// well to find the device in which shape heap must be allocated.
// For extern function `vm.builtin.alloc_shape_heap` we must pass vm register as the first
// argument to find the device in which shape heap should be allocated.
if (name == "vm.builtin.alloc_shape_heap") {
args.push_back(Instruction::Arg(Instruction::kRegister, Instruction::kVMRegister));
}
for (auto arg : call_node->args) {
args.push_back(this->VisitExpr(arg));
}
size_t arg_register = NewRegister();
builder_->EmitCall(name, args, arg_register);
return Instruction::Arg(Instruction::kRegister, arg_register);
std::vector<Instruction::Arg> converted_args = ConvertArgs(GetRef<Call>(call_node));
args.insert(args.end(), converted_args.begin(), converted_args.end());
size_t dst_register = NewRegister();
builder_->EmitCall(name, args, dst_register);
return Instruction::Arg(Instruction::kRegister, dst_register);
}

Instruction::Arg VisitExpr_(const IfNode* op) {
Expand Down Expand Up @@ -208,7 +207,12 @@ class CodeGenVM : public ExprFunctor<Instruction::Arg(const Expr&)> {
TVMRetValue constant_data;
constant_data = op->data;
Index index = this->builder_->EmitConstant(constant_data);
return Instruction::Arg(Instruction::kConstIdx, index);

size_t dst_register = NewRegister();
std::vector<Instruction::Arg> args;
args.push_back(Instruction::Arg(Instruction::kConstIdx, index));
builder_->EmitCall("vm.builtin.copy", args, dst_register);
return Instruction::Arg(Instruction::kRegister, dst_register);
}

Instruction::Arg VisitExpr_(const ShapeExprNode* op) {
Expand All @@ -232,10 +236,10 @@ class CodeGenVM : public ExprFunctor<Instruction::Arg(const Expr&)> {
for (auto arg : tuple->fields) {
args.push_back(this->VisitExpr(arg));
}
size_t arg_register = NewRegister();
builder_->EmitCall("runtime.Tuple", args, arg_register);
size_t dst_register = NewRegister();
builder_->EmitCall("runtime.Tuple", args, dst_register);

return Instruction::Arg(Instruction::kRegister, arg_register);
return Instruction::Arg(Instruction::kRegister, dst_register);
}

Instruction::Arg VisitExpr_(const TupleGetItemNode* op) {
Expand All @@ -249,10 +253,10 @@ class CodeGenVM : public ExprFunctor<Instruction::Arg(const Expr&)> {
Index index = builder_->EmitConstant(shape_tuple_value);
args.push_back(Instruction::Arg(Instruction::kConstIdx, index));

size_t arg_register = NewRegister();
builder_->EmitCall("vm.runtime.TupleGetItem", args, arg_register);
size_t dst_register = NewRegister();
builder_->EmitCall("vm.runtime.TupleGetItem", args, dst_register);

return Instruction::Arg(Instruction::kRegister, arg_register);
return Instruction::Arg(Instruction::kRegister, dst_register);
}

Instruction::Arg EmitAllocStorage(const Call& call_node) {
Expand All @@ -274,9 +278,9 @@ class CodeGenVM : public ExprFunctor<Instruction::Arg(const Expr&)> {
Index index = this->builder_->EmitConstant(data_type);
args.push_back(Instruction::Arg(Instruction::kConstIdx, index));

size_t arg_register = NewRegister();
builder_->EmitCall("vm.builtin.alloc_storage", args, arg_register);
return Instruction::Arg(Instruction::kRegister, arg_register);
size_t dst_register = NewRegister();
builder_->EmitCall("vm.builtin.alloc_storage", args, dst_register);
return Instruction::Arg(Instruction::kRegister, dst_register);
}

Instruction::Arg EmitAllocTensor(const Call& call_node) {
Expand All @@ -298,9 +302,9 @@ class CodeGenVM : public ExprFunctor<Instruction::Arg(const Expr&)> {
data_type = dtype;
Index index = this->builder_->EmitConstant(data_type);
args.push_back(Instruction::Arg(Instruction::kConstIdx, index));
size_t arg_register = NewRegister();
builder_->EmitCall("vm.builtin.alloc_tensor", args, arg_register);
return Instruction::Arg(Instruction::kRegister, arg_register);
size_t dst_register = NewRegister();
builder_->EmitCall("vm.builtin.alloc_tensor", args, dst_register);
return Instruction::Arg(Instruction::kRegister, dst_register);
}

Instruction::Arg EmitShape(const Call& call_node) {
Expand All @@ -323,13 +327,13 @@ class CodeGenVM : public ExprFunctor<Instruction::Arg(const Expr&)> {
Index index = builder_->EmitConstant(indices_const);
args.push_back(Instruction::Arg(Instruction::kConstIdx, index));

size_t arg_register = NewRegister();
size_t dst_register = NewRegister();
if (call_node->op == store_shape_op_) {
builder_->EmitCall("vm.builtin.store_shape", args, arg_register);
builder_->EmitCall("vm.builtin.store_shape", args, dst_register);
} else if (call_node->op == load_shape_op_) {
builder_->EmitCall("vm.builtin.load_shape", args, arg_register);
builder_->EmitCall("vm.builtin.load_shape", args, dst_register);
}
return Instruction::Arg(Instruction::kRegister, arg_register);
return Instruction::Arg(Instruction::kRegister, dst_register);
}

Instruction::Arg EmitTirDynOp(const Call& call_node) {
Expand Down Expand Up @@ -393,11 +397,11 @@ class CodeGenVM : public ExprFunctor<Instruction::Arg(const Expr&)> {
// attributes.
Instruction::Arg EmitPackedFuncCall(const Call& call_node, const FCallPacked& name) {
std::vector<Instruction::Arg> args;
for (auto arg : call_node->args) args.push_back(this->VisitExpr(arg));
args = ConvertArgs(call_node);
AppendAttrsAsConstants(call_node, args);
size_t arg_register = NewRegister();
builder_->EmitCall(name, args, arg_register);
return Instruction::Arg(Instruction::kRegister, arg_register);
size_t dst_register = NewRegister();
builder_->EmitCall(name, args, dst_register);
return Instruction::Arg(Instruction::kRegister, dst_register);
}

Instruction::Arg EmitAllocClosure(const Call& call_node) {
Expand Down
36 changes: 35 additions & 1 deletion src/runtime/relax_vm/vm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,40 @@ void VirtualMachine::RunInstrCall(VMFrame* curr_frame, Instruction instr) {
pc_++;
}

int64_t VirtualMachine::LoadScalarInt(RegName reg) const {
int64_t result = 0;
VMFrame* curr_frame = frames_.back().get();
const RegType& obj = ReadRegister(curr_frame, reg);
NDArray ndarray = obj.operator tvm::runtime::NDArray();
NDArray ndarray_host = ndarray.CopyTo(devices[0]);

switch (ndarray_host->dtype.bits) {
case 1: {
result = reinterpret_cast<bool*>(ndarray_host->data)[0];
break;
}
case 8: {
result = reinterpret_cast<int8_t*>(ndarray_host->data)[0];
break;
}
case 16: {
result = reinterpret_cast<int16_t*>(ndarray_host->data)[0];
break;
}
case 32: {
result = reinterpret_cast<int32_t*>(ndarray_host->data)[0];
break;
}
case 64: {
result = reinterpret_cast<int64_t*>(ndarray_host->data)[0];
break;
}
default:
LOG(FATAL) << "Unknown scalar int type: " << DLDataType2String(ndarray_host->dtype);
}
return result;
}

void VirtualMachine::RunLoop() {
VMFrame* curr_frame = frames_.back().get();

Expand Down Expand Up @@ -411,7 +445,7 @@ void VirtualMachine::RunLoop() {
break;
}
case Opcode::If: {
int64_t cond_val = ReadRegister(curr_frame, instr.cond);
int64_t cond_val = LoadScalarInt(instr.cond);
if (cond_val != 0) {
pc_++;
} else {
Expand Down
111 changes: 106 additions & 5 deletions tests/python/relax/test_vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import os
from typing import Any, Callable, List, Tuple

import sys
import tempfile
import numpy as np
import pytest
import tvm
Expand Down Expand Up @@ -50,7 +52,7 @@ def mul(a, b):
@tvm.register_func("test.vm.equal_zero")
def equal_zero(a):
ret = np.all((a.numpy() == 0))
return bool(ret)
return tvm.nd.array(ret)


@tvm.register_func("test.vm.subtract_one")
Expand Down Expand Up @@ -298,9 +300,9 @@ def test_vm_if():
4,
)
)
res = vm["main"](False, a, b)
res = vm["main"](tvm.nd.array(False), a, b)
tvm.testing.assert_allclose(res.numpy(), a.numpy() * b.numpy(), rtol=1e-7, atol=1e-7)
res = vm["main"](1, a, b)
res = vm["main"](tvm.nd.array(1), a, b)
tvm.testing.assert_allclose(res.numpy(), a.numpy() + b.numpy(), rtol=1e-7, atol=1e-7)


Expand All @@ -320,9 +322,13 @@ def ife(cond: Tensor((), "bool"), x: Tensor((3, 4), "float32")) -> Tensor:
ex = relax.vm.build(mod, target)
vm = relax.VirtualMachine(ex, tvm.cpu())
inp = tvm.nd.array(np.random.rand(3, 4))
res = vm["ife"](True, inp)
res = vm["ife"](tvm.nd.array(1), inp)
tvm.testing.assert_allclose(res.numpy(), inp.numpy() + inp.numpy(), rtol=1e-7, atol=1e-7)
res = vm["ife"](tvm.nd.array(True), inp)
tvm.testing.assert_allclose(res.numpy(), inp.numpy() + inp.numpy(), rtol=1e-7, atol=1e-7)
res = vm["ife"](0, inp)
res = vm["ife"](tvm.nd.array(0), inp)
tvm.testing.assert_allclose(res.numpy(), inp.numpy() * inp.numpy(), rtol=1e-7, atol=1e-7)
res = vm["ife"](tvm.nd.array(False), inp)
tvm.testing.assert_allclose(res.numpy(), inp.numpy() * inp.numpy(), rtol=1e-7, atol=1e-7)


Expand Down Expand Up @@ -784,6 +790,101 @@ def tuple_get_item(x: Tensor((_, _), "float32"), y: Tensor((_, _), "float32")):
tvm.testing.assert_allclose(res.numpy(), x_inp.numpy() + y_inp.numpy(), rtol=1e-7, atol=1e-7)


def test_vm_print_const():
@tvm.script.ir_module
class PrintConst:
@R.function
def main():
x = relax.const([1, 2])
y = relax.print(x)
return x

try:
stdout = sys.stdout
with tempfile.TemporaryFile(mode="w+") as test_out:
sys.stdout = test_out
mod = PrintConst
target = tvm.target.Target("llvm", host="llvm")
ex = relax.vm.build(mod, target)
vm = relax.VirtualMachine(ex, tvm.cpu())
res = vm["main"]()
test_out.seek(0)
printed_text = str(test_out.read())
expected = "[1 2]\n"
assert printed_text == expected
tvm.testing.assert_allclose(res.numpy(), np.array([1, 2]))
finally:
sys.stdout = stdout


def test_vm_return_const_tuple():
@tvm.script.ir_module
class ReturnConstTuple:
@R.function
def main(x: Tensor((_, _), "float32")):
y = relax.const([1, 2])
z = (y, relax.const([3, 4]), x)
return z

mod = ReturnConstTuple
target = tvm.target.Target("llvm", host="llvm")
ex = relax.vm.build(mod, target)
vm = relax.VirtualMachine(ex, tvm.cpu())
inp = tvm.nd.array(np.random.rand(2, 3))
res0, res1, res2 = vm["main"](inp)
tvm.testing.assert_allclose(res0.numpy(), np.array([1, 2]))
tvm.testing.assert_allclose(res1.numpy(), np.array([3, 4]))
tvm.testing.assert_allclose(res2.numpy(), inp.numpy())


def test_vm_const_as_call_arg():
@tvm.script.ir_module
class TestVMConstAsCallArg:
@R.function
def main(x: Tensor((_, _), "float32")):
a = relax.call_packed(
"test.vm.add",
relax.const([1, 2]),
relax.const([3, 4]),
type_args=(Tensor(ndim=2, dtype="float32")),
)
b = relax.call_packed(
"test.vm.add",
a,
x,
type_args=(Tensor(ndim=2, dtype="float32")),
)
return b

mod = TestVMConstAsCallArg
target = tvm.target.Target("llvm", host="llvm")
ex = relax.vm.build(mod, target)
vm = relax.VirtualMachine(ex, tvm.cpu())
inp = tvm.nd.array(np.random.rand(1, 2))
res = vm["main"](inp)
tvm.testing.assert_allclose(res.numpy(), np.array([4, 6]) + inp.numpy())


def test_vm_if_cond_const():
@tvm.script.ir_module
class TestVMIfCondConst:
@R.function
def main(x: Tensor((_, _), "float32")) -> Tensor((1,), "int32"):
if relax.const(True, dtype="bool"):
ret = x
else:
ret = x
return ret

mod = TestVMIfCondConst
target = tvm.target.Target("llvm", host="llvm")
ex = relax.vm.build(mod, target)
vm = relax.VirtualMachine(ex, tvm.cpu())
inp = tvm.nd.array(np.random.rand(3, 4))
res = vm["main"](inp)
tvm.testing.assert_allclose(res.numpy(), inp.numpy())


def test_sub_func_call():
@tvm.script.ir_module
class TestVMSubFunction:
Expand Down

0 comments on commit 88f56a4

Please sign in to comment.