Skip to content

Commit

Permalink
[PASS] StorageFlatten and StorageSync, safe condition in schedul_ops,…
Browse files Browse the repository at this point in the history
… gemm example. (#31)
  • Loading branch information
tqchen authored Feb 4, 2017
1 parent a2c8a29 commit d89917b
Show file tree
Hide file tree
Showing 34 changed files with 1,113 additions and 389 deletions.
48 changes: 0 additions & 48 deletions include/tvm/codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,54 +21,6 @@ using runtime::PackedFunc;
using runtime::TVMArgs;
using runtime::TVMRetValue;

/*!
* \brief Make an user callable API LoweredFunc.
*
* The main task of this function is to create code to :
* - Map the values in the api_args to of Var that is required by body.
* - Insert assertions to check type/value of the passed arguments.
*
* \param body The body of the function.
* \param name The name of the function.
* \param api_args Arguments to the function, can be either Var, or Buffer
* \param num_packed_args Number of arguments that are processed in packed form.
* \return a LoweredFunc with the specified signiture.
*
* \note
* The function signiture have two cases
*
* if num_packed_args is zero:
* f(api_arg_0, api_arg_1, .., api_arg_n) where n == len(api_args)
*
* if num_packed_args is not zero:
* f(TVMArg* packed_args, int* packed_arg_type_ids, int num_packed_args,
* api_arg_k, api_arg_k+1, ... api_arg_n)
*
* where n == len(api_args), k == num_packed_args
*
* There is no thread_axis in generated function.
*/
LoweredFunc MakeAPI(Stmt body,
std::string name,
Array<NodeRef> api_args,
int num_packed_args);

/*!
* \brief Count number of undefined vars in f.
* \param f The function to be checked.
* \return Number of undefined vars.
*/
Array<Var> UndefinedVars(const LoweredFunc& f);

/*!
* \brief Split the function into a host function and device functions.
* \param func The function to be splitted.
*
* \return Array of functions, the first one is host function,
* the others are device functions.
*/
Array<LoweredFunc> SplitHostDevice(LoweredFunc func);

/*!
* \brief Build a stack VM function.
* \param func The LoweredFunc to be build
Expand Down
19 changes: 19 additions & 0 deletions include/tvm/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,25 @@ constexpr const char* tvm_handle_is_null = "tvm_handle_is_null";
* }
*/
constexpr const char* tvm_call_global = "tvm_call_global";
/*!
* \brief See pesudo code
*
* int tvm_call_device(name, TVMValue* args) {
* PackedFunc df = CodeGenEnv->GetDevice(name);
* f (args, type_code_of(args), len(args));
* return 0;
* }
*/
constexpr const char* tvm_call_device = "tvm_call_device";
/*!
* \brief See pesudo code
*
* int tvm_storage_sync(std::string storage_scope) {
* __sync(storage_scope);
* return 0;
* }
*/
constexpr const char* tvm_storage_sync = "tvm_storage_sync";

/*! \brief The field id of each field in array */
enum TVMArrayFieldKind {
Expand Down
58 changes: 58 additions & 0 deletions include/tvm/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@
#include <tvm/ir_functor.h>
#include <unordered_map>
#include <vector>
#include <string>
#include "./expr.h"
#include "./buffer.h"
#include "./schedule.h"
#include "./lowered_func.h"

namespace tvm {
namespace ir {
Expand Down Expand Up @@ -95,6 +97,62 @@ Stmt Inline(Stmt stmt,
Stmt StorageFlatten(Stmt stmt,
Map<Tensor, Buffer> extern_buffer);

/*!
* \brief Make an user callable API LoweredFunc.
*
* The main task of this function is to create code to :
* - Map the values in the api_args to of Var that is required by body.
* - Insert assertions to check type/value of the passed arguments.
*
* \param body The body of the function.
* \param name The name of the function.
* \param api_args Arguments to the function, can be either Var, or Buffer
* \param num_packed_args Number of arguments that are processed in packed form.
* \return a LoweredFunc with the specified signiture.
*
* \note
* The function signiture have two cases
*
* if num_packed_args is zero:
* f(api_arg_0, api_arg_1, .., api_arg_n) where n == len(api_args)
*
* if num_packed_args is not zero:
* f(TVMArg* packed_args, int* packed_arg_type_ids, int num_packed_args,
* api_arg_k, api_arg_k+1, ... api_arg_n)
*
* where n == len(api_args), k == num_packed_args
*
* There is no thread_axis in generated function.
*/
LoweredFunc MakeAPI(Stmt body,
std::string name,
Array<NodeRef> api_args,
int num_packed_args);

/*!
* \brief Count number of undefined vars in f.
* \param f The function to be checked.
* \return Number of undefined vars.
*/
Array<Var> UndefinedVars(const LoweredFunc& f);

/*!
* \brief Split the function into a host function and device functions.
* \param func The function to be splitted.
*
* \return Array of functions, the first one is host function,
* the others are device functions.
*/
Array<LoweredFunc> SplitHostDevice(LoweredFunc func);

/*!
* \brief Insert sync between parallel read/write of shared buffers.
*
* \param stmt The stmt to be trasnformed.
* \param storage_scope The storage scope considered.
*/
LoweredFunc StorageSync(LoweredFunc stmt, std::string storage_scope);

} // namespace ir
} // namespace tvm

Expand Down
1 change: 1 addition & 0 deletions include/tvm/ir_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class IRVisitor {
virtual void Visit_(const LetStmt* op);
virtual void Visit_(const For* op);
virtual void Visit_(const Allocate* op);
virtual void Visit_(const IfThenElse* op);
virtual void Visit_(const Load* op);
virtual void Visit_(const Store* op);
virtual void Visit_(const Let* op);
Expand Down
12 changes: 7 additions & 5 deletions python/tvm/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,17 +65,19 @@ def build(sch,
stmt = schedule.ScheduleOps(sch, bounds)
stmt = ir_pass.StorageFlatten(stmt, binds)
stmt = ir_pass.Simplify(stmt)
print(stmt)
fapi = codegen.MakeAPI(stmt, name, arg_list, len(arg_list))
fsplits = codegen.SplitHostDevice(fapi)
fapi = ir_pass.MakeAPI(stmt, name, arg_list, len(arg_list))
fsplits = ir_pass.SplitHostDevice(fapi)
fsplits = [x for x in fsplits]
for i in range(1, len(fsplits)):
fsplits[i] = ir_pass.StorageSync(fsplits[i], "shared")
fsplits[i] = ir_pass.StorageSync(fsplits[i], "global")

if record_codes is not None:
output_ssa = False
for i, f in enumerate(fsplits):
t = target if i >= 1 else "c"
record_codes.append(codegen.CompileToC(f, output_ssa, t))
for c in record_codes:
print(c)

if target == "cuda":
ret = codegen.BuildNVRTC(fsplits, "stackvm")
elif target == "opencl":
Expand Down
12 changes: 0 additions & 12 deletions src/api/api_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,6 @@ TVM_REGISTER_API(_codegen_CompileToC)
}
});


TVM_REGISTER_API(_codegen_MakeAPI)
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = MakeAPI(
args[0], args[1], args[2], args[3]);
});

TVM_REGISTER_API(_codegen_SplitHostDevice)
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = SplitHostDevice(args[0]);
});

TVM_REGISTER_API(_codegen_BuildStackVM)
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = BuildStackVM(args[0],
Expand Down
3 changes: 3 additions & 0 deletions src/api/api_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ REGISTER_PASS1(ConvertSSA);
REGISTER_PASS1(VerifySSA);
REGISTER_PASS4(Inline);
REGISTER_PASS2(StorageFlatten);
REGISTER_PASS2(StorageSync);
REGISTER_PASS4(MakeAPI);
REGISTER_PASS1(SplitHostDevice);

} // namespace ir
} // namespace tvm
Loading

0 comments on commit d89917b

Please sign in to comment.