Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add threadIdx filtering in Multi-Level-Tiling and Verify-GPU-Code #20

Merged
merged 4 commits into from
Jan 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -1373,9 +1373,17 @@ constexpr const char* meta_schedule_cooperative_fetch = "meta_schedule.cooperati
*/
constexpr const char* meta_schedule_auto_tensorize = "meta_schedule.auto_tensorize";

/*! \brief Mark that tensor core is enbaled in the PrimExpr */
/*! \brief Mark that tensor core is enabled in the PrimExpr */
constexpr const char* meta_schedule_tensor_core_enabled = "meta_schedule.tensor_core_enabled";

/*! \brief The allowed range of thread extent in thread bindings */
constexpr const char* meta_schedule_thread_extent_low_inclusive =
"meta_schedule.thread_extent_low_inclusive";

/*! \brief The allowed range of thread extent in thread bindings */
constexpr const char* meta_schedule_thread_extent_high_inclusive =
"meta_schedule.thread_extent_high_inclusive";

/*!
* \brief Mark a block as generated by cache_read or cache_write block.
* 0 means cache_read; 1 means cache_write.
Expand Down
53 changes: 53 additions & 0 deletions src/meta_schedule/postproc/verify_gpu_code.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,56 @@

#include "../utils.h"

namespace tvm {
namespace tir {

class ThreadExtentChecker : private StmtVisitor {
public:
static bool Check(const Stmt& stmt) {
try {
ThreadExtentChecker().VisitStmt(stmt);
return true;
} catch (const dmlc::Error& e) {
return false;
}
}

private:
void VisitStmt_(const ForNode* loop) {
if (IsThreadIdx(GetThreadScope(loop))) {
if (const int64_t* p_ext = GetLoopIntExtent(loop)) {
thread_extent_product *= *p_ext;
StmtVisitor::VisitStmt_(loop);
thread_extent_product /= *p_ext;
return;
} else {
throw dmlc::Error("Dynamic thread extent");
}
}
StmtVisitor::VisitStmt_(loop);
}

void VisitStmt_(const BlockNode* block) {
if (Optional<Integer> low_inclusive =
GetAnn<Integer>(block, attr::meta_schedule_thread_extent_low_inclusive)) {
if (Optional<Integer> high_inclusive =
GetAnn<Integer>(block, attr::meta_schedule_thread_extent_high_inclusive)) {
int64_t low = low_inclusive.value()->value;
int64_t high = high_inclusive.value()->value;
if (!(low <= thread_extent_product && thread_extent_product <= high)) {
throw dmlc::Error("Thread extent");
}
}
}
StmtVisitor::VisitStmt_(block);
}

int64_t thread_extent_product = 1;
};

} // namespace tir
} // namespace tvm

namespace tvm {
namespace meta_schedule {

Expand Down Expand Up @@ -66,6 +116,9 @@ class VerifyGPUCodeNode : public PostprocNode {
const GlobalVar& g_var = kv.first;
const BaseFunc& base_func = kv.second;
if (const auto* prim_func = base_func.as<tir::PrimFuncNode>()) {
if (!tir::ThreadExtentChecker::Check(prim_func->body)) {
return false;
}
IRModule lowered{nullptr};
try {
auto pass_list = Array<tvm::transform::Pass>();
Expand Down
41 changes: 39 additions & 2 deletions src/meta_schedule/schedule_rule/multi_level_tiling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,17 @@ class MultiLevelTilingNode : public ScheduleRuleNode {
}

// Do nothing; Inherited from ScheduleRuleNode
void InitializeWithTuneContext(const TuneContext& context) final {}
void InitializeWithTuneContext(const TuneContext& context) final {
if (Optional<Integer> v = context->target.value()->GetAttr<Integer>("max_threads_per_block")) {
this->max_threads_per_block_ = v.value()->value;
if (Optional<Integer> v = context->target.value()->GetAttr<Integer>("thread_warp_size")) {
this->thread_warp_size_ = v.value()->value;
} else {
LOG(INFO) << "'thread_warp_size' is not defined in the target";
}
}
}

// Entry of the mega rule; Inherited from ScheduleRuleNode
Array<Schedule> Apply(const Schedule& sch, const BlockRV& block_rv) final {
if (!NeedsMultiLevelTiling(sch->state(), sch->GetSRef(block_rv))) {
Expand Down Expand Up @@ -331,6 +341,10 @@ class MultiLevelTilingNode : public ScheduleRuleNode {
std::vector<int> s_indices_;
/*! \brief The indices of reduction tiles in `structure` */
std::vector<int> r_indices_;
/*! \brief The size of the thread warp */
int thread_warp_size_;
/*! \brief The maximum number of threads to be used size of a thread warp */
int max_threads_per_block_;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("structure", &structure);
Expand All @@ -342,6 +356,8 @@ class MultiLevelTilingNode : public ScheduleRuleNode {
// `reuse_write_` is not visited
// `s_indices_` is not visited
// `r_indices_` is not visited
// `thread_warp_size_` is not visited
// `max_threads_per_block` is not visited
}

static constexpr const char* _type_key = "meta_schedule.MultiLevelTiling";
Expand Down Expand Up @@ -419,19 +435,27 @@ inline std::vector<State> MultiLevelTilingNode::TileLoopNest(State state) const
std::vector<IterVarType> iter_types = GetBlockVarTypes(sch->GetSRef(state.block_rv));
ICHECK_EQ(loops.size(), iter_types.size());
// Step 2. For each loop axis, tile it
int64_t spatial_loop_product = 1;
std::vector<Array<LoopRV>> tiles(s_indices_.size() + r_indices_.size());
for (int i = 0, n = loops.size(); i < n; ++i) {
LoopRV loop = loops[i];
const std::vector<int>* idx = nullptr;
if (iter_types[i] == IterVarType::kDataPar) {
idx = &s_indices_;
if (spatial_loop_product != -1) {
if (const int64_t* extent = tir::GetLoopIntExtent(sch->Get(loop).get())) {
spatial_loop_product *= *extent;
} else {
spatial_loop_product = -1;
}
}
} else if (iter_types[i] == IterVarType::kCommReduce) {
idx = &r_indices_;
} else {
continue;
}
// Do the split
int n_tiles = idx->size();
LoopRV loop = loops[i];
Array<ExprRV> factors = sch->SamplePerfectTile(
/*loop=*/loop,
/*n=*/n_tiles,
Expand All @@ -453,6 +477,17 @@ inline std::vector<State> MultiLevelTilingNode::TileLoopNest(State state) const
tiles[i] = {fused};
}
state.tiles = Array<Array<LoopRV>>{tiles.begin(), tiles.end()};
if (this->thread_warp_size_ != -1) {
int64_t low_inclusive = 1;
int64_t high_inclusive = this->max_threads_per_block_;
if (spatial_loop_product > 2 * this->thread_warp_size_) {
low_inclusive = this->thread_warp_size_;
}
sch->Annotate(block_rv, tir::attr::meta_schedule_thread_extent_low_inclusive,
Integer(low_inclusive));
sch->Annotate(block_rv, tir::attr::meta_schedule_thread_extent_high_inclusive,
Integer(high_inclusive));
}
return {state};
}

Expand Down Expand Up @@ -578,6 +613,8 @@ ScheduleRule ScheduleRule::MultiLevelTiling(String structure, Optional<Array<Str
LOG(FATAL) << "ValueError: Invalid tiling structure: " << structure;
}
}
n->thread_warp_size_ = -1;
n->max_threads_per_block_ = -1;
return ScheduleRule(n);
}

Expand Down
8 changes: 4 additions & 4 deletions src/meta_schedule/space_generator/post_order_apply.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,10 @@ class PostOrderApplyNode : public SpaceGeneratorNode {

Array<tir::Schedule> GenerateDesignSpace(const IRModule& mod_) final {
using ScheduleAndUnvisitedBlocks = std::pair<tir::Schedule, Array<tir::BlockRV>>;
tir::Schedule sch = tir::Schedule::Traced( //
/*mod=*/mod_, //
/*rand_state=*/ForkSeed(&this->rand_state_), //
/*debug_mode=*/tir::kVerifySRefTree | tir::kVerifyCachedFlags, //
tir::Schedule sch = tir::Schedule::Traced( //
/*mod=*/mod_, //
/*rand_state=*/ForkSeed(&this->rand_state_), //
/*debug_mode=*/0, //
/*error_render_level=*/tir::ScheduleErrorRenderLevel::kDetail);

std::vector<ScheduleAndUnvisitedBlocks> stack;
Expand Down
10 changes: 4 additions & 6 deletions tests/python/meta_schedule/run_ansor_cuda.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ RPC_HOST="192.168.6.66"
RPC_PORT="4445"
RPC_KEY="jetson-agx-xavier"
TARGET="nvidia/jetson-agx-xavier"
NUM_TRIALS=800
LOG_DIR=$HOME/logs/ansor-cuda/
NUM_TRIALS=2000

mkdir -p $LOG_DIR

Expand All @@ -23,19 +23,17 @@ run () {
2>&1 | tee "$LOG_DIR/$name.log"
}

# Single op
run C1D
run C2D
run C3D
run CAP
run DEP
run DIL
run GMM
run GRP
run NRM
run SFM
run T2D
# Subgraph
run C2d-BN-RELU
run TBG

run C3D
run NRM
run SFM
10 changes: 4 additions & 6 deletions tests/python/meta_schedule/run_meta_schedule_cuda.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ RPC_HOST="192.168.6.66"
RPC_PORT="4445"
RPC_KEY="jetson-agx-xavier"
TARGET="nvidia/jetson-agx-xavier"
LOG_DIR=/tmp/logs/ms-cuda/
LOG_DIR=$HOME/logs/ms-cuda/
NUM_TRIALS=2000

mkdir -p $LOG_DIR
Expand All @@ -25,19 +25,17 @@ run () {
2>&1 | tee "$work_dir/$name.log"
}

# Single op
run C1D
run C2D
run C3D
run CAP
run DEP
run DIL
run GMM
run GRP
run NRM
run SFM
run T2D
# Subgraph
run C2d-BN-RELU
run TBG

run C3D
run NRM
run SFM
Loading