Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions docs/dev/virtual_machine.rst
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ AllocTensor
Allocate a tensor value of the appropriate shape (stored in `shape_register`) and `dtype`. The result
is saved to register `dst`.

AllocDatatype
AllocADT
^^^^^^^^^^^^^
**Arguments**:
::
Expand Down Expand Up @@ -176,7 +176,7 @@ GetTagi
RegName object
RegName dst

Get the object tag for Datatype object in register `object`. And saves the reult to register `dst`.
Get the object tag for ADT object in register `object`. And saves the reult to register `dst`.

Fatal
^^^^^
Expand Down Expand Up @@ -251,9 +251,9 @@ Currently, we support 3 types of objects: tensors, data types, and closures.

::

VMObject VMTensor(const tvm::runtime::NDArray& data);
VMObject VMDatatype(size_t tag, const std::vector<VMObject>& fields);
VMObject VMClosure(size_t func_index, std::vector<VMObject> free_vars);
Object Tensor(const tvm::runtime::NDArray& data);
Object ADT(size_t tag, const std::vector<Object>& fields);
Object Closure(size_t func_index, std::vector<Object> free_vars);


Stack and State
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/runtime/object.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ enum TypeIndex {
kRoot = 0,
kVMTensor = 1,
kVMClosure = 2,
kVMDatatype = 3,
kVMADT = 3,
kStaticIndexEnd,
/*! \brief Type index is allocated during runtime. */
kDynamic = kStaticIndexEnd
Expand Down
24 changes: 12 additions & 12 deletions include/tvm/runtime/vm.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,31 +57,31 @@ class Tensor : public ObjectRef {


/*! \brief An object representing a structure or enumeration. */
class DatatypeObj : public Object {
class ADTObj : public Object {
public:
/*! \brief The tag representing the constructor used. */
size_t tag;
/*! \brief The fields of the structure. */
std::vector<ObjectRef> fields;

static constexpr const uint32_t _type_index = TypeIndex::kVMDatatype;
static constexpr const char* _type_key = "vm.Datatype";
TVM_DECLARE_FINAL_OBJECT_INFO(DatatypeObj, Object);
static constexpr const uint32_t _type_index = TypeIndex::kVMADT;
static constexpr const char* _type_key = "vm.ADT";
TVM_DECLARE_FINAL_OBJECT_INFO(ADTObj, Object);
};

/*! \brief reference to data type. */
class Datatype : public ObjectRef {
/*! \brief reference to algebraic data type objects. */
class ADT : public ObjectRef {
public:
Datatype(size_t tag, std::vector<ObjectRef> fields);
ADT(size_t tag, std::vector<ObjectRef> fields);

/*!
* \brief construct a tuple object.
* \param fields The fields of the tuple.
* \return The constructed tuple type.
*/
static Datatype Tuple(std::vector<ObjectRef> fields);
static ADT Tuple(std::vector<ObjectRef> fields);

TVM_DEFINE_OBJECT_REF_METHODS(Datatype, ObjectRef, DatatypeObj);
TVM_DEFINE_OBJECT_REF_METHODS(ADT, ObjectRef, ADTObj);
};

/*! \brief An object representing a closure. */
Expand Down Expand Up @@ -129,7 +129,7 @@ enum class Opcode {
InvokePacked = 4U,
AllocTensor = 5U,
AllocTensorReg = 6U,
AllocDatatype = 7U,
AllocADT = 7U,
AllocClosure = 8U,
GetField = 9U,
If = 10U,
Expand Down Expand Up @@ -237,7 +237,7 @@ struct Instruction {
/*! \brief The register to project from. */
RegName object;
} get_tag;
struct /* AllocDatatype Operands */ {
struct /* AllocADT Operands */ {
/*! \brief The datatype's constructor tag. */
Index constructor_tag;
/*! \brief The number of fields to store in the datatype. */
Expand Down Expand Up @@ -294,7 +294,7 @@ struct Instruction {
* \param dst The register name of the destination.
* \return The allocate instruction tensor.
*/
static Instruction AllocDatatype(Index tag, Index num_fields, const std::vector<RegName>& fields,
static Instruction AllocADT(Index tag, Index num_fields, const std::vector<RegName>& fields,
RegName dst);
/*! \brief Construct an allocate closure instruction.
* \param func_index The index of the function table.
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/backend/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from .interpreter import Executor

Tensor = _obj.Tensor
Datatype = _obj.Datatype
ADT = _obj.ADT

def _convert(arg, cargs):
if isinstance(arg, (np.ndarray, tvm.nd.NDArray)):
Expand Down
20 changes: 10 additions & 10 deletions python/tvm/relay/backend/vmobj.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,14 @@ def asnumpy(self):
return self.data.asnumpy()


@register_object("vm.Datatype")
class Datatype(Object):
"""Datatype object.
@register_object("vm.ADT")
class ADT(Object):
"""Algebatic data type(ADT) object.

Parameters
----------
tag : int
The tag of datatype.
The tag of ADT.

fields : list[Object] or tuple[Object]
The source tuple.
Expand All @@ -77,22 +77,22 @@ def __init__(self, tag, fields):
for f in fields:
assert isinstance(f, Object)
self.__init_handle_by_constructor__(
_vmobj.Datatype, tag, *fields)
_vmobj.ADT, tag, *fields)

@property
def tag(self):
return _vmobj.GetDatatypeTag(self)
return _vmobj.GetADTTag(self)

def __getitem__(self, idx):
return getitem_helper(
self, _vmobj.GetDatatypeFields, len(self), idx)
self, _vmobj.GetADTFields, len(self), idx)

def __len__(self):
return _vmobj.GetDatatypeNumberOfFields(self)
return _vmobj.GetADTNumberOfFields(self)


def tuple_object(fields):
"""Create a datatype object from source tuple.
"""Create a ADT object from source tuple.

Parameters
----------
Expand All @@ -101,7 +101,7 @@ def tuple_object(fields):

Returns
-------
ret : Datatype
ret : ADT
The created object.
"""
for f in fields:
Expand Down
8 changes: 4 additions & 4 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
DLOG(INFO) << "VMCompiler::Emit: instr=" << instr;
CHECK((int)instr.op < 100) << "Invalid opcode " << (int)instr.op;
switch (instr.op) {
case Opcode::AllocDatatype:
case Opcode::AllocADT:
case Opcode::AllocTensor:
case Opcode::AllocTensorReg:
case Opcode::GetField:
Expand Down Expand Up @@ -287,7 +287,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
}

// TODO(@jroesch): use correct tag
Emit(Instruction::AllocDatatype(
Emit(Instruction::AllocADT(
0,
tuple->fields.size(),
fields_registers,
Expand Down Expand Up @@ -626,7 +626,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
for (size_t i = arity - return_count; i < arity; ++i) {
fields_registers.push_back(unpacked_arg_regs[i]);
}
Emit(Instruction::AllocDatatype(0, return_count, fields_registers, NewRegister()));
Emit(Instruction::AllocADT(0, return_count, fields_registers, NewRegister()));
}
}

Expand Down Expand Up @@ -659,7 +659,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
}
} else if (auto constructor_node = op.as<ConstructorNode>()) {
auto constructor = GetRef<Constructor>(constructor_node);
Emit(Instruction::AllocDatatype(constructor->tag, call_node->args.size(), args_registers,
Emit(Instruction::AllocADT(constructor->tag, call_node->args.size(), args_registers,
NewRegister()));
} else if (auto var_node = op.as<VarNode>()) {
VisitExpr(GetRef<Var>(var_node));
Expand Down
6 changes: 3 additions & 3 deletions src/runtime/vm/executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ VMInstructionSerializer SerializeInstruction(const Instruction& instr) {
fields.push_back(instr.dst);
break;
}
case Opcode::AllocDatatype: {
case Opcode::AllocADT: {
// Number of fields = 3 + instr.num_fields
fields.assign({instr.constructor_tag, instr.num_fields, instr.dst});

Expand Down Expand Up @@ -551,7 +551,7 @@ Instruction DeserializeInstruction(const VMInstructionSerializer& instr) {

return Instruction::AllocTensorReg(shape_register, dtype, dst);
}
case Opcode::AllocDatatype: {
case Opcode::AllocADT: {
// Number of fields = 3 + instr.num_fields
DCHECK_GE(instr.fields.size(), 3U);
DCHECK_EQ(instr.fields.size(), 3U + static_cast<size_t>(instr.fields[1]));
Expand All @@ -561,7 +561,7 @@ Instruction DeserializeInstruction(const VMInstructionSerializer& instr) {
RegName dst = instr.fields[2];
std::vector<Index> fields = ExtractFields(instr.fields, 3, num_fields);

return Instruction::AllocDatatype(constructor_tag, num_fields, fields, dst);
return Instruction::AllocADT(constructor_tag, num_fields, fields, dst);
}
case Opcode::AllocClosure: {
// Number of fields = 3 + instr.num_freevar
Expand Down
28 changes: 14 additions & 14 deletions src/runtime/vm/object.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,15 @@ Tensor::Tensor(NDArray data) {
data_ = std::move(ptr);
}

Datatype::Datatype(size_t tag, std::vector<ObjectRef> fields) {
auto ptr = make_object<DatatypeObj>();
ADT::ADT(size_t tag, std::vector<ObjectRef> fields) {
auto ptr = make_object<ADTObj>();
ptr->tag = tag;
ptr->fields = std::move(fields);
data_ = std::move(ptr);
}

Datatype Datatype::Tuple(std::vector<ObjectRef> fields) {
return Datatype(0, fields);
ADT ADT::Tuple(std::vector<ObjectRef> fields) {
return ADT(0, fields);
}

Closure::Closure(size_t func_index, std::vector<ObjectRef> free_vars) {
Expand All @@ -66,28 +66,28 @@ TVM_REGISTER_GLOBAL("_vmobj.GetTensorData")
*rv = cell->data;
});

TVM_REGISTER_GLOBAL("_vmobj.GetDatatypeTag")
TVM_REGISTER_GLOBAL("_vmobj.GetADTTag")
.set_body([](TVMArgs args, TVMRetValue* rv) {
ObjectRef obj = args[0];
const auto* cell = obj.as<DatatypeObj>();
const auto* cell = obj.as<ADTObj>();
CHECK(cell != nullptr);
*rv = static_cast<int64_t>(cell->tag);
});

TVM_REGISTER_GLOBAL("_vmobj.GetDatatypeNumberOfFields")
TVM_REGISTER_GLOBAL("_vmobj.GetADTNumberOfFields")
.set_body([](TVMArgs args, TVMRetValue* rv) {
ObjectRef obj = args[0];
const auto* cell = obj.as<DatatypeObj>();
const auto* cell = obj.as<ADTObj>();
CHECK(cell != nullptr);
*rv = static_cast<int64_t>(cell->fields.size());
});


TVM_REGISTER_GLOBAL("_vmobj.GetDatatypeFields")
TVM_REGISTER_GLOBAL("_vmobj.GetADTFields")
.set_body([](TVMArgs args, TVMRetValue* rv) {
ObjectRef obj = args[0];
int idx = args[1];
const auto* cell = obj.as<DatatypeObj>();
const auto* cell = obj.as<ADTObj>();
CHECK(cell != nullptr);
CHECK_LT(idx, cell->fields.size());
*rv = cell->fields[idx];
Expand All @@ -104,22 +104,22 @@ TVM_REGISTER_GLOBAL("_vmobj.Tuple")
for (auto i = 0; i < args.size(); ++i) {
fields.push_back(args[i]);
}
*rv = Datatype::Tuple(fields);
*rv = ADT::Tuple(fields);
});

TVM_REGISTER_GLOBAL("_vmobj.Datatype")
TVM_REGISTER_GLOBAL("_vmobj.ADT")
.set_body([](TVMArgs args, TVMRetValue* rv) {
int itag = args[0];
size_t tag = static_cast<size_t>(itag);
std::vector<ObjectRef> fields;
for (int i = 1; i < args.size(); i++) {
fields.push_back(args[i]);
}
*rv = Datatype(tag, fields);
*rv = ADT(tag, fields);
});

TVM_REGISTER_OBJECT_TYPE(TensorObj);
TVM_REGISTER_OBJECT_TYPE(DatatypeObj);
TVM_REGISTER_OBJECT_TYPE(ADTObj);
TVM_REGISTER_OBJECT_TYPE(ClosureObj);
} // namespace vm
} // namespace runtime
Expand Down
Loading