Skip to content

Commit a15b990

Browse files
committed
[TIR][LowerMatchBuffer] Fix lowering strides when source region has higher dimension than the buffer
1 parent 4905a8c commit a15b990

File tree

2 files changed

+58
-5
lines changed

2 files changed

+58
-5
lines changed

src/tir/transforms/lower_match_buffer.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -188,25 +188,25 @@ class MatchBufferLower : public StmtExprMutator {
188188
Load load = Downcast<Load>(source_buffer.vload(indices, source_buffer->dtype));
189189
Bind(buffer->elem_offset, load->index, buffer->name + ".elem_offset");
190190
CHECK(analyzer_.CanProve(truncmod(buffer->elem_offset, buffer->offset_factor) == 0))
191-
<< "The source elem_offset " << buffer->elem_offset
192-
<< " does not satisfy the offset_factor " << buffer->offset_factor << ".";
191+
<< "The source elem_offset " << load->index << " does not satisfy the offset_factor "
192+
<< buffer->offset_factor << ".";
193193
}
194194

195195
// Step 2.3. Check and update strides
196196
// Check if target buffer strides are defined
197+
ICHECK(source->region.size() >= buffer->shape.size());
198+
size_t offset = source->region.size() - buffer->shape.size();
197199
if (!buffer->strides.empty()) {
198200
ICHECK_EQ(buffer->strides.size(), buffer->shape.size());
199201
PrimExpr stride = make_const(DataType::Int(32), 1);
200202
for (size_t i = buffer->shape.size(); i > 0; --i) {
201-
const PrimExpr& shape = source_buffer->shape[i - 1];
203+
const PrimExpr& shape = source_buffer->shape[i - 1 + offset];
202204
Bind(buffer->strides[i - 1], stride, buffer->name + ".strides_" + std::to_string(i - 1));
203205
stride *= shape;
204206
}
205207
}
206208

207209
// Step 2.4. Check and update shape
208-
ICHECK(source->region.size() >= buffer->shape.size());
209-
size_t offset = source->region.size() - buffer->shape.size();
210210
for (size_t i = 0; i < buffer->shape.size(); ++i) {
211211
const Range& range = source->region[i + offset];
212212
Bind(buffer->shape[i], range->extent, buffer->name + ".shape_" + std::to_string(i));

tests/python/unittest/test_tir_lower_match_buffer.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,54 @@ def transformed_opaque_access(a: ty.handle, b: ty.handle) -> None:
156156
)
157157

158158

159+
@tvm.script.tir
160+
def high_dim_opaque_access(a: ty.handle) -> None:
161+
A = tir.match_buffer(a, (16, 32, 64))
162+
for i, j, k in tir.grid(16, 2, 4):
163+
with tir.block([]):
164+
As_0 = tir.var("int32")
165+
As_1 = tir.var("int32")
166+
tir.reads([])
167+
tir.writes(A[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16])
168+
sub_A = tir.match_buffer(
169+
A[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16],
170+
(16, 16),
171+
strides=[As_0, As_1],
172+
offset_factor=1,
173+
)
174+
tir.evaluate(
175+
tir.intrin_test(
176+
sub_A.data,
177+
sub_A.elem_offset,
178+
sub_A.strides[0],
179+
sub_A.strides[1],
180+
sub_A.shape[0],
181+
sub_A.shape[1],
182+
dtype="handle",
183+
)
184+
)
185+
186+
187+
@tvm.script.tir
188+
def transformed_high_dim_opaque_access(a: ty.handle) -> None:
189+
A = tir.match_buffer(a, (16, 32, 64))
190+
for i, j, k in tir.grid(16, 2, 4):
191+
with tir.block([]):
192+
tir.reads([])
193+
tir.writes(A[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16])
194+
tir.evaluate(
195+
tir.intrin_test(
196+
A.data,
197+
i * 2048 + j * 1024 + k * 16,
198+
64,
199+
1,
200+
16,
201+
16,
202+
dtype="handle",
203+
)
204+
)
205+
206+
159207
@tvm.script.tir
160208
def recursive_match(a: ty.handle, b: ty.handle) -> None:
161209
A = tir.match_buffer(a, (64, 64, 64))
@@ -419,6 +467,10 @@ def test_opaque_access():
419467
_check(opaque_access, transformed_opaque_access)
420468

421469

470+
def test_high_dim_opaque_access():
471+
_check(high_dim_opaque_access, transformed_high_dim_opaque_access)
472+
473+
422474
def test_recursive_match():
423475
_check(recursive_match, transformed_recursive_match)
424476

@@ -447,6 +499,7 @@ def test_fail_match_func_param():
447499
if __name__ == "__main__":
448500
test_buffer_load_store()
449501
test_opaque_access()
502+
test_high_dim_opaque_access()
450503
test_recursive_match()
451504
test_symbolic_match()
452505
test_rank0_buffer()

0 commit comments

Comments
 (0)