@tqchen When we do an argmax, we basically need to do: ```python T_idx, T_val = tvm.compute((m, ), lambda i: argmax((idx[i, k], val[i, k]), axis=k)) ``` We can avoid using `idx[i, k]` by doing the following: ```python T_idx, T_val = tvm.compute((m, ), lambda i: argmax((k.var, val[i, k]), axis=k)) ``` However, we **cannot do this**, which will appear when we want to calculate the idx w.r.t multiple reduce axes. ```python T_idx, T_val = tvm.compute((m, ), lambda i: argmax((k.var + k.var, val[i, k]), axis=k)) ``` Is there a way to solve this problem?