Skip to content

Question met when implementing the argmax #510

@sxjscience

Description

@sxjscience

Currently, my implementation consists of two stages:

T_idx, T_val = COMPUTE_ARGMAX(data, name='T')
real_idx = CHOOSE_INDEX(T_idx, T_val)

During the scheduling, the first stage will reuse the reduce scheduling codes (https://github.com/dmlc/tvm/blob/master/topi/python/topi/cuda/reduction.py#L7-L42). In the second stage, the T_idx, T_val will be compute_at some axis.

The problem is that I cannot compute_at T_idx and T_val at axes in the first stage since they are the output of the first stage.

If I try T_idx.compute_at(FIRST_STAGE_OUT, FIRST_STAGE_AXIS), I'll receive this error:

tvm._ffi.base.TVMError: [23:43:06] D:\HKUST\tvm\src\schedule\graph.cc:179: Check failed: !visited.count(s.get()) Find loop in compute_at attach group

If I try T_idx.compute_at(SECOND_STAGE_OUT, FIRST_STAGE_AXIS), I'll receive this error:

tvm._ffi.base.TVMError: [23:55:14] D:\HKUST\tvm\src\schedule\schedule_lang.cc:133: Check failed: found Cannot find the axis iter_var(ax0.ax1.fused.outer, ) in parent's leaf_iter_vars parent=stage(argmax, 000002248AB6E980)

Is there a way to compute_at T_idx, T_val at the axis of the first stage?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions