Skip to content

Commit ff06917

Browse files
authored
[API/Refactor] Unified PackedFunc for API and Generated Functions (#26)
1 parent 4242b9c commit ff06917

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+2375
-1723
lines changed

.travis.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ language: cpp
44

55
os:
66
- linux
7-
- osx
7+
# - osx
88

99
env:
1010
# code analysis

include/tvm/api_registry.h

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
/*!
2+
* Copyright (c) 2016 by Contributors
3+
* \file api_registry.h
4+
* \brief This file defines the TVM API registry.
5+
*
6+
* The API registry stores type-erased functions.
7+
* Each registered function is automatically exposed
8+
* to front-end language(e.g. python).
9+
* Front-end can also pass callbacks as PackedFunc, or register
10+
* then into the same global registry in C++.
11+
* The goal is to mix the front-end language and the TVM back-end.
12+
*
13+
* \code
14+
* // register the function as MyAPIFuncName
15+
* TVM_REGISTER_API(MyAPIFuncName)
16+
* .set_body([](TVMArgs args, TVMRetValue* rv) {
17+
* // my code.
18+
* });
19+
* \endcode
20+
*/
21+
#ifndef TVM_API_REGISTRY_H_
22+
#define TVM_API_REGISTRY_H_
23+
24+
#include <dmlc/base.h>
25+
#include <string>
26+
#include "./base.h"
27+
#include "./runtime/packed_func.h"
28+
#include "./packed_func_ext.h"
29+
30+
namespace tvm {
31+
32+
/*! \brief Utility to register API. */
33+
class APIRegistry {
34+
public:
35+
/*!
36+
* \brief set the body of the function to be f
37+
* \param f The body of the function.
38+
*/
39+
APIRegistry& set_body(PackedFunc f); // NOLINT(*)
40+
/*!
41+
* \brief set the body of the function to be f
42+
* \param f The body of the function.
43+
*/
44+
APIRegistry& set_body(PackedFunc::FType f) { // NOLINT(*)
45+
return set_body(PackedFunc(f));
46+
}
47+
/*!
48+
* \brief Register a function with given name
49+
* \param name The name of the function.
50+
*/
51+
static APIRegistry& __REGISTER__(const std::string& name); // NOLINT(*)
52+
53+
private:
54+
/*! \brief name of the function */
55+
std::string name_;
56+
};
57+
58+
/*!
59+
* \brief Get API function by name.
60+
*
61+
* \param name The name of the function.
62+
* \return the corresponding API function.
63+
* \note It is really PackedFunc::GetGlobal under the hood.
64+
*/
65+
inline PackedFunc GetAPIFunc(const std::string& name) {
66+
return PackedFunc::GetGlobal(name);
67+
}
68+
69+
#define _TVM_REGISTER_VAR_DEF_ \
70+
static DMLC_ATTRIBUTE_UNUSED ::tvm::APIRegistry& __make_TVMRegistry_
71+
72+
/*!
73+
* \brief Register API function globally.
74+
* \code
75+
* TVM_REGISTER_API(MyPrint)
76+
* .set_body([](TVMArgs args, TVMRetValue* rv) {
77+
* // my code.
78+
* });
79+
* \endcode
80+
*/
81+
#define TVM_REGISTER_API(OpName) \
82+
DMLC_STR_CONCAT(_TVM_REGISTER_VAR_DEF_, __COUNTER__) = \
83+
::tvm::APIRegistry::__REGISTER__(#OpName)
84+
} // namespace tvm
85+
#endif // TVM_API_REGISTRY_H_

include/tvm/c_api.h

Lines changed: 7 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -2,83 +2,23 @@
22
* Copyright (c) 2016 by Contributors
33
* \file c_api.h
44
* \brief C API of TVM DSL
5+
*
6+
* \note The API is designed in a minimum way.
7+
* Most of the API functions are registered and can be pulled out.
8+
*
9+
* The common flow is:
10+
* - Use TVMFuncListGlobalNames to get global function name
11+
* - Use TVMFuncCall to call these functions.
512
*/
613
#ifndef TVM_C_API_H_
714
#define TVM_C_API_H_
815

916
#include "./runtime/c_runtime_api.h"
1017

1118
TVM_EXTERN_C {
12-
/*! \brief handle to functions */
13-
typedef void* APIFuncHandle;
1419
/*! \brief handle to node */
1520
typedef void* NodeHandle;
1621

17-
/*!
18-
* \brief List all the node function name
19-
* \param out_size The number of functions
20-
* \param out_array The array of function names.
21-
* \return 0 when success, -1 when failure happens
22-
*/
23-
TVM_DLL int TVMListAPIFuncNames(int *out_size,
24-
const char*** out_array);
25-
/*!
26-
* \brief get function handle by name
27-
* \param name The name of function
28-
* \param handle The returning function handle
29-
* \return 0 when success, -1 when failure happens
30-
*/
31-
TVM_DLL int TVMGetAPIFuncHandle(const char* name,
32-
APIFuncHandle *handle);
33-
34-
/*!
35-
* \brief Get the detailed information about function.
36-
* \param handle The operator handle.
37-
* \param real_name The returned name of the function.
38-
* This name is not the alias name of the atomic symbol.
39-
* \param description The returned description of the symbol.
40-
* \param num_doc_args Number of arguments that contain documents.
41-
* \param arg_names Name of the arguments of doc args
42-
* \param arg_type_infos Type informations about the arguments.
43-
* \param arg_descriptions Description information about the arguments.
44-
* \param return_type Return type of the function, if any.
45-
* \return 0 when success, -1 when failure happens
46-
*/
47-
TVM_DLL int TVMGetAPIFuncInfo(APIFuncHandle handle,
48-
const char **real_name,
49-
const char **description,
50-
int *num_doc_args,
51-
const char ***arg_names,
52-
const char ***arg_type_infos,
53-
const char ***arg_descriptions,
54-
const char **return_type);
55-
56-
/*!
57-
* \brief Push an argument to the function calling stack.
58-
* If push fails, the stack will be reset to empty
59-
*
60-
* \param arg The argument
61-
* \param type_code The type_code of argument as in TVMTypeCode
62-
* \return 0 when success, -1 when failure happens
63-
* \note API calls always exchanges with type bits=64, lanes=1
64-
*/
65-
TVM_DLL int TVMAPIPushStack(TVMValue arg,
66-
int type_code);
67-
68-
/*!
69-
* \brief call a function by using arguments in the stack.
70-
* The stack will be cleanup to empty after this call, whether the call is successful.
71-
*
72-
* \param handle The function handle
73-
* \param ret_val The return value.
74-
* \param ret_type_code the type code of return value.
75-
* \return 0 when success, -1 when failure happens
76-
* \note API calls always exchanges with type bits=64, lanes=1
77-
*/
78-
TVM_DLL int TVMAPIFuncCall(APIFuncHandle handle,
79-
TVMValue* ret_val,
80-
int* ret_type_code);
81-
8222
/*!
8323
* \brief free the node handle
8424
* \param handle The node handle to be freed.

include/tvm/expr.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <string>
1313
#include <algorithm>
1414
#include "./base.h"
15+
#include "./runtime/packed_func.h"
1516

1617
namespace tvm {
1718

include/tvm/packed_func_ext.h

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
/*!
2+
* Copyright (c) 2016 by Contributors
3+
* \file packed_func_ext.h
4+
* \brief Extension package to PackedFunc
5+
* This enales pass NodeRef types into/from PackedFunc.
6+
*/
7+
#ifndef TVM_PACKED_FUNC_EXT_H_
8+
#define TVM_PACKED_FUNC_EXT_H_
9+
10+
#include <sstream>
11+
#include <string>
12+
#include <memory>
13+
#include <type_traits>
14+
15+
#include "./base.h"
16+
#include "./expr.h"
17+
18+
namespace tvm {
19+
using runtime::TVMArgs;
20+
using runtime::TVMRetValue;
21+
using runtime::PackedFunc;
22+
23+
namespace runtime {
24+
/*!
25+
* \brief Runtime type checker for node type.
26+
* \tparam T the type to be checked.
27+
*/
28+
template<typename T>
29+
struct NodeTypeChecker {
30+
static inline bool Check(Node* sptr) {
31+
// This is the only place in the project where RTTI is used
32+
// It can be turned off, but will make non strict checking.
33+
// TODO(tqchen) possibly find alternative to turn of RTTI
34+
using ContainerType = typename T::ContainerType;
35+
return (dynamic_cast<ContainerType*>(sptr) != nullptr);
36+
}
37+
static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
38+
using ContainerType = typename T::ContainerType;
39+
os << ContainerType::_type_key;
40+
}
41+
};
42+
43+
template<typename T>
44+
struct NodeTypeChecker<Array<T> > {
45+
static inline bool Check(Node* sptr) {
46+
if (sptr == nullptr) return false;
47+
if (!sptr->is_type<ArrayNode>()) return false;
48+
ArrayNode* n = static_cast<ArrayNode*>(sptr);
49+
for (const auto& p : n->data) {
50+
if (!NodeTypeChecker<T>::Check(p.get())) return false;
51+
}
52+
return true;
53+
}
54+
static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
55+
os << "array<";
56+
NodeTypeChecker<T>::PrintName(os);
57+
os << ">";
58+
}
59+
};
60+
61+
template<typename K, typename V>
62+
struct NodeTypeChecker<Map<K, V> > {
63+
static inline bool Check(Node* sptr) {
64+
if (sptr == nullptr) return false;
65+
if (!sptr->is_type<MapNode>()) return false;
66+
MapNode* n = static_cast<MapNode*>(sptr);
67+
for (const auto& kv : n->data) {
68+
if (!NodeTypeChecker<K>::Check(kv.first.get())) return false;
69+
if (!NodeTypeChecker<V>::Check(kv.second.get())) return false;
70+
}
71+
return true;
72+
}
73+
static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
74+
os << "map<";
75+
NodeTypeChecker<K>::PrintName(os);
76+
os << ',';
77+
NodeTypeChecker<V>::PrintName(os);
78+
os << '>';
79+
}
80+
};
81+
82+
template<typename T>
83+
inline std::string NodeTypeName() {
84+
std::ostringstream os;
85+
NodeTypeChecker<T>::PrintName(os);
86+
return os.str();
87+
}
88+
89+
// extensions for tvm arg value
90+
91+
template<typename TNodeRef, typename>
92+
inline TVMArgValue::operator TNodeRef() const {
93+
static_assert(
94+
std::is_base_of<NodeRef, TNodeRef>::value,
95+
"Conversion only works for NodeRef");
96+
if (type_code_ == kNull) return TNodeRef();
97+
TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle);
98+
std::shared_ptr<Node>& sptr = *ptr<std::shared_ptr<Node> >();
99+
CHECK(NodeTypeChecker<TNodeRef>::Check(sptr.get()))
100+
<< "Expected type " << NodeTypeName<TNodeRef>()
101+
<< " but get " << sptr->type_key();
102+
return TNodeRef(sptr);
103+
}
104+
105+
inline TVMArgValue::operator Halide::Expr() const {
106+
if (type_code_ == kNull) return Expr();
107+
if (type_code_ == kInt) {
108+
return Expr(static_cast<int>(value_.v_int64));
109+
}
110+
if (type_code_ == kFloat) {
111+
return Expr(static_cast<float>(value_.v_float64));
112+
}
113+
TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle);
114+
std::shared_ptr<Node>& sptr = *ptr<std::shared_ptr<Node> >();
115+
if (sptr->is_type<IterVarNode>()) {
116+
return IterVar(sptr)->var;
117+
}
118+
CHECK(NodeTypeChecker<Expr>::Check(sptr.get()))
119+
<< "Expected type " << NodeTypeName<Expr>()
120+
<< " but get " << sptr->type_key();
121+
return Expr(sptr);
122+
}
123+
124+
inline std::shared_ptr<Node>& TVMArgValue::node_sptr() {
125+
TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle);
126+
return *ptr<std::shared_ptr<Node> >();
127+
}
128+
129+
130+
template<typename TNodeRef, typename>
131+
inline bool TVMArgValue::IsNodeType() const {
132+
TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle);
133+
std::shared_ptr<Node>& sptr =
134+
*ptr<std::shared_ptr<Node> >();
135+
return NodeTypeChecker<TNodeRef>::Check(sptr.get());
136+
}
137+
138+
// extensions for TVMRetValue
139+
inline TVMRetValue& TVMRetValue::operator=(
140+
const std::shared_ptr<Node>& other) {
141+
SwitchToClass<std::shared_ptr<Node> >(kNodeHandle, other);
142+
return *this;
143+
}
144+
145+
inline TVMRetValue& TVMRetValue::operator=(const NodeRef& other) {
146+
SwitchToClass<std::shared_ptr<Node> >(kNodeHandle, other.node_);
147+
return *this;
148+
}
149+
150+
template<typename TNodeRef, typename>
151+
inline TVMRetValue::operator TNodeRef() const {
152+
static_assert(
153+
std::is_base_of<NodeRef, TNodeRef>::value,
154+
"Conversion only works for NodeRef");
155+
if (type_code_ == kNull) return TNodeRef();
156+
TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle);
157+
return TNodeRef(*ptr<std::shared_ptr<Node> >());
158+
}
159+
160+
inline void TVMArgsSetter::operator()(size_t i, NodeRef& other) const { // NOLINT(*)
161+
values_[i].v_handle = &(other.node_);
162+
type_codes_[i] = kNodeHandle;
163+
}
164+
165+
// Type related stuffs
166+
inline Type TVMType2Type(TVMType t) {
167+
return Type(static_cast<halide_type_code_t>(t.code), t.bits, t.lanes);
168+
}
169+
170+
inline TVMType Type2TVMType(Type t) {
171+
TVMType ret;
172+
ret.code = static_cast<uint8_t>(t.code());
173+
ret.bits = static_cast<uint8_t>(t.bits());
174+
ret.lanes = static_cast<uint16_t>(t.lanes());
175+
return ret;
176+
}
177+
178+
inline TVMRetValue& TVMRetValue::operator=(const Halide::Type& t) {
179+
return this->operator=(Type2TVMType(t));
180+
}
181+
182+
inline TVMRetValue::operator Halide::Type() const {
183+
return TVMType2Type(operator TVMType());
184+
}
185+
186+
inline TVMArgValue::operator Halide::Type() const {
187+
return TVMType2Type(operator TVMType());
188+
}
189+
190+
inline void TVMArgsSetter::operator()(
191+
size_t i, const Halide::Type& t) const {
192+
this->operator()(i, Type2TVMType(t));
193+
}
194+
} // namespace runtime
195+
} // namespace tvm
196+
#endif // TVM_PACKED_FUNC_EXT_H_

0 commit comments

Comments
 (0)