Skip to content

Commit c000631

Browse files
authored
[TOPI] cuda reduction schedule (#7131)
* complex reduce * fix * fix * fix
1 parent e51bcdd commit c000631

File tree

2 files changed

+32
-3
lines changed

2 files changed

+32
-3
lines changed

python/tvm/topi/cuda/reduction.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,12 +130,14 @@ def traverse_after_reduce(operator):
130130
for tensor in operator.input_tensors:
131131
traverse_after_reduce(tensor.op)
132132
elif operator.tag == "comm_reduce":
133-
_schedule_reduce(operator, sch, is_idx_reduce=False)
133+
if operator not in scheduled_ops:
134+
_schedule_reduce(operator, sch, is_idx_reduce=False)
134135
for tensor in operator.input_tensors:
135136
if tensor.op not in scheduled_ops:
136137
traverse_before_reduce(tensor.op)
137138
elif operator.tag == "comm_reduce_idx":
138-
_schedule_reduce(operator, sch, is_idx_reduce=True)
139+
if operator not in scheduled_ops:
140+
_schedule_reduce(operator, sch, is_idx_reduce=True)
139141
input_tensors = operator.input_tensors[0].op.input_tensors
140142
for tensor in input_tensors:
141143
if tensor.op not in scheduled_ops:
@@ -147,5 +149,6 @@ def traverse_after_reduce(operator):
147149

148150
scheduled_ops.append(operator)
149151

150-
traverse_after_reduce(outs[0].op)
152+
for out in outs:
153+
traverse_after_reduce(out.op)
151154
return sch

tests/python/topi/python/test_topi_reduce.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,5 +152,31 @@ def test_reduce_map():
152152
)
153153

154154

155+
@tvm.testing.uses_gpu
156+
def test_complex_reduce():
157+
in_shape = (2, 3)
158+
dtype = "float32"
159+
axis = 0
160+
keepdims = False
161+
A = te.placeholder(shape=in_shape, name="A", dtype=dtype)
162+
B = topi.sum(A, axis=axis, keepdims=keepdims)
163+
C = topi.add(B, B)
164+
D = topi.multiply(B, B)
165+
E = topi.add(C, D)
166+
for device, ctx in tvm.testing.enabled_targets():
167+
print("Running on target: %s" % device)
168+
with tvm.target.Target(device):
169+
s = tvm.topi.testing.get_reduce_schedule(device)(E)
170+
foo = tvm.build(s, [A, E], device, name="sum")
171+
in_npy = np.random.uniform(-1, 1, size=in_shape).astype(dtype)
172+
sum_npy = in_npy.sum(axis=axis, keepdims=keepdims)
173+
out_npy = sum_npy * 2 + sum_npy * sum_npy
174+
data_tvm = tvm.nd.array(in_npy, ctx=ctx)
175+
out_tvm = tvm.nd.empty(shape=out_npy.shape, ctx=ctx, dtype=dtype)
176+
foo(data_tvm, out_tvm)
177+
tvm.testing.assert_allclose(out_tvm.asnumpy(), out_npy, 1e-3, 1e-3)
178+
179+
155180
if __name__ == "__main__":
156181
test_reduce_map()
182+
test_complex_reduce()

0 commit comments

Comments
 (0)