Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions ffi/include/tvm/ffi/any.h
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,15 @@ struct AnyUnsafe : public ObjectUnsafe {
}
}

template <typename T>
static TVM_FFI_INLINE T MoveFromAnyStorageAfterCheck(Any&& ref) {
if constexpr (!std::is_same_v<T, Any>) {
return TypeTraits<T>::MoveFromAnyStorageAfterCheck(&(ref.data_));
} else {
return std::move(ref);
}
}

static TVM_FFI_INLINE Object* ObjectPtrFromAnyAfterCheck(const Any& ref) {
return reinterpret_cast<Object*>(ref.data_.v_obj);
}
Expand Down
6 changes: 3 additions & 3 deletions ffi/include/tvm/ffi/base_details.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,9 @@
* This macro is used to clear the padding parts for hash and equality check
* in 32bit platform.
*/
#define TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result) \
if constexpr (sizeof(result->v_obj) != sizeof(result->v_int64)) { \
result->v_int64 = 0; \
#define TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result) \
if constexpr (sizeof((result)->v_obj) != sizeof((result)->v_int64)) { \
(result)->v_int64 = 0; \
}

namespace tvm {
Expand Down
8 changes: 8 additions & 0 deletions ffi/include/tvm/ffi/container/container_details.h
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,14 @@ inline constexpr bool storage_enabled_v = std::is_same_v<T, Any> || TypeTraits<T
template <typename... T>
inline constexpr bool all_storage_enabled_v = (storage_enabled_v<T> && ...);

/*!
* \brief Check if all T are compatible with Any.
*
* \tparam T The type to check.
* \return True if T is compatible with Any, false otherwise.
*/
template <typename... T>
inline constexpr bool all_object_ref_v = (std::is_base_of_v<ObjectRef, T> && ...);
/**
* \brief Check if Any storage of Derived can always be directly used as Base.
*
Expand Down
98 changes: 78 additions & 20 deletions ffi/include/tvm/ffi/container/variant.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,73 @@

namespace tvm {
namespace ffi {
namespace details {
/*!
* \brief Base class for Variant.
*
* \tparam all_storage_object Whether all types are derived from ObjectRef.
*/
template <bool all_storage_object = false>
class VariantBase {
public:
TVM_FFI_INLINE bool same_as(const VariantBase<all_storage_object>& other) const {
return data_.same_as(other.data_);
}

protected:
template <typename T>
explicit VariantBase(T other) : data_(std::move(other)) {}

TVM_FFI_INLINE void SetData(Any other_data) { data_ = std::move(other_data); }

TVM_FFI_INLINE Any MoveToAny() && { return std::move(data_); }

TVM_FFI_INLINE AnyView ToAnyView() const { return data_.operator AnyView(); }

Any data_;
};

// Specialization for all object ref case, backed by ObjectRef.
template <>
class VariantBase<true> : public ObjectRef {
protected:
template <typename T>
explicit VariantBase(const T& other) : ObjectRef(other) {}
template <typename T>
explicit VariantBase(T&& other) : ObjectRef(std::move(other)) {}
explicit VariantBase(ObjectPtr<Object> ptr) : ObjectRef(ptr) {}
explicit VariantBase(Any other)
: ObjectRef(details::AnyUnsafe::MoveFromAnyStorageAfterCheck<ObjectRef>(std::move(other))) {}

TVM_FFI_INLINE void SetData(ObjectPtr<Object> other) { data_ = std::move(other); }

TVM_FFI_INLINE Any MoveToAny() && { return Any(ObjectRef(std::move(data_))); }

TVM_FFI_INLINE AnyView ToAnyView() const {
TVMFFIAny any_data;
if (data_ == nullptr) {
any_data.type_index = TypeIndex::kTVMFFINone;
any_data.v_int64 = 0;
} else {
TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(&any_data);
any_data.type_index = data_->type_index();
any_data.v_obj = details::ObjectUnsafe::TVMFFIObjectPtrFromObjectPtr<Object>(data_);
}
return AnyView::CopyFromTVMFFIAny(any_data);
}
};
} // namespace details

/*!
* \brief A typed variant container.
*
* A Variant is backed by Any container, with strong checks during construction.
* When all values are ObjectRef, Variant is backed by ObjectRef,
* otherwise it is backed by Any.
*/
template <typename... V>
class Variant {
class Variant : public details::VariantBase<details::all_object_ref_v<V...>> {
public:
using TParent = details::VariantBase<details::all_object_ref_v<V...>>;
static_assert(details::all_storage_enabled_v<V...>,
"All types used in Variant<...> must be compatible with Any");
/*
Expand All @@ -54,31 +112,30 @@ class Variant {
template <typename T>
using enable_if_variant_contains_t = std::enable_if_t<variant_contains_v<T>>;

Variant(const Variant<V...>& other) : data_(other.data_) {}
Variant(Variant<V...>&& other) : data_(std::move(other.data_)) {}
Variant(const Variant<V...>& other) : TParent(other.data_) {}
Variant(Variant<V...>&& other) : TParent(std::move(other.data_)) {}

TVM_FFI_INLINE Variant& operator=(const Variant<V...>& other) {
data_ = other.data_;
this->SetData(other.data_);
return *this;
}

TVM_FFI_INLINE Variant& operator=(Variant<V...>&& other) {
data_ = std::move(other.data_);
this->SetData(std::move(other.data_));
return *this;
}

template <typename T, typename = enable_if_variant_contains_t<T>>
Variant(T other) : data_(std::move(other)) {} // NOLINT(*)
Variant(T other) : TParent(std::move(other)) {} // NOLINT(*)

template <typename T, typename = enable_if_variant_contains_t<T>>
TVM_FFI_INLINE Variant& operator=(T other) {
data_ = std::move(other);
return *this;
return operator=(Variant(std::move(other)));
}

template <typename T, typename = enable_if_variant_contains_t<T>>
TVM_FFI_INLINE std::optional<T> as() const {
return data_.as<T>();
return this->TParent::ToAnyView().template as<T>();
}

/*
Expand All @@ -89,29 +146,27 @@ class Variant {
*/
template <typename T, typename = std::enable_if_t<std::is_base_of_v<Object, T>>>
TVM_FFI_INLINE const T* as() const {
return data_.as<const T*>().value_or(nullptr);
return this->TParent::ToAnyView().template as<const T*>().value_or(nullptr);
}

template <typename T, typename = enable_if_variant_contains_t<T>>
TVM_FFI_INLINE T get() const& {
return data_.template cast<T>();
return this->TParent::ToAnyView().template cast<T>();
}

template <typename T, typename = enable_if_variant_contains_t<T>>
TVM_FFI_INLINE T get() && {
return std::move(data_).template cast<T>();
return std::move(*this).TParent::MoveToAny().template cast<T>();
}

TVM_FFI_INLINE std::string GetTypeKey() const { return data_.GetTypeKey(); }
TVM_FFI_INLINE std::string GetTypeKey() const { return this->TParent::ToAnyView().GetTypeKey(); }

private:
friend struct TypeTraits<Variant<V...>>;
friend struct ObjectPtrHash;
friend struct ObjectPtrEqual;
// constructor from any
explicit Variant(Any data) : data_(std::move(data)) {}
// internal data is backed by Any
Any data_;
explicit Variant(Any data) : TParent(std::move(data)) {}
/*!
* \brief Get the object pointer from the variant
* \note This function is only available if all types used in Variant<...> are derived from
Expand All @@ -122,8 +177,11 @@ class Variant {
static_assert(all_object_v,
"All types used in Variant<...> must be derived from ObjectRef "
"to enable ObjectPtrHash/ObjectPtrEqual");
return details::AnyUnsafe::ObjectPtrFromAnyAfterCheck(data_);
return this->data_.get();
}
// rexpose to friend class
using TParent::MoveToAny;
using TParent::ToAnyView;
};

template <typename... V>
Expand All @@ -132,11 +190,11 @@ inline constexpr bool use_default_type_traits_v<Variant<V...>> = false;
template <typename... V>
struct TypeTraits<Variant<V...>> : public TypeTraitsBase {
static TVM_FFI_INLINE void CopyToAnyView(const Variant<V...>& src, TVMFFIAny* result) {
*result = AnyView(src.data_).CopyToTVMFFIAny();
*result = src.ToAnyView().CopyToTVMFFIAny();
}

static TVM_FFI_INLINE void MoveToAny(Variant<V...> src, TVMFFIAny* result) {
*result = details::AnyUnsafe::MoveAnyToTVMFFIAny(std::move(src.data_));
*result = details::AnyUnsafe::MoveAnyToTVMFFIAny(std::move(src).MoveToAny());
}

static TVM_FFI_INLINE std::string GetMismatchTypeInfo(const TVMFFIAny* src) {
Expand Down
1 change: 1 addition & 0 deletions ffi/tests/cpp/test_any.cc
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,7 @@ TEST(Any, ObjectMove) {
auto v0 = std::move(any1).cast<TPrimExpr>();
EXPECT_EQ(v0->value, 3.14);
EXPECT_EQ(v0.use_count(), 1);
EXPECT_TRUE(any1 == nullptr);
}

} // namespace
2 changes: 1 addition & 1 deletion ffi/tests/cpp/test_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ TEST(Map, AnyConvertCheck) {
::tvm::ffi::Error);
}

TEST(Map, ffi::FunctionGetItem) {
TEST(Map, FunctionGetItem) {
Function f = Function::FromTyped([](const MapObj* n, const Any& k) -> Any { return n->at(k); },
"map_get_item");
Map<String, int64_t> map{{"x", 1}, {"y", 2}};
Expand Down
27 changes: 27 additions & 0 deletions ffi/tests/cpp/test_variant.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,4 +134,31 @@ TEST(Variant, Upcast) {
EXPECT_EQ(a1[0].get<int>(), 1);
}

TEST(Variant, AllObjectRef) {
Variant<TInt, Array<TInt>> v0 = TInt(1);
EXPECT_EQ(v0.get<TInt>()->value, 1);
static_assert(std::is_base_of_v<ObjectRef, decltype(v0)>);
Any any0 = v0;
EXPECT_EQ(any0.cast<TInt>()->value, 1);
auto v2 = any0.cast<Variant<TInt, Array<TInt>>>();
EXPECT_TRUE(v0.same_as(v2));
// assignment operator
v0 = Array<TInt>({TInt(2), TInt(3)});
EXPECT_EQ(v0.get<Array<TInt>>().size(), 2);
EXPECT_EQ(v0.get<Array<TInt>>()[0]->value, 2);
EXPECT_EQ(v0.get<Array<TInt>>()[1]->value, 3);
EXPECT_EQ(sizeof(v0), sizeof(ObjectRef));
}

TEST(Variant, PODSameAs) {
Variant<String, int> v0 = 1;
Variant<String, int> v1 = 1;
EXPECT_TRUE(v0.same_as(v1));
String s = String("hello");
v0 = s;
v1 = s;
EXPECT_TRUE(v0.same_as(v1));
v1 = String("hello");
EXPECT_TRUE(!v0.same_as(v1));
}
} // namespace
Loading