Skip to content

Commit ebea77d

Browse files
authored
[CI] Test Fix: Handle BufferLoad nodes when T.gemm input has a stride (#843)
* bugfix * fix * test fix
1 parent 232782d commit ebea77d

File tree

1 file changed

+87
-8
lines changed

1 file changed

+87
-8
lines changed

tilelang/language/gemm.py

Lines changed: 87 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import tilelang.language as T
55
from tvm import tir
66
from typing import Union, List
7+
from tilelang.utils.language import get_buffer_region_from_load
78

89

910
def gemm(
@@ -66,8 +67,15 @@ def retrieve_shape(object: Union[tir.Buffer, tir.BufferRegion]) -> List[int]:
6667
for r in region:
6768
shape.append(r.extent)
6869
return shape
70+
elif isinstance(object, tir.BufferLoad):
71+
region = get_buffer_region_from_load(object).region
72+
shape = []
73+
for r in region:
74+
shape.append(r.extent)
75+
return shape
6976
else:
70-
raise ValueError(f"Unsupported argument type: {type(object)} for buffer {object}")
77+
raise ValueError(
78+
f"Unsupported retrieve_shape argument type: {type(object)} for buffer {object}")
7179

7280
def retrieve_stride(object: Union[tir.Buffer, tir.BufferRegion]) -> List[int]:
7381
if isinstance(object, tir.Buffer):
@@ -85,8 +93,17 @@ def retrieve_stride(object: Union[tir.Buffer, tir.BufferRegion]) -> List[int]:
8593
strides.insert(0, stride)
8694
stride *= s
8795
return strides
96+
elif isinstance(object, tir.BufferLoad):
97+
buffer = object.buffer
98+
strides = []
99+
stride = 1
100+
for s in reversed(buffer.shape):
101+
strides.insert(0, stride)
102+
stride *= s
103+
return strides
88104
else:
89-
raise ValueError(f"Unsupported argument type: {type(object)} for buffer {object}")
105+
raise ValueError(
106+
f"Unsupported retrieve_stride argument type: {type(object)} for buffer {object}")
90107

91108
A_shape = retrieve_shape(A)
92109
B_shape = retrieve_shape(B)
@@ -134,8 +151,24 @@ def retrieve_ptr(object: Union[tir.Buffer, tir.BufferRegion],
134151
for i in range(len(indices) - 2):
135152
offset += indices[i] * strides[i]
136153
return buffer.access_ptr(access_mask=access_type, offset=offset)
154+
elif isinstance(object, tir.BufferLoad):
155+
buffer = object.buffer
156+
region = get_buffer_region_from_load(object).region
157+
indices = []
158+
for r in region:
159+
indices.append(r.min)
160+
strides = []
161+
stride = 1
162+
for s in reversed(buffer.shape):
163+
strides.insert(0, stride)
164+
stride *= s
165+
offset = 0
166+
for i in range(len(indices) - 2):
167+
offset += indices[i] * strides[i]
168+
return buffer.access_ptr(access_mask=access_type, offset=offset)
137169
else:
138-
raise ValueError(f"Unsupported argument type: {type(object)} for buffer {object}")
170+
raise ValueError(
171+
f"Unsupported retrieve_ptr argument type: {type(object)} for buffer {object}")
139172

140173
def retrieve_offset(object: Union[tir.Buffer, tir.BufferRegion]) -> tir.PrimExpr:
141174
"""Retrieve the offset of the buffer or buffer region."""
@@ -147,8 +180,15 @@ def retrieve_offset(object: Union[tir.Buffer, tir.BufferRegion]) -> tir.PrimExpr
147180
for r in region:
148181
indices.append(r.min)
149182
return indices
183+
elif isinstance(object, tir.BufferLoad):
184+
region = get_buffer_region_from_load(object).region
185+
indices = []
186+
for r in region:
187+
indices.append(r.min)
188+
return indices
150189
else:
151-
raise ValueError(f"Unsupported argument type: {type(object)} for buffer {object}")
190+
raise ValueError(
191+
f"Unsupported retrieve_offset argument type: {type(object)} for buffer {object}")
152192

153193
A_offset = retrieve_offset(A)
154194
B_offset = retrieve_offset(B)
@@ -243,8 +283,15 @@ def retrieve_shape(object: Union[tir.Buffer, tir.BufferRegion]) -> List[int]:
243283
for r in region:
244284
shape.append(r.extent)
245285
return shape
286+
elif isinstance(object, tir.BufferLoad):
287+
region = get_buffer_region_from_load(object).region
288+
shape = []
289+
for r in region:
290+
shape.append(r.extent)
291+
return shape
246292
else:
247-
raise ValueError(f"Unsupported argument type: {type(object)} for buffer {object}")
293+
raise ValueError(
294+
f"Unsupported retrieve_shape argument type: {type(object)} for buffer {object}")
248295

249296
def retrieve_stride(object: Union[tir.Buffer, tir.BufferRegion]) -> List[int]:
250297
if isinstance(object, tir.Buffer):
@@ -262,8 +309,17 @@ def retrieve_stride(object: Union[tir.Buffer, tir.BufferRegion]) -> List[int]:
262309
strides.insert(0, stride)
263310
stride *= s
264311
return strides
312+
elif isinstance(object, tir.BufferLoad):
313+
buffer = object.buffer
314+
strides = []
315+
stride = 1
316+
for s in reversed(buffer.shape):
317+
strides.insert(0, stride)
318+
stride *= s
319+
return strides
265320
else:
266-
raise ValueError(f"Unsupported argument type: {type(object)} for buffer {object}")
321+
raise ValueError(
322+
f"Unsupported retrieve_stride argument type: {type(object)} for buffer {object}")
267323

268324
A_shape = retrieve_shape(A)
269325
B_shape = retrieve_shape(B)
@@ -311,8 +367,24 @@ def retrieve_ptr(object: Union[tir.Buffer, tir.BufferRegion],
311367
for i in range(len(indices) - 2):
312368
offset += indices[i] * strides[i]
313369
return buffer.access_ptr(access_mask=access_type, offset=offset)
370+
elif isinstance(object, tir.BufferLoad):
371+
buffer = object.buffer
372+
region = get_buffer_region_from_load(object).region
373+
indices = []
374+
for r in region:
375+
indices.append(r.min)
376+
strides = []
377+
stride = 1
378+
for s in reversed(buffer.shape):
379+
strides.insert(0, stride)
380+
stride *= s
381+
offset = 0
382+
for i in range(len(indices) - 2):
383+
offset += indices[i] * strides[i]
384+
return buffer.access_ptr(access_mask=access_type, offset=offset)
314385
else:
315-
raise ValueError(f"Unsupported argument type: {type(object)} for buffer {object}")
386+
raise ValueError(
387+
f"Unsupported retrieve_ptr argument type: {type(object)} for buffer {object}")
316388

317389
def retrieve_offset(object: Union[tir.Buffer, tir.BufferRegion]) -> tir.PrimExpr:
318390
"""Retrieve the offset of the buffer or buffer region."""
@@ -324,8 +396,15 @@ def retrieve_offset(object: Union[tir.Buffer, tir.BufferRegion]) -> tir.PrimExpr
324396
for r in region:
325397
indices.append(r.min)
326398
return indices
399+
elif isinstance(object, tir.BufferLoad):
400+
region = get_buffer_region_from_load(object).region
401+
indices = []
402+
for r in region:
403+
indices.append(r.min)
404+
return indices
327405
else:
328-
raise ValueError(f"Unsupported argument type: {type(object)} for buffer {object}")
406+
raise ValueError(
407+
f"Unsupported retrieve_offset argument type: {type(object)} for buffer {object}")
329408

330409
A_offset = retrieve_offset(A)
331410
B_offset = retrieve_offset(B)

0 commit comments

Comments
 (0)