-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Closed
Description
Currently, my implementation consists of two stages:
T_idx, T_val = COMPUTE_ARGMAX(data, name='T')
real_idx = CHOOSE_INDEX(T_idx, T_val)During the scheduling, the first stage will reuse the reduce scheduling codes (https://github.com/dmlc/tvm/blob/master/topi/python/topi/cuda/reduction.py#L7-L42). In the second stage, the T_idx, T_val will be compute_at some axis.
The problem is that I cannot compute_at T_idx and T_val at axes in the first stage since they are the output of the first stage.
If I try T_idx.compute_at(FIRST_STAGE_OUT, FIRST_STAGE_AXIS), I'll receive this error:
tvm._ffi.base.TVMError: [23:43:06] D:\HKUST\tvm\src\schedule\graph.cc:179: Check failed: !visited.count(s.get()) Find loop in compute_at attach group
If I try T_idx.compute_at(SECOND_STAGE_OUT, FIRST_STAGE_AXIS), I'll receive this error:
tvm._ffi.base.TVMError: [23:55:14] D:\HKUST\tvm\src\schedule\schedule_lang.cc:133: Check failed: found Cannot find the axis iter_var(ax0.ax1.fused.outer, ) in parent's leaf_iter_vars parent=stage(argmax, 000002248AB6E980)
Is there a way to compute_at T_idx, T_val at the axis of the first stage?
Metadata
Metadata
Assignees
Labels
No labels