Skip to content

Commit 2c0072a

Browse files
authored
[Refactor] Update buffer handling in copy and atomic operations (#1247)
* [Refactor] Update buffer handling in copy and atomic operations * Refactored the `copy` and `atomic_add` functions to use element-wise minimum for defining copy extents, ensuring correct handling of overlapping regions. * Updated utility functions to create `BufferLoad` instances with explicit extents, improving memory management and clarity. * Removed unused imports from `atomic.py` and `copy.py` to streamline the codebase. * Adjusted logging in `copy.cc` to provide clearer warnings for fallback scenarios in bulk copy operations. * Remove obsolete .git_commit.txt file * Add unit test for dynamic copy extent handling in TileLang * Introduced a new test file `test_tilelang_issue_1237.py` to verify that the `T.copy` function correctly manages dynamic extents during primitive function building. * The test reproduces a specific issue related to dynamic slice lengths and static buffer sizes, ensuring robustness in the handling of such scenarios. * The test does not require execution of the kernel, as building the primitive function is sufficient to validate the fix. * lint fix * fix * Revert "fix" This reverts commit 828b4c1. * Update TVM submodule and refactor atomic and copy functions * Updated the TVM submodule to a dirty state. * Refactored `atomic_add` and `copy` functions to pass extents explicitly to the `_to_region` helper, improving clarity and correctness in handling buffer regions. * Commented out the main execution call in the test example for `cast` and added a new function call to better demonstrate the example usage. * Enhance extent handling in atomic and copy functions * Introduced `legalize_pairwise_extents` utility to align and broadcast extent lists for `atomic_add` and `copy` functions, ensuring compatibility and correctness in buffer operations. * Updated both functions to utilize the new utility, improving clarity and robustness in handling dynamic and static extents. * Added comments to clarify the extent handling logic. * Enhance `legalize_pairwise_extents` function with early-exit rule * Added an early-exit condition to the `legalize_pairwise_extents` function to return original extents if the number of non-1 dimensions in both source and destination extents is equal, improving performance by avoiding unnecessary adjustments. * Updated the function's documentation to clarify the new behavior and maintain clarity in the extent handling logic. * lint fix
1 parent d7164ab commit 2c0072a

File tree

6 files changed

+113
-16
lines changed

6 files changed

+113
-16
lines changed

src/op/copy.cc

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1504,7 +1504,12 @@ Stmt CopyNode::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer,
15041504
}
15051505

15061506
auto inner_box_dim = as_const_int(desc.smem_box[0]);
1507-
ICHECK(inner_box_dim != nullptr);
1507+
if (inner_box_dim == nullptr) {
1508+
LOG(WARNING) << "inner_box_dim " << desc.smem_box[0]
1509+
<< " can only be a constant integer for TMA bulk copy, "
1510+
"fallback to normal copy";
1511+
return LowerNormalCopy(T, analyzer);
1512+
}
15081513
int instruction_dim = *inner_box_dim;
15091514
if (desc.swizzle == static_cast<int>(CU_TENSOR_MAP_SWIZZLE_64B)) {
15101515
instruction_dim = 64 / src->dtype.bytes();
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import tilelang.testing
2+
from tilelang import language as T
3+
4+
5+
def test_issue_1237_dynamic_copy_extent_builds():
6+
# Repro from debug/1113_issues/copy_dyn.py, adapted as a unit test.
7+
# The goal is to ensure T.copy correctly handles dynamic extents
8+
# (e.g., src slice length vs. static dst buffer size) during prim_func building.
9+
10+
length = T.symbolic("len", dtype="int32")
11+
12+
@T.prim_func
13+
def sample_kernel(global_tensor: T.Tensor[(length,), "int32"]): # noqa: F821
14+
with T.Kernel(1, threads=32):
15+
buffer_shared = T.alloc_shared((1024,), dtype="int32")
16+
T.copy(global_tensor[0:length], buffer_shared)
17+
18+
# Building the prim_func is sufficient to exercise the bug path; no need to JIT/execute.
19+
_ = sample_kernel
20+
21+
22+
if __name__ == "__main__":
23+
tilelang.testing.main()

tilelang/language/atomic.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
import tilelang.language as T
77
from tvm import ir, tir
88
from tvm.tir import PrimExpr, Buffer, BufferRegion, Var, op
9-
from tilelang.language.utils import buffer_to_tile_region, buffer_region_to_tile_region, buffer_load_to_tile_region
10-
from tilelang.utils.language import get_buffer_region_from_load
9+
from tilelang.language.utils import buffer_region_to_tile_region, buffer_load_to_tile_region
10+
from tilelang.utils.language import get_buffer_region_from_load, legalize_pairwise_extents
1111

1212
_MEMORY_ORDER_ID_MAP = {
1313
"relaxed": 0,
@@ -201,13 +201,14 @@ def get_extent(data):
201201
assert src_extent or dst_extent, "Can't deduce atomicadd extents from args"
202202
src_extent = list(src_extent) if src_extent else [1] * len(dst_extent)
203203
dst_extent = list(dst_extent) if dst_extent else [1] * len(src_extent)
204-
extent = max(src_extent, dst_extent)
204+
src_extent, dst_extent = legalize_pairwise_extents(src_extent, dst_extent)
205205

206-
def _to_region(data, access_type):
206+
def _to_region(data, access_type, extent):
207207
if isinstance(data, tir.Var) and T.has_let_value(data):
208208
data = T.get_let_value(data)
209209
if isinstance(data, tir.Buffer):
210-
return buffer_to_tile_region(data, access_type)
210+
zeros = [tir.IntImm("int32", 0) for _ in extent]
211+
return buffer_load_to_tile_region(tir.BufferLoad(data, zeros), access_type, extent)
211212
elif isinstance(data, tir.BufferRegion):
212213
return buffer_region_to_tile_region(data, access_type, extent)
213214
elif isinstance(data, tir.BufferLoad):
@@ -218,8 +219,8 @@ def _to_region(data, access_type):
218219
else:
219220
return buffer_load_to_tile_region(data, access_type, extent)
220221

221-
value = _to_region(value, "r")
222-
dst = _to_region(dst, "w")
222+
value = _to_region(value, "r", src_extent)
223+
dst = _to_region(dst, "w", dst_extent)
223224

224225
# Note: tile-region-based atomic operations don't support return_prev yet
225226
# This would need to be implemented in the tile runtime

tilelang/language/copy.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,12 @@
33

44
from typing import Literal
55
from tilelang import language as T
6-
from tilelang.utils.language import get_buffer_region_from_load
6+
from tilelang.utils.language import (
7+
get_buffer_region_from_load,
8+
legalize_pairwise_extents,
9+
)
710
from tvm import ir, tir
8-
from tilelang.language.utils import buffer_to_tile_region, buffer_region_to_tile_region, buffer_load_to_tile_region
11+
from tilelang.language.utils import buffer_region_to_tile_region, buffer_load_to_tile_region
912

1013

1114
def copy(src: tir.Buffer | tir.BufferLoad | tir.BufferRegion,
@@ -55,15 +58,26 @@ def get_extent(data):
5558
return tir.BufferStore(dst.buffer, src, dst.indices)
5659

5760
assert src_extent or dst_extent, "Can't deduce copy extents from args"
61+
# Treat missing extent as length-matched ones to enable broadcasting logic.
5862
src_extent = list(src_extent) if src_extent else [1] * len(dst_extent)
5963
dst_extent = list(dst_extent) if dst_extent else [1] * len(src_extent)
60-
extent = max(src_extent, dst_extent)
6164

62-
def _to_region(data, access_type):
65+
# Align and broadcast extents from the right (tail) side independently
66+
# for src and dst, so we can pass them unchanged into _to_region.
67+
# Rules per-dim from the right:
68+
# - equal -> keep both
69+
# - one is 1 -> set that side to the other side's dim
70+
# - otherwise -> error
71+
src_extent, dst_extent = legalize_pairwise_extents(src_extent, dst_extent)
72+
73+
def _to_region(data, access_type, extent):
6374
if isinstance(data, tir.Var) and T.has_let_value(data):
6475
data = T.get_let_value(data)
6576
if isinstance(data, tir.Buffer):
66-
return buffer_to_tile_region(data, access_type)
77+
# Restrict a raw buffer to the computed copy extent by creating
78+
# a BufferLoad at origin and passing the extents explicitly.
79+
zeros = [tir.IntImm("int32", 0) for _ in extent]
80+
return buffer_load_to_tile_region(tir.BufferLoad(data, zeros), access_type, extent)
6781
elif isinstance(data, tir.BufferRegion):
6882
return buffer_region_to_tile_region(data, access_type, extent)
6983
elif isinstance(data, tir.BufferLoad):
@@ -74,8 +88,9 @@ def _to_region(data, access_type):
7488
else:
7589
return buffer_load_to_tile_region(data, access_type, extent)
7690

77-
src = _to_region(src, "r")
78-
dst = _to_region(dst, "w")
91+
# Use legalized extents for src and dst respectively.
92+
src = _to_region(src, "r", src_extent)
93+
dst = _to_region(dst, "w", dst_extent)
7994

8095
if coalesced_width is None:
8196
coalesced_width = -1 # PrimExpr can not be None

tilelang/language/utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,14 @@ def buffer_region_to_tile_region(buffer_region: tir.BufferRegion, access_type: s
8585
extents
8686
), f"region_extents must be >= extents, region_extents = {region_extents}, extents = {extents}"
8787

88-
return region(T.BufferLoad(buffer_region.buffer, mins), access_type, *region_extents)
88+
# Clamp extents element-wise so that the produced region respects the
89+
# requested copy/fill extent, supporting dynamic PrimExpr via tir.min.
90+
clamped_extents = [
91+
tir.min(region_extents[i], extents[i]) if i < len(extents) else region_extents[i]
92+
for i in range(len(region_extents))
93+
]
94+
95+
return region(T.BufferLoad(buffer_region.buffer, mins), access_type, *clamped_extents)
8996

9097

9198
def index_to_coordinates(index, shape) -> list[PrimExpr]:

tilelang/utils/language.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,52 @@ def prim_expr_equal(lhs, rhs) -> bool:
367367
return tir.analysis.expr_deep_equal(lhs, rhs)
368368

369369

370+
def legalize_pairwise_extents(src_extents: list, dst_extents: list) -> tuple[list, list]:
371+
"""
372+
Right-align and broadcast two extent lists to be mutually compatible.
373+
374+
Early-exit rule:
375+
- If the number of non-1 dimensions in `src_extents` equals that in `dst_extents`,
376+
no adjustment is made; the original extents are returned unchanged. This
377+
preserves the per-dimension iteration mapping (one loop var per non-1 dim)
378+
and avoids creating extra varying axes on either side.
379+
380+
Otherwise, for each pair of tail-aligned dimensions (x, y):
381+
- if x == y: keep both
382+
- elif x == 1: set x = y
383+
- elif y == 1: set y = x
384+
- else: promote both to tir.max(x, y) to handle dynamic-vs-static safely
385+
386+
Leading unmatched dimensions are kept as-is.
387+
388+
Returns a tuple of new lists (src_new, dst_new).
389+
"""
390+
a = list(src_extents)
391+
b = list(dst_extents)
392+
393+
# If both sides have the same number of non-1 extents, don't re-broadcast.
394+
def _num_non_one(exts: list) -> int:
395+
return sum(0 if prim_expr_equal(x, 1) else 1 for x in exts)
396+
397+
if _num_non_one(a) == _num_non_one(b):
398+
return a, b
399+
k = min(len(a), len(b))
400+
for i in range(1, k + 1):
401+
x, y = a[-i], b[-i]
402+
if prim_expr_equal(x, y):
403+
continue
404+
elif prim_expr_equal(x, 1):
405+
a[-i] = y
406+
elif prim_expr_equal(y, 1):
407+
b[-i] = x
408+
else:
409+
# Dynamic mismatch: promote to max so downstream clamping/predicates remain safe
410+
m = tir.max(x, y)
411+
a[-i] = m
412+
b[-i] = m
413+
return a, b
414+
415+
370416
def is_full_region(buffer_region: BufferRegion) -> bool:
371417
"""
372418
Check whether a BufferRegion covers the full buffer region.

0 commit comments

Comments
 (0)