Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PASS] StorageSync, safe condition in ScheduleOps, GEMM Example #31

Merged
merged 1 commit into from
Feb 4, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 0 additions & 48 deletions include/tvm/codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,54 +21,6 @@ using runtime::PackedFunc;
using runtime::TVMArgs;
using runtime::TVMRetValue;

/*!
* \brief Make an user callable API LoweredFunc.
*
* The main task of this function is to create code to :
* - Map the values in the api_args to of Var that is required by body.
* - Insert assertions to check type/value of the passed arguments.
*
* \param body The body of the function.
* \param name The name of the function.
* \param api_args Arguments to the function, can be either Var, or Buffer
* \param num_packed_args Number of arguments that are processed in packed form.
* \return a LoweredFunc with the specified signiture.
*
* \note
* The function signiture have two cases
*
* if num_packed_args is zero:
* f(api_arg_0, api_arg_1, .., api_arg_n) where n == len(api_args)
*
* if num_packed_args is not zero:
* f(TVMArg* packed_args, int* packed_arg_type_ids, int num_packed_args,
* api_arg_k, api_arg_k+1, ... api_arg_n)
*
* where n == len(api_args), k == num_packed_args
*
* There is no thread_axis in generated function.
*/
LoweredFunc MakeAPI(Stmt body,
std::string name,
Array<NodeRef> api_args,
int num_packed_args);

/*!
* \brief Count number of undefined vars in f.
* \param f The function to be checked.
* \return Number of undefined vars.
*/
Array<Var> UndefinedVars(const LoweredFunc& f);

/*!
* \brief Split the function into a host function and device functions.
* \param func The function to be splitted.
*
* \return Array of functions, the first one is host function,
* the others are device functions.
*/
Array<LoweredFunc> SplitHostDevice(LoweredFunc func);

/*!
* \brief Build a stack VM function.
* \param func The LoweredFunc to be build
Expand Down
19 changes: 19 additions & 0 deletions include/tvm/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,25 @@ constexpr const char* tvm_handle_is_null = "tvm_handle_is_null";
* }
*/
constexpr const char* tvm_call_global = "tvm_call_global";
/*!
* \brief See pesudo code
*
* int tvm_call_device(name, TVMValue* args) {
* PackedFunc df = CodeGenEnv->GetDevice(name);
* f (args, type_code_of(args), len(args));
* return 0;
* }
*/
constexpr const char* tvm_call_device = "tvm_call_device";
/*!
* \brief See pesudo code
*
* int tvm_storage_sync(std::string storage_scope) {
* __sync(storage_scope);
* return 0;
* }
*/
constexpr const char* tvm_storage_sync = "tvm_storage_sync";

/*! \brief The field id of each field in array */
enum TVMArrayFieldKind {
Expand Down
58 changes: 58 additions & 0 deletions include/tvm/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@
#include <tvm/ir_functor.h>
#include <unordered_map>
#include <vector>
#include <string>
#include "./expr.h"
#include "./buffer.h"
#include "./schedule.h"
#include "./lowered_func.h"

namespace tvm {
namespace ir {
Expand Down Expand Up @@ -95,6 +97,62 @@ Stmt Inline(Stmt stmt,
Stmt StorageFlatten(Stmt stmt,
Map<Tensor, Buffer> extern_buffer);

/*!
* \brief Make an user callable API LoweredFunc.
*
* The main task of this function is to create code to :
* - Map the values in the api_args to of Var that is required by body.
* - Insert assertions to check type/value of the passed arguments.
*
* \param body The body of the function.
* \param name The name of the function.
* \param api_args Arguments to the function, can be either Var, or Buffer
* \param num_packed_args Number of arguments that are processed in packed form.
* \return a LoweredFunc with the specified signiture.
*
* \note
* The function signiture have two cases
*
* if num_packed_args is zero:
* f(api_arg_0, api_arg_1, .., api_arg_n) where n == len(api_args)
*
* if num_packed_args is not zero:
* f(TVMArg* packed_args, int* packed_arg_type_ids, int num_packed_args,
* api_arg_k, api_arg_k+1, ... api_arg_n)
*
* where n == len(api_args), k == num_packed_args
*
* There is no thread_axis in generated function.
*/
LoweredFunc MakeAPI(Stmt body,
std::string name,
Array<NodeRef> api_args,
int num_packed_args);

/*!
* \brief Count number of undefined vars in f.
* \param f The function to be checked.
* \return Number of undefined vars.
*/
Array<Var> UndefinedVars(const LoweredFunc& f);

/*!
* \brief Split the function into a host function and device functions.
* \param func The function to be splitted.
*
* \return Array of functions, the first one is host function,
* the others are device functions.
*/
Array<LoweredFunc> SplitHostDevice(LoweredFunc func);

/*!
* \brief Insert sync between parallel read/write of shared buffers.
*
* \param stmt The stmt to be trasnformed.
* \param storage_scope The storage scope considered.
*/
LoweredFunc StorageSync(LoweredFunc stmt, std::string storage_scope);

} // namespace ir
} // namespace tvm

Expand Down
1 change: 1 addition & 0 deletions include/tvm/ir_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class IRVisitor {
virtual void Visit_(const LetStmt* op);
virtual void Visit_(const For* op);
virtual void Visit_(const Allocate* op);
virtual void Visit_(const IfThenElse* op);
virtual void Visit_(const Load* op);
virtual void Visit_(const Store* op);
virtual void Visit_(const Let* op);
Expand Down
12 changes: 7 additions & 5 deletions python/tvm/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,17 +65,19 @@ def build(sch,
stmt = schedule.ScheduleOps(sch, bounds)
stmt = ir_pass.StorageFlatten(stmt, binds)
stmt = ir_pass.Simplify(stmt)
print(stmt)
fapi = codegen.MakeAPI(stmt, name, arg_list, len(arg_list))
fsplits = codegen.SplitHostDevice(fapi)
fapi = ir_pass.MakeAPI(stmt, name, arg_list, len(arg_list))
fsplits = ir_pass.SplitHostDevice(fapi)
fsplits = [x for x in fsplits]
for i in range(1, len(fsplits)):
fsplits[i] = ir_pass.StorageSync(fsplits[i], "shared")
fsplits[i] = ir_pass.StorageSync(fsplits[i], "global")

if record_codes is not None:
output_ssa = False
for i, f in enumerate(fsplits):
t = target if i >= 1 else "c"
record_codes.append(codegen.CompileToC(f, output_ssa, t))
for c in record_codes:
print(c)

if target == "cuda":
ret = codegen.BuildNVRTC(fsplits, "stackvm")
elif target == "opencl":
Expand Down
12 changes: 0 additions & 12 deletions src/api/api_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,6 @@ TVM_REGISTER_API(_codegen_CompileToC)
}
});


TVM_REGISTER_API(_codegen_MakeAPI)
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = MakeAPI(
args[0], args[1], args[2], args[3]);
});

TVM_REGISTER_API(_codegen_SplitHostDevice)
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = SplitHostDevice(args[0]);
});

TVM_REGISTER_API(_codegen_BuildStackVM)
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = BuildStackVM(args[0],
Expand Down
3 changes: 3 additions & 0 deletions src/api/api_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ REGISTER_PASS1(ConvertSSA);
REGISTER_PASS1(VerifySSA);
REGISTER_PASS4(Inline);
REGISTER_PASS2(StorageFlatten);
REGISTER_PASS2(StorageSync);
REGISTER_PASS4(MakeAPI);
REGISTER_PASS1(SplitHostDevice);

} // namespace ir
} // namespace tvm
Loading