Skip to content

Commit 1f6870e

Browse files
authored
Handle special case reducing into shape-1 tensor
Allow reducing [X, d, Y] into [X, Y] or [X, 1, Y]
1 parent e295ba1 commit 1f6870e

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

tilelang/language/reduce.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,16 @@ def reduce(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clea
2424
Returns:
2525
tir.Call: Handle to the reduction operation
2626
"""
27-
expected_shape = buffer.shape[:dim] + buffer.shape[dim + 1:]
28-
if expected_shape != list(out.shape):
27+
# input shape: [X, d, Y], expected output shape: [X, Y] or [X, 1, Y]
28+
expected_shapes = [
29+
buffer.shape[:dim] + buffer.shape[dim + 1:],
30+
buffer.shape[:dim] + [1] + buffer.shape[dim + 1:]
31+
]
32+
if list(out.shape) not in expected_shapes:
33+
expected_shapes_str = ' or '.join(map(str, expected_shapes))
2934
raise ValueError(
3035
f"Invalid reduce output shape, buffer shape is {buffer.shape}, dim is {dim}, "
31-
f"expected output shape {expected_shape}, got output shape {out.shape}")
36+
f"output shape is {out.shape}, expected shapes are {expected_shapes_str}")
3237
buffer = buffer.access_ptr("r")
3338
out = out.access_ptr("w")
3439
return tir.call_intrin(

0 commit comments

Comments
 (0)