Skip to content

Commit ee01a93

Browse files
committed
Enhance bulk copy and store checks in Copy class
- Updated scope validation for source and destination tensors in `CheckBulkLoad` and `CheckBulkStore` methods to include both `shared.dyn` and `shared` as valid options. - Modified `CheckLDSMCopy` and `CheckSTSMCopy` methods to accommodate the new scope validation, ensuring compatibility with shared memory configurations. - Improved logging in `LowerBulkCopy` to provide clearer warnings regarding unsupported swizzle layouts, including source and destination names for better debugging.
1 parent 00b2fd8 commit ee01a93

File tree

10 files changed

+28
-19
lines changed

10 files changed

+28
-19
lines changed

benchmark/matmul/benchmark_matmul_sp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def main(
192192

193193
# Clear out the accumulation buffer
194194
T.clear(C_local)
195-
T.no_set_max_nreg()
195+
T.disable_warp_group_reg_alloc()
196196

197197
T.use_swizzle(panel_size=10, enable=enable_rasterization)
198198
T.annotate_layout({

examples/deepseek_mla/experimental/example_mla_decode_kv_fp8.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def main_no_split(
5252
T.fill(acc_o, 0)
5353
T.fill(logsum, 0)
5454
T.fill(scores_max, -T.infinity(accum_dtype))
55-
T.no_set_max_nreg()
55+
T.disable_warp_group_reg_alloc()
5656
loop_range = T.ceildiv(seqlen_kv, block_N)
5757
for k in T.Pipelined(loop_range, num_stages=2):
5858
T.copy(KV[bx, k * block_N:(k + 1) * block_N, cur_kv_head, :], qKV_shared)

examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ def main(
338338
C_shared: tilelang.layout.make_swizzled_layout(C_shared),
339339
})
340340
if threads == 512:
341-
T.no_set_max_nreg()
341+
T.disable_warp_group_reg_alloc()
342342

343343
T.clear(C_local)
344344
for k in T.Pipelined(K // block_K, num_stages=num_stages):

examples/gdn/example_chunk_o.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def kernel(
122122

123123
T.clear(A_fragment)
124124
T.clear(O_fragment)
125-
T.no_set_max_nreg()
125+
T.disable_warp_group_reg_alloc()
126126
for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages):
127127
T.copy(
128128
Q[bb, bs * block_S:(bs + 1) * block_S, bh, i_k * block_DK:(i_k + 1) * block_DK],

examples/gdn/example_chunk_scaled_dot_kkt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def kernel(
101101
})
102102

103103
T.fill(A_fragment, 0)
104-
T.no_set_max_nreg()
104+
T.disable_warp_group_reg_alloc()
105105
for i_s in T.Parallel(block_S):
106106
Beta_shared[i_s] = Beta[bb, bs * block_S + i_s, bh]
107107

examples/gdn/example_wy_fast.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def kernel(
107107
U_Beta_shared: tilelang.layout.make_swizzled_layout(U_Beta_shared),
108108
})
109109

110-
T.no_set_max_nreg()
110+
T.disable_warp_group_reg_alloc()
111111
for i_s in T.Parallel(block_S):
112112
Beta_shared[i_s] = Beta[bb, bs * block_S + i_s, bh]
113113
G_shared[i_s] = T.exp(G[bb, bs * block_S + i_s, bh])

src/op/copy.cc

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -363,8 +363,8 @@ bool Copy::CheckBulkLoad(Target target) const {
363363
// 1. arch must have bulk copy support
364364
if (!TargetHasBulkCopy(target))
365365
return false;
366-
// 2. src and dst must be shared.dyn and local.fragment
367-
if (src.scope() != "global" || dst.scope() != "shared.dyn")
366+
// 2. src and dst must be global and shared
367+
if (src.scope() != "global" || (dst.scope() != "shared.dyn" && dst.scope() != "shared"))
368368
return false;
369369
// 3. check shape.
370370
// TODO(lei): validate if we can utilize tma under this shape.
@@ -391,7 +391,7 @@ bool Copy::CheckBulkStore(Target target) const {
391391
if (!TargetHasBulkCopy(target))
392392
return false;
393393
// 2. src and dst must be shared.dyn and local.fragment
394-
if (src.scope() != "shared.dyn" || dst.scope() != "global")
394+
if ((src.scope() != "shared.dyn" && src.scope() != "shared") || dst.scope() != "global")
395395
return false;
396396
// 3. check shape.
397397
// TODO(lei): validate if we can utilize tma under this shape.
@@ -414,7 +414,7 @@ bool Copy::CheckBulkStore(Target target) const {
414414
* otherwise.
415415
*/
416416
bool Copy::CheckLDSMCopy(Target target) const {
417-
return TargetHasLdmatrix(target) && src.scope() == "shared.dyn" &&
417+
return TargetHasLdmatrix(target) && (src.scope() == "shared.dyn" || src.scope() == "shared") &&
418418
dst.scope() == "local.fragment";
419419
}
420420

@@ -428,7 +428,7 @@ bool Copy::CheckLDSMCopy(Target target) const {
428428
*/
429429
bool Copy::CheckSTSMCopy(Target target) const {
430430
return TargetHasStmatrix(target) && src.scope() == "local.fragment" &&
431-
dst.scope() == "shared.dyn";
431+
(dst.scope() == "shared.dyn" || dst.scope() == "shared");
432432
}
433433

434434
/*!
@@ -883,11 +883,7 @@ Stmt Copy::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer,
883883
ICHECK(stride != nullptr && continuous != nullptr);
884884
// We also need to check if the shape satisfies the following doc:
885885
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html#group__CUDA__TENSOR__MEMORY_1ga7c7d2aaac9e49294304e755e6f341d7
886-
if (StructuralEqual()(shared_layout, makeGemmABLayoutPadded(
887-
*stride, *continuous,
888-
shared_tensor->dtype.bits()))) {
889-
desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_NONE);
890-
} else if (StructuralEqual()(
886+
if (StructuralEqual()(
891887
shared_layout,
892888
makeQuarterBankSwizzleLayout(*stride, *continuous,
893889
shared_tensor->dtype.bits()))) {
@@ -902,9 +898,18 @@ Stmt Copy::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer,
902898
makeFullBankSwizzleLayout(*stride, *continuous,
903899
shared_tensor->dtype.bits()))) {
904900
desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_128B);
901+
} else if (StructuralEqual()(shared_layout, makeGemmABLayoutPadded(
902+
*stride, *continuous,
903+
shared_tensor->dtype.bits()))) {
904+
LOG(WARNING) << "Bulk copy cannot support a padded layout for src: "
905+
<< src->name << ", dst: " << dst->name
906+
<< ", fallback to normal copy";
907+
return LowerNormalCopy(T, analyzer);
905908
} else {
906909
LOG(WARNING)
907-
<< "Came across unsupported swizzle layout, fallback to normal copy";
910+
<< "Came across unsupported swizzle layout for src: "
911+
<< src->name << ", dst: " << dst->name
912+
<< ", fallback to normal copy";
908913
return LowerNormalCopy(T, analyzer);
909914
}
910915
}

testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def main(
7070
backend="cutlass",
7171
block_k=block_K),
7272
})
73-
T.no_set_max_nreg()
73+
T.disable_warp_group_reg_alloc()
7474
T.clear(C_local)
7575
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
7676
T.copy(E[by * block_M, k * block_K // E_factor], E_shared)

testing/python/transform/test_tilelang_transform_inject_set_max_nreg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def before_no_set_max_nreg(A: T.Tensor((512, 512), "float16")):
100100
T.writes()
101101

102102
# Add no_set_max_nreg to disable register hinting
103-
T.no_set_max_nreg()
103+
T.disable_warp_group_reg_alloc()
104104

105105
T.create_list_of_mbarrier(128, 128)
106106
T.attr([128, 128], "kWarpSpecializationScope", 0)

tilelang/language/builtin.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,10 @@ def no_set_max_nreg():
159159
"""
160160
return tir.call_intrin("handle", tir.op.Op.get("tl.no_set_max_nreg"))
161161

162+
def disable_warp_group_reg_alloc():
163+
"""Disable the warp group reg alloc.
164+
"""
165+
return no_set_max_nreg()
162166

163167
def mbarrier_wait_parity(mbarrier: Union[int, PrimExpr, tir.Call], parity: Union[int, Var]):
164168
"""Wait for memory barrier parity condition.

0 commit comments

Comments
 (0)