Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CODEGEN] Add LoweredFunc, MakeAPI to build a C API function #23

Merged
merged 2 commits into from
Jan 23, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion HalideIR
Submodule HalideIR updated from adfa66 to 30bf0f
3 changes: 3 additions & 0 deletions include/tvm/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ class Buffer : public NodeRef {
* \return the pointer to the internal node container
*/
inline const BufferNode* operator->() const;

/*! \brief specify container node */
using ContainerType = BufferNode;
};

/*! \brief Node to represent a buffer */
Expand Down
40 changes: 34 additions & 6 deletions include/tvm/c_runtime_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#endif

#include <stdint.h>
#include <stddef.h>


TVM_EXTERN_C {
Expand Down Expand Up @@ -216,18 +217,45 @@ TVM_DLL int TVMArrayCopyFromTo(TVMArrayHandle from,
TVM_DLL int TVMSynchronize(TVMContext ctx, TVMStreamHandle stream);

/*!
* \brief Launch a generated TVM function
* \brief TVM Function API: Get resource requirement
*
* By default TVM function try not to do internal allocations.
* Instead, TVMFuncRequirement can be called, given the input arguments.
*
* \param func function handle to be launched.
* \param args The arguments
* \param arg_type_ids The type id of the arguments
* \param num_args Number of arguments.
* \param out_workspace_size The workspace size needed to launch this function.
* \param out_workspace_align The alignment requirement of workspace.
*
* \note The data pointer in the arrays is not used by requirement.
*/
TVM_DLL int TVMFuncRequirement(TVMFunctionHandle func,
TVMArg* args,
int* arg_type_ids,
int num_args,
size_t* out_workspace_size,
size_t* out_workspace_align);

/*!
* \brief TVM Function API: Launch generated function.
*
* \param func function handle to be launched.
* \param args The arguments
* \param arg_type_ids The type id of the arguments
* \param num_args Number of arguments.
* \param stream The stream this function to be launched on.
* \param workspace Additional workspace used to launch this function.
*
* \sa TVMFuncRequirement
*/
TVM_DLL int TVMLaunch(TVMFunctionHandle func,
TVMArg* args,
int* arg_type_ids,
int num_args,
TVMStreamHandle stream);
TVM_DLL int TVMFuncLaunch(TVMFunctionHandle func,
TVMArg* args,
int* arg_type_ids,
int num_args,
TVMStreamHandle stream,
TVMArrayHandle workspace);
} // TVM_EXTERN_C

#endif // TVM_C_RUNTIME_API_H_
68 changes: 68 additions & 0 deletions include/tvm/codegen.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*!
* Copyright (c) 2016 by Contributors
* \file codegen.h
* \brief Collection of Lowlevel IR pass to codegen.
*/
#ifndef TVM_CODEGEN_H_
#define TVM_CODEGEN_H_

#include <string>
#include "./base.h"
#include "./expr.h"
#include "./module.h"

namespace tvm {
/*! \brief namespace for lowlevel IR pass and codegen */
namespace codegen {
/*!
* \brief Make an user callable API LoweredFunc.
*
* The main task of this function is to create code to :
* - Map the values in the api_args to of Var that is required by body.
* - Insert assertions to check type/value of the passed arguments.
*
* \param body The body of the function.
* \param name The name of the function.
* \param api_args Arguments to the function, can be either Var, or Buffer
* \param num_packed_args Number of arguments that are processed in packed form.
* \return a LoweredFunc with the specified signiture.
*
* \note
* The function signiture have two cases
*
* if num_packed_args is zero:
* f(api_arg_0, api_arg_1, .., api_arg_n) where n == len(api_args)
*
* if num_packed_args is not zero:
* f(TVMArg* packed_args, int* packed_arg_type_ids, int num_packed_args,
* api_arg_k, api_arg_k+1, ... api_arg_n)
*
* where n == len(api_args), k == num_packed_args
*
* There is no thread_axis in generated function.
*/
LoweredFunc MakeAPI(Stmt body,
std::string name,
Array<NodeRef> api_args,
int num_packed_args);

/*!
* \brief Count number of undefined vars in f.
* \param f The function to be checked.
* \return Number of undefined vars.
*/
Array<Var> UndefinedVars(const LoweredFunc& f);

/*!
* \brief Split the function into a host function and device functions.
* \param func The function to be splitted.
*
* \return Array of functions, the first one is host function,
* the others are device functions.
*/
Array<LoweredFunc> SplitHostDevice(LoweredFunc func);

} // namespace codegen
} // namespace tvm

#endif // TVM_CODEGEN_H_
42 changes: 42 additions & 0 deletions include/tvm/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,48 @@ struct Reduce : public ExprNode<Reduce> {
static constexpr const char* Min = "Min";
};

/*! \brief namespace of TVM Intrinsic functions */
namespace intrinsic {
// Most of the intrinsics is to enab
/*!
* \brief See pesudo code
*
* Type tvm_api_load_arg(TVMArg* args, int* args_type_id, i) {
* assert(arg_type_id[i] == typeid(Type));
* return args[i];
* }
*/
constexpr const char* tvm_api_load_arg = "tvm_api_load_arg";
/*!
* \brief See pesudo code
*
* Type tvm_array_get_field(TVMArray* arr, int field_id) {
* return arr->field;
* }
* \sa TVMArrayFieldKind
*/
constexpr const char* tvm_array_get_field = "tvm_array_get_field";
/*!
* \brief See pesudo code
*
* bool tvm_handle_is_null(void* handle) {
* return handle == nullptr
* }
*/
constexpr const char* tvm_handle_is_null = "tvm_handle_is_null";

/*! \brief The field id of each field in array */
enum TVMArrayFieldKind {
kData = 0,
kNDim = 1,
kShape = 2,
kStrides = 3,
kTypeCode = 4,
kTypeBits = 5,
kTypeLanes = 6
};
} // namespace intrinsic

// Reuse IR node defintiion from HalideIR
using Halide::Internal::IntImm;
using Halide::Internal::UIntImm;
Expand Down
15 changes: 15 additions & 0 deletions include/tvm/ir_mutator.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <tvm/ir_functor.h>
#include <unordered_map>
#include "./expr.h"
#include "./ir.h"

namespace tvm {
namespace ir {
Expand Down Expand Up @@ -51,6 +52,20 @@ class IRMutator {
static FMutateExpr& vtable_expr(); // NOLINT(*)
/*! \return internal stmt of expr */
static FMutateStmt& vtable_stmt(); // NOLINT(*)
// Set of overloadable functions
// The underscore allows Mutate not to be shadowed by inheritance
virtual Stmt Mutate_(const LetStmt* op, const Stmt& s);
virtual Stmt Mutate_(const AttrStmt* op, const Stmt& s);
virtual Stmt Mutate_(const For* op, const Stmt& s);
virtual Stmt Mutate_(const Provide* op, const Stmt& s);
virtual Stmt Mutate_(const Allocate* op, const Stmt& s);
virtual Stmt Mutate_(const Realize* op, const Stmt& s);
virtual Stmt Mutate_(const Store* op, const Stmt& s);
virtual Stmt Mutate_(const Free* op, const Stmt& s);
virtual Expr Mutate_(const Call* op, const Expr& e);
virtual Expr Mutate_(const Load* op, const Expr& s);
virtual Expr Mutate_(const Variable* op, const Expr& e);
virtual Expr Mutate_(const Let* op, const Expr& e);
};

/*!
Expand Down
7 changes: 6 additions & 1 deletion include/tvm/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@ Stmt ScheduleOps(Schedule s, Map<IterVar, Range> dom_map);
*/
bool VerifySSA(const Stmt& ir);

/*!
* \brief Whether the expression have side effect.
* \return whether expression have side effect
*/
bool HasSideEffect(const Expr& e);

/*!
* \brief Convert a IR node to be SSA form.
* \param stmt The source statement to be converted.
Expand All @@ -79,7 +85,6 @@ Stmt Inline(Stmt stmt,
Array<Var> args,
Expr body);


/*!
* \brief Flatten the multi-dimensional read/write
* to single dimensional Load/Store
Expand Down
11 changes: 11 additions & 0 deletions include/tvm/ir_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,17 @@ class IRVisitor {
using FVisit = IRFunctor<void(const NodeRef&, IRVisitor*)>;
/*! \return internal vtable*/
static FVisit& vtable();
// overloadable visit function.
virtual void Visit_(const Variable* op);
virtual void Visit_(const AttrStmt* op);
virtual void Visit_(const LetStmt* op);
virtual void Visit_(const For* op);
virtual void Visit_(const Allocate* op);
virtual void Visit_(const Load* op);
virtual void Visit_(const Store* op);
virtual void Visit_(const Let* op);
virtual void Visit_(const Free* op);
virtual void Visit_(const Call* op);
};

/*!
Expand Down
108 changes: 108 additions & 0 deletions include/tvm/module.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/*!
* Copyright (c) 2016 by Contributors
* \file module.h
* \brief Low level IR module,
* Contains lowered function information.
*/
#ifndef TVM_MODULE_H_
#define TVM_MODULE_H_

#include <tvm/container.h>
#include <ir/FunctionBase.h>
#include <string>

#include "./base.h"
#include "./expr.h"
#include "./tensor.h"

namespace tvm {

// Internal node container of lowered function.
class LoweredFuncNode;

// Internal node container of module.
class ModuleNode;

/*!
* \brief LoweredFunc represents function after lowering.
* This is the final IR representation before codegen.
*/
class LoweredFunc : public FunctionRef {
public:
LoweredFunc() {}
explicit LoweredFunc(std::shared_ptr<Node> n) : FunctionRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const LoweredFuncNode* operator->() const;
/*! \brief specify container node */
using ContainerType = LoweredFuncNode;
};

/*! \brief Node container of LoweredFunc */
class LoweredFuncNode : public FunctionBaseNode {
public:
/*! \brief The name of the function */
std::string name;
/*!
* \brief The arguments of the function
* This function can only take pod type(int, float) and void* as arguments.
*/
Array<Var> args;
/*!
* \brief The IterVar axis of threads
* Each axis need host function to specify a size.
* \note Calling convention into LoweredFunc
*
* Assume we have a LoweredFunc f, a call into f
* Call(f, arg1, arg2, ..., arg_n,
* size_axis_1, size_axis_2, ... size_axis_m)
*
* Here n = len(args), m = len(thread_axis)
*
* The CodeGen should take this and translate this call
* to corresponding API specific kernel launchs or function calls.
*/
Array<IterVar> thread_axis;
/*!
* \brief The hint data type of Var handles defined in LetStmt
* Can be used as hint when generating type signiture.
* The creation rule is given by
* handle_data_type[var_handle] = make_const(the_type, 0);
*
* \note Expr is used instead Type, because Type cannot be hold by Map.
* constant Expr of given type is used.
*/
Map<Var, Expr> handle_data_type;
/*! \brief The body statment of the function */
Stmt body;
/*! \return name of the operation */
const std::string& func_name() const final {
return name;
}
// there is no return value, but return 1
// to enable Call into this function.
int num_outputs() const final {
return 1;
}
void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name);
v->Visit("args", &args);
v->Visit("thread_axis", &thread_axis);
v->Visit("handle_data_type", &handle_data_type);
v->Visit("body", &body);
}

static constexpr const char* _type_key = "LoweredFunc";
TVM_DECLARE_NODE_TYPE_INFO(LoweredFuncNode);
};

// Implementations of inline functions
inline const LoweredFuncNode* LoweredFunc::operator->() const {
return static_cast<const LoweredFuncNode*>(node_.get());
}

} // namespace tvm

#endif // TVM_MODULE_H_
6 changes: 6 additions & 0 deletions python/tvm/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,9 @@ class IterVar(NodeBase, _expr.ExprOp):
class Buffer(NodeBase):
"""Represent a Buffer in TVM."""
pass


@register_node
class LoweredFunc(NodeBase):
"""Represent a LoweredFunc in TVM."""
pass
3 changes: 2 additions & 1 deletion src/base/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#define TVM_BASE_COMMON_H_

#include <tvm/base.h>
#include <tvm/expr.h>
#include <string>

namespace tvm {
Expand All @@ -30,7 +31,7 @@ inline Type String2Type(std::string s) {
} else if (s.substr(0, 5) == "float") {
code = Type::Float; s = s.substr(5);
} else if (s == "handle") {
return Type(Type::Handle, 32, 1);
return Handle();
} else {
LOG(FATAL) << "unknown type " << s;
}
Expand Down
Loading