Skip to content

Commit c2fe91e

Browse files
[Enhancement] Add shape checking for reduce options (#748)
* Add shape checking for reduce options * lint fix * Handle special case reducing into shape-1 tensor Allow reducing [X, d, Y] into [X, Y] or [X, 1, Y] --------- Co-authored-by: LeiWang1999 <leiwang1999@outlook.com>
1 parent e68fdab commit c2fe91e

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

tilelang/language/reduce.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,16 @@ def reduce(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clea
2424
Returns:
2525
tir.Call: Handle to the reduction operation
2626
"""
27+
# input shape: [X, d, Y], expected output shape: [X, Y] or [X, 1, Y]
28+
expected_shapes = [
29+
buffer.shape[:dim] + buffer.shape[dim + 1:],
30+
buffer.shape[:dim] + [1] + buffer.shape[dim + 1:]
31+
]
32+
if list(out.shape) not in expected_shapes:
33+
expected_shapes_str = ' or '.join(map(str, expected_shapes))
34+
raise ValueError(
35+
f"Invalid reduce output shape, buffer shape is {buffer.shape}, dim is {dim}, "
36+
f"output shape is {out.shape}, expected shapes are {expected_shapes_str}")
2737
buffer = buffer.access_ptr("r")
2838
out = out.access_ptr("w")
2939
return tir.call_intrin(

0 commit comments

Comments
 (0)