Skip to content

Commit 4f1473f

Browse files
authored
[CODEGEN] Add LoweredFunc, MakeAPI to build a C API function (#23)
* [CODEGEN] Add LoweredFunc, MakeAPI and SplitHostDevice * update halideir
1 parent 3c1020d commit 4f1473f

25 files changed

+1346
-348
lines changed

HalideIR

Submodule HalideIR updated from adfa662 to 30bf0f0

include/tvm/buffer.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ class Buffer : public NodeRef {
5050
* \return the pointer to the internal node container
5151
*/
5252
inline const BufferNode* operator->() const;
53+
54+
/*! \brief specify container node */
55+
using ContainerType = BufferNode;
5356
};
5457

5558
/*! \brief Node to represent a buffer */

include/tvm/c_runtime_api.h

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#endif
3131

3232
#include <stdint.h>
33+
#include <stddef.h>
3334

3435

3536
TVM_EXTERN_C {
@@ -216,18 +217,45 @@ TVM_DLL int TVMArrayCopyFromTo(TVMArrayHandle from,
216217
TVM_DLL int TVMSynchronize(TVMContext ctx, TVMStreamHandle stream);
217218

218219
/*!
219-
* \brief Launch a generated TVM function
220+
* \brief TVM Function API: Get resource requirement
221+
*
222+
* By default TVM function try not to do internal allocations.
223+
* Instead, TVMFuncRequirement can be called, given the input arguments.
224+
*
225+
* \param func function handle to be launched.
226+
* \param args The arguments
227+
* \param arg_type_ids The type id of the arguments
228+
* \param num_args Number of arguments.
229+
* \param out_workspace_size The workspace size needed to launch this function.
230+
* \param out_workspace_align The alignment requirement of workspace.
231+
*
232+
* \note The data pointer in the arrays is not used by requirement.
233+
*/
234+
TVM_DLL int TVMFuncRequirement(TVMFunctionHandle func,
235+
TVMArg* args,
236+
int* arg_type_ids,
237+
int num_args,
238+
size_t* out_workspace_size,
239+
size_t* out_workspace_align);
240+
241+
/*!
242+
* \brief TVM Function API: Launch generated function.
243+
*
220244
* \param func function handle to be launched.
221245
* \param args The arguments
222246
* \param arg_type_ids The type id of the arguments
223247
* \param num_args Number of arguments.
224248
* \param stream The stream this function to be launched on.
249+
* \param workspace Additional workspace used to launch this function.
250+
*
251+
* \sa TVMFuncRequirement
225252
*/
226-
TVM_DLL int TVMLaunch(TVMFunctionHandle func,
227-
TVMArg* args,
228-
int* arg_type_ids,
229-
int num_args,
230-
TVMStreamHandle stream);
253+
TVM_DLL int TVMFuncLaunch(TVMFunctionHandle func,
254+
TVMArg* args,
255+
int* arg_type_ids,
256+
int num_args,
257+
TVMStreamHandle stream,
258+
TVMArrayHandle workspace);
231259
} // TVM_EXTERN_C
232260

233261
#endif // TVM_C_RUNTIME_API_H_

include/tvm/codegen.h

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
/*!
2+
* Copyright (c) 2016 by Contributors
3+
* \file codegen.h
4+
* \brief Collection of Lowlevel IR pass to codegen.
5+
*/
6+
#ifndef TVM_CODEGEN_H_
7+
#define TVM_CODEGEN_H_
8+
9+
#include <string>
10+
#include "./base.h"
11+
#include "./expr.h"
12+
#include "./module.h"
13+
14+
namespace tvm {
15+
/*! \brief namespace for lowlevel IR pass and codegen */
16+
namespace codegen {
17+
/*!
18+
* \brief Make an user callable API LoweredFunc.
19+
*
20+
* The main task of this function is to create code to :
21+
* - Map the values in the api_args to of Var that is required by body.
22+
* - Insert assertions to check type/value of the passed arguments.
23+
*
24+
* \param body The body of the function.
25+
* \param name The name of the function.
26+
* \param api_args Arguments to the function, can be either Var, or Buffer
27+
* \param num_packed_args Number of arguments that are processed in packed form.
28+
* \return a LoweredFunc with the specified signiture.
29+
*
30+
* \note
31+
* The function signiture have two cases
32+
*
33+
* if num_packed_args is zero:
34+
* f(api_arg_0, api_arg_1, .., api_arg_n) where n == len(api_args)
35+
*
36+
* if num_packed_args is not zero:
37+
* f(TVMArg* packed_args, int* packed_arg_type_ids, int num_packed_args,
38+
* api_arg_k, api_arg_k+1, ... api_arg_n)
39+
*
40+
* where n == len(api_args), k == num_packed_args
41+
*
42+
* There is no thread_axis in generated function.
43+
*/
44+
LoweredFunc MakeAPI(Stmt body,
45+
std::string name,
46+
Array<NodeRef> api_args,
47+
int num_packed_args);
48+
49+
/*!
50+
* \brief Count number of undefined vars in f.
51+
* \param f The function to be checked.
52+
* \return Number of undefined vars.
53+
*/
54+
Array<Var> UndefinedVars(const LoweredFunc& f);
55+
56+
/*!
57+
* \brief Split the function into a host function and device functions.
58+
* \param func The function to be splitted.
59+
*
60+
* \return Array of functions, the first one is host function,
61+
* the others are device functions.
62+
*/
63+
Array<LoweredFunc> SplitHostDevice(LoweredFunc func);
64+
65+
} // namespace codegen
66+
} // namespace tvm
67+
68+
#endif // TVM_CODEGEN_H_

include/tvm/ir.h

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,48 @@ struct Reduce : public ExprNode<Reduce> {
4949
static constexpr const char* Min = "Min";
5050
};
5151

52+
/*! \brief namespace of TVM Intrinsic functions */
53+
namespace intrinsic {
54+
// Most of the intrinsics is to enab
55+
/*!
56+
* \brief See pesudo code
57+
*
58+
* Type tvm_api_load_arg(TVMArg* args, int* args_type_id, i) {
59+
* assert(arg_type_id[i] == typeid(Type));
60+
* return args[i];
61+
* }
62+
*/
63+
constexpr const char* tvm_api_load_arg = "tvm_api_load_arg";
64+
/*!
65+
* \brief See pesudo code
66+
*
67+
* Type tvm_array_get_field(TVMArray* arr, int field_id) {
68+
* return arr->field;
69+
* }
70+
* \sa TVMArrayFieldKind
71+
*/
72+
constexpr const char* tvm_array_get_field = "tvm_array_get_field";
73+
/*!
74+
* \brief See pesudo code
75+
*
76+
* bool tvm_handle_is_null(void* handle) {
77+
* return handle == nullptr
78+
* }
79+
*/
80+
constexpr const char* tvm_handle_is_null = "tvm_handle_is_null";
81+
82+
/*! \brief The field id of each field in array */
83+
enum TVMArrayFieldKind {
84+
kData = 0,
85+
kNDim = 1,
86+
kShape = 2,
87+
kStrides = 3,
88+
kTypeCode = 4,
89+
kTypeBits = 5,
90+
kTypeLanes = 6
91+
};
92+
} // namespace intrinsic
93+
5294
// Reuse IR node defintiion from HalideIR
5395
using Halide::Internal::IntImm;
5496
using Halide::Internal::UIntImm;

include/tvm/ir_mutator.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <tvm/ir_functor.h>
1010
#include <unordered_map>
1111
#include "./expr.h"
12+
#include "./ir.h"
1213

1314
namespace tvm {
1415
namespace ir {
@@ -51,6 +52,20 @@ class IRMutator {
5152
static FMutateExpr& vtable_expr(); // NOLINT(*)
5253
/*! \return internal stmt of expr */
5354
static FMutateStmt& vtable_stmt(); // NOLINT(*)
55+
// Set of overloadable functions
56+
// The underscore allows Mutate not to be shadowed by inheritance
57+
virtual Stmt Mutate_(const LetStmt* op, const Stmt& s);
58+
virtual Stmt Mutate_(const AttrStmt* op, const Stmt& s);
59+
virtual Stmt Mutate_(const For* op, const Stmt& s);
60+
virtual Stmt Mutate_(const Provide* op, const Stmt& s);
61+
virtual Stmt Mutate_(const Allocate* op, const Stmt& s);
62+
virtual Stmt Mutate_(const Realize* op, const Stmt& s);
63+
virtual Stmt Mutate_(const Store* op, const Stmt& s);
64+
virtual Stmt Mutate_(const Free* op, const Stmt& s);
65+
virtual Expr Mutate_(const Call* op, const Expr& e);
66+
virtual Expr Mutate_(const Load* op, const Expr& s);
67+
virtual Expr Mutate_(const Variable* op, const Expr& e);
68+
virtual Expr Mutate_(const Let* op, const Expr& e);
5469
};
5570

5671
/*!

include/tvm/ir_pass.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,12 @@ Stmt ScheduleOps(Schedule s, Map<IterVar, Range> dom_map);
5656
*/
5757
bool VerifySSA(const Stmt& ir);
5858

59+
/*!
60+
* \brief Whether the expression have side effect.
61+
* \return whether expression have side effect
62+
*/
63+
bool HasSideEffect(const Expr& e);
64+
5965
/*!
6066
* \brief Convert a IR node to be SSA form.
6167
* \param stmt The source statement to be converted.
@@ -79,7 +85,6 @@ Stmt Inline(Stmt stmt,
7985
Array<Var> args,
8086
Expr body);
8187

82-
8388
/*!
8489
* \brief Flatten the multi-dimensional read/write
8590
* to single dimensional Load/Store

include/tvm/ir_visitor.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,17 @@ class IRVisitor {
3434
using FVisit = IRFunctor<void(const NodeRef&, IRVisitor*)>;
3535
/*! \return internal vtable*/
3636
static FVisit& vtable();
37+
// overloadable visit function.
38+
virtual void Visit_(const Variable* op);
39+
virtual void Visit_(const AttrStmt* op);
40+
virtual void Visit_(const LetStmt* op);
41+
virtual void Visit_(const For* op);
42+
virtual void Visit_(const Allocate* op);
43+
virtual void Visit_(const Load* op);
44+
virtual void Visit_(const Store* op);
45+
virtual void Visit_(const Let* op);
46+
virtual void Visit_(const Free* op);
47+
virtual void Visit_(const Call* op);
3748
};
3849

3950
/*!

include/tvm/module.h

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
/*!
2+
* Copyright (c) 2016 by Contributors
3+
* \file module.h
4+
* \brief Low level IR module,
5+
* Contains lowered function information.
6+
*/
7+
#ifndef TVM_MODULE_H_
8+
#define TVM_MODULE_H_
9+
10+
#include <tvm/container.h>
11+
#include <ir/FunctionBase.h>
12+
#include <string>
13+
14+
#include "./base.h"
15+
#include "./expr.h"
16+
#include "./tensor.h"
17+
18+
namespace tvm {
19+
20+
// Internal node container of lowered function.
21+
class LoweredFuncNode;
22+
23+
// Internal node container of module.
24+
class ModuleNode;
25+
26+
/*!
27+
* \brief LoweredFunc represents function after lowering.
28+
* This is the final IR representation before codegen.
29+
*/
30+
class LoweredFunc : public FunctionRef {
31+
public:
32+
LoweredFunc() {}
33+
explicit LoweredFunc(std::shared_ptr<Node> n) : FunctionRef(n) {}
34+
/*!
35+
* \brief access the internal node container
36+
* \return the pointer to the internal node container
37+
*/
38+
inline const LoweredFuncNode* operator->() const;
39+
/*! \brief specify container node */
40+
using ContainerType = LoweredFuncNode;
41+
};
42+
43+
/*! \brief Node container of LoweredFunc */
44+
class LoweredFuncNode : public FunctionBaseNode {
45+
public:
46+
/*! \brief The name of the function */
47+
std::string name;
48+
/*!
49+
* \brief The arguments of the function
50+
* This function can only take pod type(int, float) and void* as arguments.
51+
*/
52+
Array<Var> args;
53+
/*!
54+
* \brief The IterVar axis of threads
55+
* Each axis need host function to specify a size.
56+
* \note Calling convention into LoweredFunc
57+
*
58+
* Assume we have a LoweredFunc f, a call into f
59+
* Call(f, arg1, arg2, ..., arg_n,
60+
* size_axis_1, size_axis_2, ... size_axis_m)
61+
*
62+
* Here n = len(args), m = len(thread_axis)
63+
*
64+
* The CodeGen should take this and translate this call
65+
* to corresponding API specific kernel launchs or function calls.
66+
*/
67+
Array<IterVar> thread_axis;
68+
/*!
69+
* \brief The hint data type of Var handles defined in LetStmt
70+
* Can be used as hint when generating type signiture.
71+
* The creation rule is given by
72+
* handle_data_type[var_handle] = make_const(the_type, 0);
73+
*
74+
* \note Expr is used instead Type, because Type cannot be hold by Map.
75+
* constant Expr of given type is used.
76+
*/
77+
Map<Var, Expr> handle_data_type;
78+
/*! \brief The body statment of the function */
79+
Stmt body;
80+
/*! \return name of the operation */
81+
const std::string& func_name() const final {
82+
return name;
83+
}
84+
// there is no return value, but return 1
85+
// to enable Call into this function.
86+
int num_outputs() const final {
87+
return 1;
88+
}
89+
void VisitAttrs(AttrVisitor* v) final {
90+
v->Visit("name", &name);
91+
v->Visit("args", &args);
92+
v->Visit("thread_axis", &thread_axis);
93+
v->Visit("handle_data_type", &handle_data_type);
94+
v->Visit("body", &body);
95+
}
96+
97+
static constexpr const char* _type_key = "LoweredFunc";
98+
TVM_DECLARE_NODE_TYPE_INFO(LoweredFuncNode);
99+
};
100+
101+
// Implementations of inline functions
102+
inline const LoweredFuncNode* LoweredFunc::operator->() const {
103+
return static_cast<const LoweredFuncNode*>(node_.get());
104+
}
105+
106+
} // namespace tvm
107+
108+
#endif // TVM_MODULE_H_

python/tvm/collections.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,9 @@ class IterVar(NodeBase, _expr.ExprOp):
5656
class Buffer(NodeBase):
5757
"""Represent a Buffer in TVM."""
5858
pass
59+
60+
61+
@register_node
62+
class LoweredFunc(NodeBase):
63+
"""Represent a LoweredFunc in TVM."""
64+
pass

0 commit comments

Comments
 (0)