@tqchen
When we do an argmax, we basically need to do:
T_idx, T_val = tvm.compute((m, ), lambda i: argmax((idx[i, k], val[i, k]), axis=k))
We can avoid using idx[i, k] by doing the following:
T_idx, T_val = tvm.compute((m, ), lambda i: argmax((k.var, val[i, k]), axis=k))
However, we cannot do this, which will appear when we want to calculate the idx w.r.t multiple reduce axes.
T_idx, T_val = tvm.compute((m, ), lambda i: argmax((k.var + k.var, val[i, k]), axis=k))
Is there a way to solve this problem?