44import tilelang .language as T
55from tvm import tir
66from typing import Union , List
7+ from tilelang .utils .language import get_buffer_region_from_load
78
89
910def gemm (
@@ -66,8 +67,15 @@ def retrieve_shape(object: Union[tir.Buffer, tir.BufferRegion]) -> List[int]:
6667 for r in region :
6768 shape .append (r .extent )
6869 return shape
70+ elif isinstance (object , tir .BufferLoad ):
71+ region = get_buffer_region_from_load (object ).region
72+ shape = []
73+ for r in region :
74+ shape .append (r .extent )
75+ return shape
6976 else :
70- raise ValueError (f"Unsupported argument type: { type (object )} for buffer { object } " )
77+ raise ValueError (
78+ f"Unsupported retrieve_shape argument type: { type (object )} for buffer { object } " )
7179
7280 def retrieve_stride (object : Union [tir .Buffer , tir .BufferRegion ]) -> List [int ]:
7381 if isinstance (object , tir .Buffer ):
@@ -85,8 +93,17 @@ def retrieve_stride(object: Union[tir.Buffer, tir.BufferRegion]) -> List[int]:
8593 strides .insert (0 , stride )
8694 stride *= s
8795 return strides
96+ elif isinstance (object , tir .BufferLoad ):
97+ buffer = object .buffer
98+ strides = []
99+ stride = 1
100+ for s in reversed (buffer .shape ):
101+ strides .insert (0 , stride )
102+ stride *= s
103+ return strides
88104 else :
89- raise ValueError (f"Unsupported argument type: { type (object )} for buffer { object } " )
105+ raise ValueError (
106+ f"Unsupported retrieve_stride argument type: { type (object )} for buffer { object } " )
90107
91108 A_shape = retrieve_shape (A )
92109 B_shape = retrieve_shape (B )
@@ -134,8 +151,24 @@ def retrieve_ptr(object: Union[tir.Buffer, tir.BufferRegion],
134151 for i in range (len (indices ) - 2 ):
135152 offset += indices [i ] * strides [i ]
136153 return buffer .access_ptr (access_mask = access_type , offset = offset )
154+ elif isinstance (object , tir .BufferLoad ):
155+ buffer = object .buffer
156+ region = get_buffer_region_from_load (object ).region
157+ indices = []
158+ for r in region :
159+ indices .append (r .min )
160+ strides = []
161+ stride = 1
162+ for s in reversed (buffer .shape ):
163+ strides .insert (0 , stride )
164+ stride *= s
165+ offset = 0
166+ for i in range (len (indices ) - 2 ):
167+ offset += indices [i ] * strides [i ]
168+ return buffer .access_ptr (access_mask = access_type , offset = offset )
137169 else :
138- raise ValueError (f"Unsupported argument type: { type (object )} for buffer { object } " )
170+ raise ValueError (
171+ f"Unsupported retrieve_ptr argument type: { type (object )} for buffer { object } " )
139172
140173 def retrieve_offset (object : Union [tir .Buffer , tir .BufferRegion ]) -> tir .PrimExpr :
141174 """Retrieve the offset of the buffer or buffer region."""
@@ -147,8 +180,15 @@ def retrieve_offset(object: Union[tir.Buffer, tir.BufferRegion]) -> tir.PrimExpr
147180 for r in region :
148181 indices .append (r .min )
149182 return indices
183+ elif isinstance (object , tir .BufferLoad ):
184+ region = get_buffer_region_from_load (object ).region
185+ indices = []
186+ for r in region :
187+ indices .append (r .min )
188+ return indices
150189 else :
151- raise ValueError (f"Unsupported argument type: { type (object )} for buffer { object } " )
190+ raise ValueError (
191+ f"Unsupported retrieve_offset argument type: { type (object )} for buffer { object } " )
152192
153193 A_offset = retrieve_offset (A )
154194 B_offset = retrieve_offset (B )
@@ -243,8 +283,15 @@ def retrieve_shape(object: Union[tir.Buffer, tir.BufferRegion]) -> List[int]:
243283 for r in region :
244284 shape .append (r .extent )
245285 return shape
286+ elif isinstance (object , tir .BufferLoad ):
287+ region = get_buffer_region_from_load (object ).region
288+ shape = []
289+ for r in region :
290+ shape .append (r .extent )
291+ return shape
246292 else :
247- raise ValueError (f"Unsupported argument type: { type (object )} for buffer { object } " )
293+ raise ValueError (
294+ f"Unsupported retrieve_shape argument type: { type (object )} for buffer { object } " )
248295
249296 def retrieve_stride (object : Union [tir .Buffer , tir .BufferRegion ]) -> List [int ]:
250297 if isinstance (object , tir .Buffer ):
@@ -262,8 +309,17 @@ def retrieve_stride(object: Union[tir.Buffer, tir.BufferRegion]) -> List[int]:
262309 strides .insert (0 , stride )
263310 stride *= s
264311 return strides
312+ elif isinstance (object , tir .BufferLoad ):
313+ buffer = object .buffer
314+ strides = []
315+ stride = 1
316+ for s in reversed (buffer .shape ):
317+ strides .insert (0 , stride )
318+ stride *= s
319+ return strides
265320 else :
266- raise ValueError (f"Unsupported argument type: { type (object )} for buffer { object } " )
321+ raise ValueError (
322+ f"Unsupported retrieve_stride argument type: { type (object )} for buffer { object } " )
267323
268324 A_shape = retrieve_shape (A )
269325 B_shape = retrieve_shape (B )
@@ -311,8 +367,24 @@ def retrieve_ptr(object: Union[tir.Buffer, tir.BufferRegion],
311367 for i in range (len (indices ) - 2 ):
312368 offset += indices [i ] * strides [i ]
313369 return buffer .access_ptr (access_mask = access_type , offset = offset )
370+ elif isinstance (object , tir .BufferLoad ):
371+ buffer = object .buffer
372+ region = get_buffer_region_from_load (object ).region
373+ indices = []
374+ for r in region :
375+ indices .append (r .min )
376+ strides = []
377+ stride = 1
378+ for s in reversed (buffer .shape ):
379+ strides .insert (0 , stride )
380+ stride *= s
381+ offset = 0
382+ for i in range (len (indices ) - 2 ):
383+ offset += indices [i ] * strides [i ]
384+ return buffer .access_ptr (access_mask = access_type , offset = offset )
314385 else :
315- raise ValueError (f"Unsupported argument type: { type (object )} for buffer { object } " )
386+ raise ValueError (
387+ f"Unsupported retrieve_ptr argument type: { type (object )} for buffer { object } " )
316388
317389 def retrieve_offset (object : Union [tir .Buffer , tir .BufferRegion ]) -> tir .PrimExpr :
318390 """Retrieve the offset of the buffer or buffer region."""
@@ -324,8 +396,15 @@ def retrieve_offset(object: Union[tir.Buffer, tir.BufferRegion]) -> tir.PrimExpr
324396 for r in region :
325397 indices .append (r .min )
326398 return indices
399+ elif isinstance (object , tir .BufferLoad ):
400+ region = get_buffer_region_from_load (object ).region
401+ indices = []
402+ for r in region :
403+ indices .append (r .min )
404+ return indices
327405 else :
328- raise ValueError (f"Unsupported argument type: { type (object )} for buffer { object } " )
406+ raise ValueError (
407+ f"Unsupported retrieve_offset argument type: { type (object )} for buffer { object } " )
329408
330409 A_offset = retrieve_offset (A )
331410 B_offset = retrieve_offset (B )
0 commit comments