Skip to content

[Refactor] Phaseout Pass ParallelLoopTransformer #611

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 0 additions & 153 deletions src/transform/layout_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,158 +50,6 @@ using namespace tir;
using arith::IRMutatorWithAnalyzer;
using arith::IRVisitorWithAnalyzer;

class ParallelLoopTransformer : public IRMutatorWithAnalyzer {
public:
static Stmt Substitute(Stmt stmt, bool skip_thread_partition = false) {
arith::Analyzer analyzer;
ParallelLoopTransformer transformer(&analyzer);
return transformer.VisitStmt(stmt);
}

ParallelLoopTransformer(arith::Analyzer *analyzer)
: IRMutatorWithAnalyzer(analyzer) {}

Stmt VisitStmt_(const ForNode *op) final {
if (op->kind != ForKind::kParallel)
return StmtMutator::VisitStmt_(op);

// Collect loop variables and ranges
auto for_node = GetRef<For>(op);
Array<Var> loop_vars;
Array<PrimExpr> loop_extents;
Stmt body = op->body;

// Bind the range of outer loop variables
analyzer_->Bind(op->loop_var, Range::FromMinExtent(0, op->extent));
loop_vars.push_back(op->loop_var);
loop_extents.push_back(op->extent);

// If there are inner loops, bind their ranges as well
while (const ForNode *inner = body.as<ForNode>()) {
analyzer_->Bind(inner->loop_var, Range::FromMinExtent(0, inner->extent));
loop_vars.push_back(inner->loop_var);
loop_extents.push_back(inner->extent);
body = inner->body;
}

ICHECK(loop_vars.size() == loop_extents.size())
<< "loop_vars and loop_extents size mismatch";

// Collect buffer access information
BufferAccessCollector collector;
collector(op->body);

PrimExpr condition;

for (const auto &[buffer, indices] : collector.buffer_indices) {
ICHECK(indices.size() == buffer->shape.size())
<< "indices size mismatch with buffer shape";

for (size_t i = 0; i < indices.size(); ++i) {
auto index = indices[i];
auto bound = analyzer_->const_int_bound(index);
int64_t upper_bound = bound->max_value + 1;
int64_t shape = Downcast<IntImm>(buffer->shape[i])->value;

// Collect the variables that used in the index
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> used_vars;
// post order visit the index
PostOrderVisit(index, [&](const ObjectRef &obj) {
if (const VarNode *v = obj.as<VarNode>()) {
used_vars.insert(GetRef<Var>(v));
}
});
if (used_vars.size() == 0) {
continue;
}

// find related loop vars
Array<Var> related_loop_vars;
for (size_t j = 0; j < loop_vars.size(); ++j) {
auto loop_var = loop_vars[j];
// if find related, pop the loop_vars and loop_extents
if (used_vars.count(loop_var)) {
related_loop_vars.push_back(loop_var);
}
ICHECK(related_loop_vars.size() <= 1)
<< "Only one related loop var is supported currently, but got "
<< related_loop_vars
<< " implement multiple loop vars may not be "
<< "too hard, please send an issue if you need "
<< "came up with this message.";

auto bound = analyzer_->const_int_bound(index);
int64_t upper_bound = bound->max_value + 1;
int64_t shape = Downcast<IntImm>(buffer->shape[i])->value;
if (upper_bound < shape) {
PrimExpr predicate = LT(index, IntImm(index.dtype(), upper_bound));
condition =
condition.defined() ? And(condition, predicate) : predicate;

// replace the buffer index from A[i, r * 2] with A[i, j]
// where r is the original index, j is the loop_var
auto index_map = tir::IndexMap({loop_var}, {index});
auto inverse_index_map = index_map.Inverse(
{Range::FromMinExtent(0, IntImm(index.dtype(), upper_bound))},
analyzer_);

loop_extents.Set(i, IntImm(index.dtype(), shape));
body = tir::Substitute(body,
{{loop_var, inverse_index_map->MapIndices(
{loop_var}, analyzer_)[0]}});
}
}
}
}
if (condition.defined()) {
body = IfThenElse(condition, body);
for (int j = loop_vars.size() - 1; j >= 0; --j) {
auto loop_var = loop_vars[j];
auto loop_extent = loop_extents[j];
body = For(loop_var, 0, loop_extent, ForKind::kParallel, body);
}
return Downcast<For>(body);
}
// Only traverse the outer loop
return for_node;
}

private:
// Helper class for collecting buffer access information, only counts fragment
// buffer access
class BufferAccessCollector : public StmtExprVisitor {
public:
void VisitExpr_(const BufferLoadNode *op) final {
if (op->buffer.scope() == "local.fragment") {
if (buffer_indices.find(op->buffer) == buffer_indices.end()) {
buffer_indices[op->buffer] = op->indices;
} else {
// check equal
ICHECK(StructuralEqual()(buffer_indices[op->buffer], op->indices))
<< "indices mismatch for buffer: " << op->buffer;
}
}
StmtExprVisitor::VisitExpr_(op);
}

void VisitStmt_(const BufferStoreNode *op) final {
if (op->buffer.scope() == "local.fragment") {
if (buffer_indices.find(op->buffer) == buffer_indices.end()) {
buffer_indices[op->buffer] = op->indices;
} else {
// check equal
ICHECK(StructuralEqual()(buffer_indices[op->buffer], op->indices))
<< "indices mismatch for buffer: " << op->buffer;
}
}
StmtExprVisitor::VisitStmt_(op);
}

std::unordered_map<Buffer, Array<PrimExpr>, ObjectPtrHash, ObjectPtrEqual>
buffer_indices;
};
};

struct LayoutInferenceResult {
Map<Buffer, Layout> layout_map;
Map<For, Fragment> for_map;
Expand Down Expand Up @@ -656,7 +504,6 @@ class LayoutInferencer : public IRMutatorWithAnalyzer {
tvm::transform::Pass LayoutInference() {
using namespace tir::transform;
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
f.CopyOnWrite()->body = ParallelLoopTransformer::Substitute(f->body);
ThreadBindingCollector collector;
collector(f->body);
bool has_thread_binding = collector.thread_binding_.size() > 0;
Expand Down
4 changes: 2 additions & 2 deletions testing/python/kernel/test_tilelang_kernel_mha_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,8 +310,8 @@ def assert_mha_equal(batch, h, n_ctx, d_head, causal):


def test_mha_bwd():
assert_mha_equal(8, 32, 256, 64, False)
assert_mha_equal(8, 32, 256, 64, True)
assert_mha_equal(8, 32, 128, 64, False)
assert_mha_equal(8, 32, 128, 64, True)


if __name__ == "__main__":
Expand Down
Loading