Skip to content

Commit e371056

Browse files
committed
fix ci and pass bug
1 parent 3705141 commit e371056

File tree

2 files changed

+21
-9
lines changed

2 files changed

+21
-9
lines changed

src/transform/warp_specialized_rewriter.cc

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -376,21 +376,32 @@ class ThreadIdxRewriter : public StmtExprMutator {
376376
eq_op->b.as<VarNode>() == thread_var_.get()) {
377377
maybe_thread_opt_ = true;
378378
}
379-
maybe_thread_opt_ = do_shuffle_ && maybe_thread_opt_;
379+
auto then_case = StmtExprMutator::VisitStmt(op->then_case);
380+
maybe_thread_opt_ = do_shuffle_ && maybe_thread_opt_ && has_tma_op_;
381+
if (maybe_thread_opt_) {
382+
return IfThenElse(
383+
Call(DataType::Bool(), tl_shuffle_elect(), {thread_extent_}),
384+
StmtExprMutator::VisitStmt(op->then_case), std::nullopt);
385+
}
380386
}
381-
if (maybe_thread_opt_)
382-
return IfThenElse(
383-
Call(DataType::Bool(), tl_shuffle_elect(), {thread_extent_}),
384-
StmtExprMutator::VisitStmt(op->then_case), std::nullopt);
385-
else
386-
return StmtExprMutator::VisitStmt_(op);
387+
return StmtExprMutator::VisitStmt_(op);
388+
}
389+
390+
PrimExpr VisitExpr_(const CallNode *op) final {
391+
if (op->op.same_as(tl::tma_load()) ||
392+
op->op.same_as(tl::tma_load_im2col()) ||
393+
op->op.same_as(tl::tma_store())) {
394+
has_tma_op_ = true;
395+
}
396+
return StmtExprMutator::VisitExpr_(op);
387397
}
388398

389399
Var thread_var_;
390400
PrimExpr replaced_;
391401
PrimExpr thread_extent_;
392402
bool maybe_thread_opt_ = false;
393403
bool do_shuffle_;
404+
bool has_tma_op_ = false;
394405
};
395406

396407
Block MakeGroupBlock(const Stmt &stmt,

testing/python/webgpu/test_webgpu_codegen.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,9 @@ def assert_gemm_codegen(
4343
accum_dtype="float",
4444
):
4545
func = matmul(M, N, K, block_M, block_N, block_K, dtype=dtype, accum_dtype=accum_dtype)
46-
47-
artifact = tilelang.lower(func, target="webgpu")
46+
# Because the current pass context have been polluted by previous testing.
47+
with tvm.transform.PassContext():
48+
artifact = tilelang.lower(func, target="webgpu")
4849

4950
src_code = artifact.kernel_source
5051

0 commit comments

Comments
 (0)