Skip to content

Commit ca4a10f

Browse files
committed
Fix typos in intrinsic names and remove unused print statement in block_sparse_attn_tilelang.py. Updated references from ptx_ldmatirx to ptx_ldmatrix across multiple files for consistency.
1 parent eb098c6 commit ca4a10f

File tree

6 files changed

+8
-9
lines changed

6 files changed

+8
-9
lines changed

examples/seer_attention/block_sparse_attn_tilelang.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,6 @@ def test_topk_sparse_attention():
178178
# Run tilelang kernel
179179
kernel = blocksparse_flashattn(
180180
BATCH, N_HEADS, SEQ_LEN, SEQ_LEN, D_HEAD, downsample_len, is_causal=True)
181-
print(kernel.get_kernel_source())
182181
tilelang_output = kernel(q, k, v, block_mask.to(torch.int8))
183182

184183
# Compute reference

src/op/builtin.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ TIR_DEFINE_TL_BUILTIN(mbarrier_expect_tx)
8080
.set_attr<TCallEffectKind>("TCallEffectKind",
8181
Integer(CallEffectKind::kOpaque));
8282

83-
TIR_DEFINE_TL_BUILTIN(ptx_ldmatirx)
83+
TIR_DEFINE_TL_BUILTIN(ptx_ldmatrix)
8484
.set_num_inputs(4)
8585
.set_attr<TCallEffectKind>("TCallEffectKind",
8686
Integer(CallEffectKind::kOpaque));

src/op/builtin.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -146,15 +146,15 @@ TVM_DLL const Op &mbarrier_expect_tx();
146146
/*!
147147
* \brief tvm intrinsics for ldmatrix
148148
*
149-
* ptx_ldmatirx(transposed, num, shared_addr, local_addr)
149+
* ptx_ldmatrix(transposed, num, shared_addr, local_addr)
150150
*
151151
*/
152-
TVM_DLL const Op &ptx_ldmatirx();
152+
TVM_DLL const Op &ptx_ldmatrix();
153153

154154
/*!
155155
* \brief tvm intrinsics for stmatrix
156156
*
157-
* ptx_ldmatirx(transposed, num, shared_addr, int32_values...)
157+
* ptx_ldmatrix(transposed, num, shared_addr, int32_values...)
158158
*
159159
*/
160160
TVM_DLL const Op &ptx_stmatrix();

src/op/copy.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ LayoutMap Copy::InferLayout(const LayoutInferArgs &T, InferLevel level) {
330330
Buffer shared_tensor = is_load ? dst : src;
331331
// check shared layout is non-swizzle
332332
// skip layout inference if shared layout is already annotated
333-
if (!T.layout_map.count(shared_tensor)) {
333+
if (level == InferLevel::kFree && !T.layout_map.count(shared_tensor)) {
334334
// create a new layout map for tma linear layout
335335
Layout linear_layout = ComputeLinearLayout(shared_tensor);
336336
return Map<Buffer, Layout>({{shared_tensor, linear_layout}});
@@ -646,7 +646,7 @@ Stmt Copy::LowerLDSMCopy(const LowerArgs &T, arith::Analyzer *analyzer,
646646
num = 2;
647647

648648
Array<PrimExpr> args;
649-
const Op &op = is_ldmatrix ? tl::ptx_ldmatirx() : tl::ptx_stmatirx();
649+
const Op &op = is_ldmatrix ? tl::ptx_ldmatrix() : tl::ptx_stmatrix();
650650
args.push_back(static_cast<int>(is_transposed));
651651
args.push_back(num);
652652

src/target/codegen_cuda.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1099,7 +1099,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
10991099
ss << "tl::tma_store";
11001100
}
11011101
print_extern_call_stmt(ss.str(), 0, 1);
1102-
} else if (op->op.same_as(tl::ptx_ldmatirx())) {
1102+
} else if (op->op.same_as(tl::ptx_ldmatrix())) {
11031103
int trans = Downcast<IntImm>(op->args[0])->value;
11041104
int num = Downcast<IntImm>(op->args[1])->value;
11051105
std::string func_name = "tl::ptx_ldmatrix_x" + std::to_string(num);

src/transform/inject_fence_proxy.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ class ProxyMarker : public StmtVisitor {
5757
void VisitStmt_(const EvaluateNode *op) final {
5858
Proxy proxy = Proxy::kAsync;
5959
if (auto call = op->value.as<CallNode>()) {
60-
if (call->op.same_as(ptx_ldmatirx()) ||
60+
if (call->op.same_as(ptx_ldmatrix()) ||
6161
call->op.same_as(ptx_stmatrix())) {
6262
proxy = Proxy::kGeneric;
6363
}

0 commit comments

Comments
 (0)