Skip to content

Commit 824e1d8

Browse files
tqchenkevinthesun
authored andcommitted
[REFACTOR][NODE][RUNTIME] Move Node to the new Object protocol. (apache#4161)
* [REFACTOR][NODE][RUNTIME] Move Node to the new Object protocol. This PR removes the original node system, and make node as a subclass of Object. This is a major refactor towards a better unified runtime object system. List of changes in the refactor: - We now hide data_ field, use Downcast explicitly to get a sub-class object. - Removed the node system FFI in python. - Removed the node C API, instead use PackedFunc for list and get attrs. - Change relay::Op::set_attr_type_key(attr_key_name) to relay::Op::set_attr_type<AttrType>(). - This change was necessary because of the new Object registration mechanism. - Subsequent changes to the op registrations - The change revealed a few previous problems that is now fixed. - Patched up a few missing node type registration. - Now we will raise an error if we register object that is not registered. - The original node.h and container.h are kept in the same location. - Calling convention: kObjectHandle now equals the old kNodeHandle, kNodeHandle is removed. - IRFunctor now dispatches on ObjectRef. - Update to the new type checking API: is_type, derived_from are replaced by IsInstance. - Removed .hash member function, instead use C++ convention hasher functors. * Address review comments
1 parent ffc11b7 commit 824e1d8

File tree

185 files changed

+1442
-2387
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

185 files changed

+1442
-2387
lines changed

golang/src/value.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ var KTVMType = int32(C.kTVMType)
4444
var KTVMContext = int32(C.kTVMContext)
4545
// KArrayHandle is golang type code for TVM kArrayHandle.
4646
var KArrayHandle = int32(C.kArrayHandle)
47-
// KNodeHandle is golang type code for TVM kNodeHandle.
48-
var KNodeHandle = int32(C.kNodeHandle)
47+
// KObjectHandle is golang type code for TVM kObjectHandle.
48+
var KObjectHandle = int32(C.kObjectHandle)
4949
// KModuleHandle is gonag type code for TVM kModuleHandle.
5050
var KModuleHandle = int32(C.kModuleHandle)
5151
// KFuncHandle is gonalg type code for TVM kFuncHandle.

include/tvm/api_registry.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ class EnvFunc : public NodeRef {
7979
explicit EnvFunc(NodePtr<Node> n) : NodeRef(n) {}
8080
/*! \return The internal global function pointer */
8181
const EnvFuncNode* operator->() const {
82-
return static_cast<EnvFuncNode*>(node_.get());
82+
return static_cast<const EnvFuncNode*>(get());
8383
}
8484
/*!
8585
* \brief Invoke the function.
@@ -124,19 +124,19 @@ class TypedEnvFunc<R(Args...)> : public NodeRef {
124124
/*! \brief short hand for this function type */
125125
using TSelf = TypedEnvFunc<R(Args...)>;
126126
TypedEnvFunc() {}
127-
explicit TypedEnvFunc(NodePtr<Node> n) : NodeRef(n) {}
127+
explicit TypedEnvFunc(ObjectPtr<Object> n) : NodeRef(n) {}
128128
/*!
129129
* \brief Assign global function to a TypedEnvFunc
130130
* \param other Another global function.
131131
* \return reference to self.
132132
*/
133133
TSelf& operator=(const EnvFunc& other) {
134-
this->node_ = other.node_;
134+
ObjectRef::operator=(other);
135135
return *this;
136136
}
137137
/*! \return The internal global function pointer */
138138
const EnvFuncNode* operator->() const {
139-
return static_cast<EnvFuncNode*>(node_.get());
139+
return static_cast<const EnvFuncNode*>(get());
140140
}
141141
/*!
142142
* \brief Invoke the function.

include/tvm/arithmetic.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,7 @@ class IntSet : public NodeRef {
362362
/*! \brief constructor */
363363
IntSet() {}
364364
// constructor from not container.
365-
explicit IntSet(NodePtr<Node> n) : NodeRef(n) {}
365+
explicit IntSet(ObjectPtr<Object> n) : NodeRef(n) {}
366366
/*!
367367
* \brief access the internal node container
368368
* \return the pointer to the internal node container
@@ -692,7 +692,7 @@ Array<Expr> DetectClipBound(const Expr& e,
692692

693693
// implementation
694694
inline const IntSetNode* IntSet::operator->() const {
695-
return static_cast<const IntSetNode*>(node_.get());
695+
return static_cast<const IntSetNode*>(get());
696696
}
697697
} // namespace arith
698698
} // namespace tvm

include/tvm/attrs.h

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ class AttrsEqual {
163163
return lhs == rhs;
164164
}
165165
// node comparator
166-
TVM_DLL bool operator()(const NodeRef& lhs, const NodeRef& rhs) const;
166+
TVM_DLL bool operator()(const ObjectRef& lhs, const ObjectRef& rhs) const;
167167

168168
protected:
169169
friend class AttrsEqualHandler;
@@ -203,7 +203,7 @@ class AttrsHash {
203203
(static_cast<int>(value.bits()) << 8) |
204204
(static_cast<int>(value.lanes()) << 16));
205205
}
206-
TVM_DLL size_t operator()(const NodeRef& value) const;
206+
TVM_DLL size_t operator()(const ObjectRef& value) const;
207207

208208
private:
209209
friend class AttrsHashHandler;
@@ -260,7 +260,7 @@ class BaseAttrsNode : public Node {
260260
* \return The comparison result.
261261
*/
262262
TVM_DLL virtual bool ContentEqual(
263-
const Node* other, AttrsEqual equal) const = 0;
263+
const Object* other, AttrsEqual equal) const = 0;
264264
/*!
265265
* \brief Content aware hash.
266266
* \param hasher The hasher to run the hash.
@@ -290,7 +290,7 @@ class Attrs : public NodeRef {
290290
private:
291291
/*! \return the internal attribute node */
292292
const BaseAttrsNode* ptr() const {
293-
return static_cast<const BaseAttrsNode*>(node_.get());
293+
return static_cast<const BaseAttrsNode*>(get());
294294
}
295295
};
296296

@@ -315,7 +315,7 @@ class DictAttrsNode : public BaseAttrsNode {
315315
void VisitNonDefaultAttrs(AttrVisitor* v) final;
316316
void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final;
317317
Array<AttrFieldInfo> ListFieldInfo() const final;
318-
bool ContentEqual(const Node* other, AttrsEqual equal) const final;
318+
bool ContentEqual(const Object* other, AttrsEqual equal) const final;
319319
size_t ContentHash(AttrsHash hasher) const final;
320320
// type info
321321
static constexpr const char* _type_key = "DictAttrs";
@@ -369,7 +369,7 @@ class AttrsEqualVisitor {
369369
public:
370370
bool result_{true};
371371
// constructor
372-
AttrsEqualVisitor(const Node* lhs, const Node* rhs, const AttrsEqual& equal)
372+
AttrsEqualVisitor(const Object* lhs, const Object* rhs, const AttrsEqual& equal)
373373
: lhs_(lhs), rhs_(rhs), equal_(equal) {
374374
}
375375
template<typename T>
@@ -387,8 +387,8 @@ class AttrsEqualVisitor {
387387
}
388388

389389
private:
390-
const Node* lhs_;
391-
const Node* rhs_;
390+
const Object* lhs_;
391+
const Object* rhs_;
392392
const AttrsEqual& equal_;
393393
};
394394

@@ -488,7 +488,7 @@ inline void SetIntValue(T* ptr, const TVMArgValue& val) {
488488
} else if (const ir::UIntImm* op = expr.as<ir::UIntImm>()) {
489489
*ptr = static_cast<T>(op->value);
490490
} else {
491-
LOG(FATAL) << "Expect int value, but get " << expr->type_key();
491+
LOG(FATAL) << "Expect int value, but get " << expr->GetTypeKey();
492492
}
493493
}
494494
}
@@ -521,7 +521,7 @@ inline void SetValue<double>(double* ptr, const TVMArgValue& val) {
521521
} else if (const ir::UIntImm* op = expr.as<ir::UIntImm>()) {
522522
*ptr = static_cast<double>(op->value);
523523
} else {
524-
LOG(FATAL) << "Expect float value, but get " << expr->type_key();
524+
LOG(FATAL) << "Expect float value, but get " << expr->GetTypeKey();
525525
}
526526
}
527527
}
@@ -827,7 +827,7 @@ class AttrsNode : public BaseAttrsNode {
827827
return visitor.fields_;
828828
}
829829

830-
bool ContentEqual(const Node* other, AttrsEqual equal) const final {
830+
bool ContentEqual(const Object* other, AttrsEqual equal) const final {
831831
DerivedType* pself = self();
832832
if (pself == other) return true;
833833
if (other == nullptr) return false;
@@ -839,7 +839,7 @@ class AttrsNode : public BaseAttrsNode {
839839

840840
size_t ContentHash(AttrsHash hasher) const final {
841841
::tvm::detail::AttrsHashVisitor visitor(hasher);
842-
visitor.result_ = std::hash<std::string>()(this->type_key());
842+
visitor.result_ = this->GetTypeKeyHash();
843843
self()->__VisitAttrs__(visitor);
844844
return visitor.result_;
845845
}

include/tvm/base.h

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,10 @@ using ::tvm::AttrVisitor;
4747
*/
4848
#define TVM_DEFINE_NODE_REF_METHODS(TypeName, BaseTypeName, NodeName) \
4949
TypeName() {} \
50-
explicit TypeName(::tvm::NodePtr<::tvm::Node> n) : BaseTypeName(n) {} \
50+
explicit TypeName(::tvm::ObjectPtr<::tvm::Object> n) \
51+
: BaseTypeName(n) {} \
5152
const NodeName* operator->() const { \
52-
return static_cast<const NodeName*>(node_.get()); \
53+
return static_cast<const NodeName*>(data_.get()); \
5354
} \
5455
operator bool() const { return this->defined(); } \
5556
using ContainerType = NodeName;
@@ -75,12 +76,12 @@ using ::tvm::AttrVisitor;
7576
*/
7677
#define TVM_DEFINE_NODE_REF_COW(NodeName) \
7778
NodeName* CopyOnWrite() { \
78-
CHECK(node_ != nullptr); \
79-
if (!node_.unique()) { \
79+
CHECK(data_ != nullptr); \
80+
if (!data_.unique()) { \
8081
NodePtr<NodeName> n = make_node<NodeName>(*(operator->())); \
81-
NodePtr<Node>(std::move(n)).swap(node_); \
82+
ObjectPtr<Object>(std::move(n)).swap(data_); \
8283
} \
83-
return static_cast<NodeName*>(node_.get()); \
84+
return static_cast<NodeName*>(data_.get()); \
8485
}
8586

8687
/*! \brief Macro to make it easy to define node ref type given node */
@@ -160,7 +161,7 @@ std::string SaveJSON(const NodeRef& node);
160161
*
161162
* \return The shared_ptr of the Node.
162163
*/
163-
NodePtr<Node> LoadJSON_(std::string json_str);
164+
ObjectPtr<Object> LoadJSON_(std::string json_str);
164165

165166
/*!
166167
* \brief Load the node from json string.
@@ -233,6 +234,7 @@ struct NodeFactoryReg {
233234
* \note This is necessary to enable serialization of the Node.
234235
*/
235236
#define TVM_REGISTER_NODE_TYPE(TypeName) \
237+
TVM_REGISTER_OBJECT_TYPE(TypeName); \
236238
static DMLC_ATTRIBUTE_UNUSED ::tvm::NodeFactoryReg & __make_Node ## _ ## TypeName ## __ = \
237239
::tvm::NodeFactoryReg::Registry()->__REGISTER__(TypeName::_type_key) \
238240
.set_creator([](const std::string&) { return ::tvm::make_node<TypeName>(); })

include/tvm/buffer.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ enum BufferType : int {
5151
class Buffer : public NodeRef {
5252
public:
5353
Buffer() {}
54-
explicit Buffer(NodePtr<Node> n) : NodeRef(n) {}
54+
explicit Buffer(ObjectPtr<Object> n) : NodeRef(n) {}
5555
/*!
5656
* \brief Return a new buffer that is equivalent with current one
5757
* but always add stride field.
@@ -171,7 +171,7 @@ class BufferNode : public Node {
171171
};
172172

173173
inline const BufferNode* Buffer::operator->() const {
174-
return static_cast<const BufferNode*>(node_.get());
174+
return static_cast<const BufferNode*>(get());
175175
}
176176

177177
/*!

include/tvm/build_module.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ class TargetNode : public Node {
9393
class Target : public NodeRef {
9494
public:
9595
Target() {}
96-
explicit Target(NodePtr<Node> n) : NodeRef(n) {}
96+
explicit Target(ObjectPtr<Object> n) : NodeRef(n) {}
9797
/*!
9898
* \brief Create a Target given a string
9999
* \param target_str the string to parse
@@ -110,7 +110,7 @@ class Target : public NodeRef {
110110
TVM_DLL static tvm::Target Current(bool allow_not_defined = true);
111111

112112
const TargetNode* operator->() const {
113-
return static_cast<const TargetNode*>(node_.get());
113+
return static_cast<const TargetNode*>(get());
114114
}
115115

116116
using ContainerType = TargetNode;
@@ -256,12 +256,12 @@ class BuildConfigNode : public Node {
256256
class BuildConfig : public ::tvm::NodeRef {
257257
public:
258258
BuildConfig() {}
259-
explicit BuildConfig(NodePtr<::tvm::Node> n) : NodeRef(n) {}
259+
explicit BuildConfig(ObjectPtr<Object> n) : NodeRef(n) {}
260260
const BuildConfigNode* operator->() const {
261-
return static_cast<const BuildConfigNode*>(node_.get());
261+
return static_cast<const BuildConfigNode*>(get());
262262
}
263263
BuildConfigNode* operator->() {
264-
return static_cast<BuildConfigNode*>(node_.get());
264+
return static_cast<BuildConfigNode*>(get_mutable());
265265
}
266266
/*!
267267
* \brief Construct a BuildConfig containing a empty build config node.
@@ -371,7 +371,7 @@ class GenericFuncNode;
371371
class GenericFunc : public NodeRef {
372372
public:
373373
GenericFunc() {}
374-
explicit GenericFunc(NodePtr<Node> n) : NodeRef(n) {}
374+
explicit GenericFunc(ObjectPtr<Object> n) : NodeRef(n) {}
375375

376376
/*!
377377
* \brief Set the default function implementaiton.
@@ -478,10 +478,10 @@ class GenericFuncNode : public Node {
478478
};
479479

480480
inline GenericFuncNode* GenericFunc::operator->() {
481-
return static_cast<GenericFuncNode*>(node_.get());
481+
return static_cast<GenericFuncNode*>(get_mutable());
482482
}
483483

484-
#define TVM_GENERIC_FUNC_REG_VAR_DEF \
484+
#define TVM_GENERIC_FUNC_REG_VAR_DEF \
485485
static TVM_ATTRIBUTE_UNUSED ::tvm::GenericFunc& __mk_ ## TVM
486486

487487
/*!

include/tvm/c_dsl_api.h

Lines changed: 0 additions & 98 deletions
This file was deleted.

include/tvm/channel.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class Channel : public NodeRef {
3535
public:
3636
/*! \brief default constructor */
3737
Channel() {}
38-
explicit Channel(NodePtr<Node> n) : NodeRef(n) {}
38+
explicit Channel(ObjectPtr<Object> n) : NodeRef(n) {}
3939
/*!
4040
* \brief access the internal node container
4141
* \return the pointer to the internal node container
@@ -67,7 +67,7 @@ struct ChannelNode : public Node {
6767

6868
// Inline implementations
6969
inline const ChannelNode* Channel::operator->() const {
70-
return static_cast<const ChannelNode*>(node_.get());
70+
return static_cast<const ChannelNode*>(get());
7171
}
7272
} // namespace tvm
7373
#endif // TVM_CHANNEL_H_

0 commit comments

Comments
 (0)