Skip to content

Commit

Permalink
Revert "[FFI][RUNTIME] Introduce runtime boxed types for int/float/bo…
Browse files Browse the repository at this point in the history
…ol" (#17252)

Revert "[FFI][RUNTIME] Introduce runtime boxed types for int/float/bool (#16183)"

This reverts commit 5f22be4.
  • Loading branch information
tqchen authored Aug 7, 2024
1 parent 05e2bc3 commit 11be832
Show file tree
Hide file tree
Showing 184 changed files with 1,221 additions and 3,215 deletions.
76 changes: 18 additions & 58 deletions include/tvm/ir/attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -265,16 +265,7 @@ class DictAttrs : public Attrs {

auto it = node->dict.find(attr_key);
if (it != node->dict.end()) {
// For backwards compatibility, return through TVMRetValue.
// This triggers any automatic conversions registered with
// PackedFuncValueConverter. Importantly, this allows use of
// `GetAttr<Integer>` and `GetAttr<Bool>` for properties that
// are stored internally as `runtime::Box<int64_t>` and
// `runtime::Box<bool>`.
TVMRetValue ret;
ret = (*it).second;
Optional<TObjectRef> obj = ret;
return obj;
return Downcast<Optional<TObjectRef>>((*it).second);
} else {
return default_value;
}
Expand Down Expand Up @@ -324,46 +315,6 @@ inline TAttrs AttrsWithDefaultValues() {
return TAttrs(n);
}

/*!
* \brief Copy the DictAttrs, but overrides attributes with the
* entries from \p attrs.
*
* \param attrs The DictAttrs to update
*
* \param new_attrs Key/values attributes to add to \p attrs.
*
* \returns The new DictAttrs with updated attributes.
*/
DictAttrs WithAttrs(DictAttrs attrs, Map<String, ObjectRef> new_attrs);

/*!
* \brief Copy the DictAttrs, but overrides a single attribute.
*
* \param attrs The DictAttrs to update
*
* \param key The update to insert or update.
*
* \param value The new value of the attribute
*
* \returns The new DictAttrs with updated attributes.
*/
DictAttrs WithAttr(DictAttrs attrs, String key, ObjectRef value);

inline DictAttrs WithAttr(DictAttrs attrs, const std::string& key, ObjectRef value) {
return WithAttr(std::move(attrs), String(key), std::move(value));
}

/*!
* \brief Copy the DictAttrs, but without a specific attribute.
*
* \param attrs The DictAttrs to update
*
* \param key The key to remove
*
* \returns The new DictAttrs with updated attributes.
*/
DictAttrs WithoutAttr(DictAttrs attrs, const std::string& key);

/*!
* \brief Copy the function or module, but overrides
* the attribute value key with the value.
Expand Down Expand Up @@ -396,8 +347,12 @@ inline TFunc WithAttr(TFunc input, const std::string& attr_key, ObjectRef attr_v
using TNode = typename TFunc::ContainerType;
static_assert(TNode::_type_final, "Can only operate on the leaf nodes");
TNode* node = input.CopyOnWrite();
node->attrs = WithAttr(std::move(node->attrs), attr_key, attr_value);

if (node->attrs.defined()) {
node->attrs.CopyOnWrite()->dict.Set(attr_key, attr_value);
} else {
Map<String, ObjectRef> dict = {{attr_key, attr_value}};
node->attrs = DictAttrs(dict);
}
return input;
}

Expand All @@ -416,9 +371,13 @@ inline TFunc WithAttrs(TFunc input, Map<String, ObjectRef> attrs) {
using TNode = typename TFunc::ContainerType;
static_assert(TNode::_type_final, "Can only operate on the leaf nodes");
TNode* node = input.CopyOnWrite();

node->attrs = WithAttrs(std::move(node->attrs), attrs);

if (node->attrs.defined()) {
for (const auto& pair : attrs) {
node->attrs.CopyOnWrite()->dict.Set(pair.first, pair.second);
}
} else {
node->attrs = DictAttrs(std::move(attrs));
}
return input;
}

Expand Down Expand Up @@ -453,9 +412,10 @@ inline TFunc WithoutAttr(TFunc input, const std::string& attr_key) {
using TNode = typename TFunc::ContainerType;
static_assert(TNode::_type_final, "Can only operate on the leaf nodes");

TNode* node = input.CopyOnWrite();
node->attrs = WithoutAttr(std::move(node->attrs), attr_key);

if (input->attrs.defined()) {
TNode* node = input.CopyOnWrite();
node->attrs.CopyOnWrite()->dict.erase(attr_key);
}
return input;
}

Expand Down
130 changes: 31 additions & 99 deletions include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -770,121 +770,53 @@ inline const TTypeNode* RelayExprNode::type_as() const {

namespace tvm {
namespace runtime {

// Automatic conversion into IntImm, Integer, and Bool, when called
// through the FFI. Automatic conversions into PrimExpr are
// registered in "tvm/tir/expr.h", as it includes conversions to the
// TIR-only StringImm.
//
// While the FFI only requires the From() method, these
// implementations also define a TryFrom() method to avoid duplicate
// logic in the PrimExpr conversion.

// common rule for RetValue and ArgValue
template <>
struct PackedFuncValueConverter<tvm::IntImm> {
template <typename PODSubclass>
static Optional<tvm::IntImm> TryFrom(const PODSubclass& val) {
if (auto opt = val.TryAsInt()) {
int64_t value = opt.value();
auto dtype =
(value > std::numeric_limits<int>::max() || value < std::numeric_limits<int>::min())
? DataType::Int(64)
: DataType::Int(32);
return IntImm(dtype, value);
} else if (auto opt = val.TryAsBool()) {
return IntImm(DataType::Int(32), opt.value());
} else {
return NullOpt;
struct PackedFuncValueConverter<PrimExpr> {
static PrimExpr From(const TVMPODValue_& val) {
if (val.type_code() == kTVMNullptr) {
return PrimExpr(ObjectPtr<Object>(nullptr));
}
}

template <typename PODSubclass>
static tvm::IntImm From(const PODSubclass& val) {
if (auto opt = TryFrom(val)) {
return opt.value();
} else {
return val.template AsObjectRef<tvm::IntImm>();
if (val.type_code() == kDLInt) {
int64_t value = val.operator int64_t();
if (value > std::numeric_limits<int>::max() || value < std::numeric_limits<int>::min()) {
return IntImm(runtime::DataType::Int(64), value);
}
return IntImm(runtime::DataType::Int(32), val.operator int());
}
}
};

template <>
struct PackedFuncValueConverter<tvm::Integer> {
template <typename PODSubclass>
static tvm::Integer From(const PODSubclass& val) {
if (auto opt = PackedFuncValueConverter<tvm::IntImm>::TryFrom(val)) {
return Integer(opt.value());
} else {
return val.template AsObjectRef<tvm::Integer>();
if (val.type_code() == kDLFloat) {
return FloatImm(runtime::DataType::Float(32), val.operator double());
}
}
};

template <>
struct PackedFuncValueConverter<tvm::Bool> {
template <typename PODSubclass>
static Optional<tvm::Bool> TryFrom(const PODSubclass& val) {
if (auto opt = val.TryAsBool()) {
return tvm::Bool(opt.value());
} else if (auto opt = val.TryAsInt()) {
int value = opt.value();
ICHECK(value == 0 || value == 1)
<< "ValueError: boolean value can only be 0 or 1, but get " << value;
return tvm::Bool(static_cast<bool>(value));
} else {
return NullOpt;
}
}

template <typename PODSubclass>
static tvm::Bool From(const PODSubclass& val) {
if (auto opt = TryFrom(val)) {
return opt.value();
} else {
return val.template AsObjectRef<tvm::Bool>();
}
return PrimExpr::FromObject_(val.AsObjectRef<ObjectRef>());
}
};

template <>
struct PackedFuncValueConverter<tvm::FloatImm> {
static Optional<tvm::FloatImm> TryFrom(const TVMPODValue_& val) {
if (auto opt = val.TryAsFloat()) {
return FloatImm(runtime::DataType::Float(32), opt.value());
} else {
return NullOpt;
struct PackedFuncValueConverter<tvm::Integer> {
static tvm::Integer From(const TVMPODValue_& val) {
if (val.type_code() == kTVMNullptr) {
return Integer(ObjectPtr<Object>(nullptr));
}
}

template <typename PODSubclass>
static tvm::FloatImm From(const PODSubclass& val) {
if (auto opt = TryFrom(val)) {
return opt.value();
} else {
return val.template AsObjectRef<tvm::FloatImm>();
if (val.type_code() == kTVMArgInt) {
return Integer(val.operator int());
}
return val.AsObjectRef<tvm::Integer>();
}
};

/* \brief Backwards compatibility wrapper for IntImm arguments
*
* In previous versions of TVM, IntImm was the default FFI type for
* integer arguments, instead of runtime::Int. For backwards
* compatibility where the callee has been updated to expected a
* runtime::Int, the caller has not been updated to provide a
* runtime::Int (e.g. relay script parsing), and the auto-unboxing of
* runtime::Int does not apply (e.g. making an `Array<runtime::Int>`),
* allow the IntImm to be generated.
*/
template <>
struct PackedFuncValueConverter<runtime::Int> {
template <typename PODSubclass>
static runtime::Int From(const PODSubclass& val) {
if (val.template IsObjectRef<tvm::IntImm>()) {
return runtime::Int(val.template AsObjectRef<tvm::IntImm>()->value);
} else {
return val.template AsObjectRef<runtime::Int>();
struct PackedFuncValueConverter<tvm::Bool> {
static tvm::Bool From(const TVMPODValue_& val) {
if (val.type_code() == kTVMNullptr) {
return Bool(ObjectPtr<Object>(nullptr));
}
if (val.type_code() == kTVMArgInt) {
int v = val.operator int();
ICHECK(v == 0 || v == 1) << "ValueError: boolean value can only be 0 or 1, but get " << v;
return Bool(static_cast<bool>(v));
}
return val.AsObjectRef<tvm::Bool>();
}
};

Expand Down
34 changes: 2 additions & 32 deletions include/tvm/ir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -271,36 +271,7 @@ class PassContext : public ObjectRef {
using ValueNodeType = typename ValueType::ContainerType;
// NOTE: we could further update the function later.
uint32_t tindex = ValueNodeType::_GetOrAllocRuntimeTypeIndex();
auto type_key = runtime::Object::TypeIndex2Key(tindex);

auto* reflection = ReflectionVTable::Global();

auto legalization = [=](ObjectRef obj) -> ObjectRef {
if (obj->IsInstance<Map<String, ObjectRef>::ContainerType>()) {
return reflection->CreateObject(type_key, Downcast<Map<String, ObjectRef>>(obj));
} else {
// Backwards compatibility for config options defined prior to
// https://github.com/apache/tvm/pull/16183. This commit
// changed the default FFI conversion of python integers from
// `tvm::IntImm` to `runtime::Int`.
//
// This backwards compatibility fix can be removed when all
// options registered with TVM_REGISTER_PASS_CONFIG_OPTION are
// updated to use `runtime::Int` and `runtime::Bool`.
TVMRetValue ret;
ret = obj;
try {
ValueType legalized = ret;
return legalized;
} catch (Error& err) {
LOG(FATAL) << "AttributeError: expect config " << key << " to have type " << type_key
<< ", but received error when converting to this type.\n"
<< err.what();
}
}
};

RegisterConfigOption(key, tindex, legalization);
RegisterConfigOption(key, tindex);
return tindex;
}

Expand All @@ -314,8 +285,7 @@ class PassContext : public ObjectRef {
// The exit of a pass context scope.
TVM_DLL void ExitWithScope();
// Register configuration key value type.
TVM_DLL static void RegisterConfigOption(const char* key, uint32_t value_type_index,
std::function<ObjectRef(ObjectRef)> legalization);
TVM_DLL static void RegisterConfigOption(const char* key, uint32_t value_type_index);

// Classes to get the Python `with` like syntax.
friend class Internal;
Expand Down
8 changes: 4 additions & 4 deletions include/tvm/meta_schedule/schedule_rule.h
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ class ScheduleRule : public runtime::ObjectRef {
* \param thread_extents Candidates of thread axis extent (values are required to be positive).
* \return The schedule rule created
*/
TVM_DLL static ScheduleRule CrossThreadReduction(Array<runtime::Int> thread_extents);
TVM_DLL static ScheduleRule CrossThreadReduction(Array<Integer> thread_extents);
/*!
* \brief A rule that randomly select a compute-at location for a free block
* \return The schedule rule created
Expand All @@ -260,9 +260,9 @@ class ScheduleRule : public runtime::ObjectRef {
* \param unroll_explicit Whether to explicitly unroll the loop, or just add an "unroll" pragma.
* \return The schedule rule created
*/
TVM_DLL static ScheduleRule ParallelizeVectorizeUnroll(int max_jobs_per_core, //
int max_vectorize_extent, //
Array<runtime::Int> unroll_max_steps, //
TVM_DLL static ScheduleRule ParallelizeVectorizeUnroll(int max_jobs_per_core, //
int max_vectorize_extent, //
Array<Integer> unroll_max_steps, //
bool unroll_explicit);
/*!
* \brief Auto bind loops around the block to BlockIdx and ThreadIdx
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ struct SqueezeAttrs : public tvm::AttrsNode<SqueezeAttrs> {
}; // struct SqueezeAttrs

struct SplitAttrs : public tvm::AttrsNode<SplitAttrs> {
Variant<runtime::Int, Array<runtime::Int>> indices_or_sections;
ObjectRef indices_or_sections;
int axis;

TVM_DECLARE_ATTRS(SplitAttrs, "relay.attrs.SplitAttrs") {
Expand Down
5 changes: 1 addition & 4 deletions include/tvm/runtime/c_runtime_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@
#ifdef __cplusplus
extern "C" {
#endif
#include <stdbool.h>
#include <stddef.h>
#include <stdint.h>

Expand Down Expand Up @@ -187,12 +186,11 @@ typedef enum {
kTVMBytes = 12U,
kTVMNDArrayHandle = 13U,
kTVMObjectRValueRefArg = 14U,
kTVMArgBool = 15U,
// Extension codes for other frameworks to integrate TVM PackedFunc.
// To make sure each framework's id do not conflict, use first and
// last sections to mark ranges.
// Open an issue at the repo if you need a section of code.
kTVMExtBegin = 16U,
kTVMExtBegin = 15U,
kTVMNNVMFirst = 16U,
kTVMNNVMLast = 20U,
// The following section of code is used for non-reserved types.
Expand All @@ -209,7 +207,6 @@ typedef DLTensor* TVMArrayHandle;
*/
typedef union {
int64_t v_int64;
bool v_bool;
double v_float64;
void* v_handle;
const char* v_str;
Expand Down
Loading

0 comments on commit 11be832

Please sign in to comment.