Skip to content

Commit 6df070a

Browse files
AndrewZhaoLuoAndrew Zhao Luo
andauthored
[ONNX][TOPI] Support select_last_index for argmin/max (#8816)
* support select_last_index for argmin/max * reverse conditions which made on accident * forward args in reduce.py * make proper nodes for reduction ops * remove complicated nested lambdas * fix lambda capture for conversion * forward more arguments * forward more args * enable onnx tests * wrapping casts to remove ambiguity * revert changes extraneous * correct incorrect attrs being used for ops * change attributes * remove old impl * register new attribute node * clean up test * reformat * reformat * coolio * stable comparison * casts to avoid ambiguity * casting more * correct arg passing * support select_last_index for argmin/max * reverse conditions which made on accident * forward args in reduce.py * make proper nodes for reduction ops * remove complicated nested lambdas * fix lambda capture for conversion * forward more arguments * forward more args * enable onnx tests * wrapping casts to remove ambiguity * revert changes extraneous * correct incorrect attrs being used for ops * change attributes * remove old impl * register new attribute node * clean up test * reformat * reformat * coolio * stable comparison * casts to avoid ambiguity * casting more * correct arg passing * fix broken input * OneElementReduceAttrs-->ArgReduceAttrs" * reduce boilerplate * change names * remove log statement * jostle ci Co-authored-by: Andrew Zhao Luo <andrewzhaoluo@system76-pc.localdomain>
1 parent 9a9cd70 commit 6df070a

File tree

9 files changed

+245
-87
lines changed

9 files changed

+245
-87
lines changed

include/tvm/relay/attrs/reduce.h

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,42 @@ struct ReduceAttrs : public tvm::AttrsNode<ReduceAttrs> {
6161
}
6262
};
6363

64+
/*! \brief Attributes for Reduce operators which reduce by finding a single element. E.g. argmin */
65+
struct ArgReduceAttrs : public tvm::AttrsNode<ArgReduceAttrs> {
66+
Array<Integer> axis;
67+
bool keepdims;
68+
bool select_last_index;
69+
bool exclude;
70+
71+
TVM_DECLARE_ATTRS(ArgReduceAttrs, "relay.attrs.ArgReduceAttrs") {
72+
TVM_ATTR_FIELD(axis)
73+
.set_default(NullValue<Array<Integer>>())
74+
.describe(R"code(The axis or axes along which to perform the reduction.
75+
76+
The default, `axis=()`, will compute over all elements into a
77+
scalar array with shape `(1,)`.
78+
79+
If `axis` is int, a reduction is performed on a particular axis.
80+
81+
If `axis` is a tuple of ints, a reduction is performed on all the axes
82+
specified in the tuple.
83+
84+
If `exclude` is true, reduction will be performed on the axes that are
85+
NOT in axis instead.)code");
86+
87+
TVM_ATTR_FIELD(keepdims).set_default(false).describe(
88+
"If this is set to `True`, the reduced axes are left "
89+
"in the result as dimension with size one.");
90+
TVM_ATTR_FIELD(select_last_index)
91+
.set_default(false)
92+
.describe(
93+
"Whether to select the last index if the target element appears multiple times, else "
94+
"select the first index which the target element appears");
95+
TVM_ATTR_FIELD(exclude).set_default(false).describe(
96+
"Whether to perform reduction on axis that are NOT in axis instead.");
97+
}
98+
};
99+
64100
struct VarianceAttrs : public tvm::AttrsNode<VarianceAttrs> {
65101
Array<Integer> axis;
66102
bool keepdims;

include/tvm/topi/reduction.h

Lines changed: 76 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,45 @@ inline Tensor max(const Tensor& data, const Array<Integer>& axis, bool keepdims
431431
return CommReduce(data, axis, MaxOp, keepdims, atleast1d);
432432
}
433433

434+
inline FCommReduce MakeArgminReducer(bool select_last_index = false) {
435+
// Create a Commutative Reducer with a comparison operation, and method to get the initial value.
436+
auto fcombine = [=](Array<Var> lhs, Array<Var> rhs) {
437+
Array<PrimExpr> result;
438+
439+
// Casting to avoid operator ambiguity
440+
PrimExpr lhs_idx = static_cast<PrimExpr>(lhs[0]);
441+
PrimExpr rhs_idx = static_cast<PrimExpr>(rhs[0]);
442+
PrimExpr lhs_val = static_cast<PrimExpr>(lhs[1]);
443+
PrimExpr rhs_val = static_cast<PrimExpr>(rhs[1]);
444+
445+
// These variables compare the actual values of the array
446+
auto is_smaller = lhs_val < rhs_val;
447+
auto is_same = lhs_val == rhs_val;
448+
449+
// This checks if the indices are correct for the reduction. E.g. for select_last_index
450+
// it gives precedence for later indices of the same element and precedence for sooner
451+
// indices if not select_last_index;
452+
PrimExpr proper_index;
453+
if (select_last_index) {
454+
proper_index = lhs_idx > rhs_idx;
455+
} else {
456+
proper_index = lhs_idx < rhs_idx;
457+
}
458+
459+
PrimExpr update_index = is_smaller || (is_same && proper_index);
460+
result.push_back(tvm::tir::Select(update_index, lhs[0], rhs[0])); // idx
461+
result.push_back(tvm::tir::Select(is_smaller, lhs[1], rhs[1])); // val
462+
return result;
463+
};
464+
auto fidentity = [&](std::vector<DataType> types) {
465+
Array<PrimExpr> result;
466+
result.push_back(tvm::tir::make_const(types[0], -1)); // idx
467+
result.push_back(tvm::max_value(types[1])); // val
468+
return result;
469+
};
470+
return MakeCommReducer(fcombine, fidentity, "argmin");
471+
}
472+
434473
/*!
435474
* \brief Creates an operation that finds the indices of the minimum
436475
* values over a given axis.
@@ -442,35 +481,48 @@ inline Tensor max(const Tensor& data, const Array<Integer>& axis, bool keepdims
442481
* left in the result as dimensions with size one. This enables the result
443482
* to broadcast correctly against the input array.
444483
* \param atleast1d Whether the output need to be atleast1d.
484+
* \param select_last_index Whether to select the last index if the minimum element
485+
* appears multiple times, else select the first index.
445486
*
446487
* \return A Tensor whose op member is the argmin operation
447488
*/
448489
inline Tensor argmin(const Tensor& data, const Array<Integer>& axis, bool keepdims = false,
449-
bool atleast1d = false) {
450-
auto fcombine = [](Array<Var> lhs, Array<Var> rhs) {
451-
Array<PrimExpr> result;
452-
result.push_back(tvm::tir::Select(lhs[1] <= rhs[1], lhs[0], rhs[0])); // idx
453-
result.push_back(tvm::tir::Select(lhs[1] <= rhs[1], lhs[1], rhs[1])); // val
454-
return result;
455-
};
456-
auto fidentity = [](std::vector<DataType> types) {
457-
Array<PrimExpr> result;
458-
result.push_back(tvm::tir::make_const(types[0], -1)); // idx
459-
result.push_back(tvm::max_value(types[1])); // val
460-
return result;
461-
};
462-
auto func = MakeCommReducer(fcombine, fidentity, "argmin");
463-
return CommReduceIdx(data, axis, func, keepdims, atleast1d);
490+
bool atleast1d = false, bool select_last_index = false) {
491+
auto reducer = MakeArgminReducer(select_last_index);
492+
return CommReduceIdx(data, axis, reducer, keepdims, atleast1d);
464493
}
465494

466-
inline FCommReduce MakeArgmaxReducer() {
467-
auto fcombine = [](Array<Var> lhs, Array<Var> rhs) {
495+
inline FCommReduce MakeArgmaxReducer(bool select_last_index = false) {
496+
// Create a Commutative Reducer with a comparison operation, and method to get the initial value.
497+
auto fcombine = [=](Array<Var> lhs, Array<Var> rhs) {
468498
Array<PrimExpr> result;
469-
result.push_back(tvm::tir::Select(lhs[1] >= rhs[1], lhs[0], rhs[0])); // idx
470-
result.push_back(tvm::tir::Select(lhs[1] >= rhs[1], lhs[1], rhs[1])); // val
499+
500+
// Casting to avoid operator ambiguity
501+
PrimExpr lhs_idx = static_cast<PrimExpr>(lhs[0]);
502+
PrimExpr rhs_idx = static_cast<PrimExpr>(rhs[0]);
503+
PrimExpr lhs_val = static_cast<PrimExpr>(lhs[1]);
504+
PrimExpr rhs_val = static_cast<PrimExpr>(rhs[1]);
505+
506+
// These variables compare the actual values of the array
507+
auto is_bigger = lhs_val > rhs_val;
508+
auto is_same = lhs_val == rhs_val;
509+
510+
// This checks if the indices are correct for the reduction. E.g. for select_last_index
511+
// it gives precedence for later indices of the same element and precedence for sooner
512+
// indices if not select_last_index;
513+
PrimExpr proper_index;
514+
if (select_last_index) {
515+
proper_index = lhs_idx > rhs_idx;
516+
} else {
517+
proper_index = lhs_idx < rhs_idx;
518+
}
519+
520+
PrimExpr update_index = is_bigger || (is_same && proper_index);
521+
result.push_back(tvm::tir::Select(update_index, lhs[0], rhs[0])); // idx
522+
result.push_back(tvm::tir::Select(is_bigger, lhs[1], rhs[1])); // val
471523
return result;
472524
};
473-
auto fidentity = [](std::vector<DataType> types) {
525+
auto fidentity = [&](std::vector<DataType> types) {
474526
Array<PrimExpr> result;
475527
result.push_back(tvm::tir::make_const(types[0], -1)); // idx
476528
result.push_back(tvm::min_value(types[1])); // val
@@ -490,12 +542,13 @@ inline FCommReduce MakeArgmaxReducer() {
490542
* left in the result as dimensions with size one. This enables the result
491543
* to broadcast correctly against the input array.
492544
* \param atleast1d Whether the output need to be atleast1d.
493-
*
545+
* \param select_last_index Whether to select the last index if the maximum element
546+
* appears multiple times, else select the first index.
494547
* \return A Tensor whose op member is the argmax operation
495548
*/
496549
inline Tensor argmax(const Tensor& data, const Array<Integer>& axis, bool keepdims = false,
497-
bool atleast1d = false) {
498-
auto reducer = MakeArgmaxReducer();
550+
bool atleast1d = false, bool select_last_index = false) {
551+
auto reducer = MakeArgmaxReducer(select_last_index);
499552
return CommReduceIdx(data, axis, reducer, keepdims, atleast1d);
500553
}
501554

python/tvm/relay/frontend/onnx.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,23 +32,23 @@
3232
from .. import loops as _loops
3333
from .. import op as _op
3434
from .. import qnn as _qnn
35+
from .. import random as _random
3536
from .. import ty as _ty
3637
from .. import vision as _vision
37-
from .. import random as _random
3838
from .common import (
3939
AttrCvt,
4040
Renamer,
4141
fold_constant,
4242
get_name,
4343
get_relay_op,
44+
gru_cell,
4445
infer_channels,
4546
infer_shape,
4647
infer_type,
4748
infer_value,
49+
lstm_cell,
4850
new_var,
4951
unbind,
50-
gru_cell,
51-
lstm_cell,
5252
)
5353

5454
__all__ = ["from_onnx"]
@@ -1786,25 +1786,23 @@ class ArgMax(OnnxOpConverter):
17861786
"""Operator converter for ArgMax."""
17871787

17881788
@classmethod
1789-
def _impl_v1(cls, inputs, attr, params):
1790-
if "select_last_index" in attr:
1791-
raise NotImplementedError("select_last_index not supported in ArgMax")
1789+
def _impl_v13(cls, inputs, attr, params):
17921790
axis = attr.get("axis", 0)
17931791
keepdims = attr.get("keepdims", True)
1794-
attr = {"axis": axis, "keepdims": keepdims}
1792+
select_last_index = attr.get("select_last_index", False)
1793+
attr = {"axis": axis, "keepdims": keepdims, "select_last_index": select_last_index}
17951794
return _op.cast(AttrCvt("argmax")(inputs, attr), "int64")
17961795

17971796

17981797
class ArgMin(OnnxOpConverter):
17991798
"""Operator converter for ArgMin."""
18001799

18011800
@classmethod
1802-
def _impl_v1(cls, inputs, attr, params):
1803-
if "select_last_index" in attr:
1804-
raise NotImplementedError("select_last_index not supported in ArgMin")
1801+
def _impl_v13(cls, inputs, attr, params):
18051802
axis = attr.get("axis", 0)
18061803
keepdims = attr.get("keepdims", True)
1807-
attr = {"axis": axis, "keepdims": keepdims}
1804+
select_last_index = attr.get("select_last_index", False)
1805+
attr = {"axis": axis, "keepdims": keepdims, "select_last_index": select_last_index}
18081806
return _op.cast(AttrCvt("argmin")(inputs, attr), "int64")
18091807

18101808

python/tvm/relay/op/reduce.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@
1717
"""Reduce operators."""
1818
# pylint: disable=redefined-builtin
1919

20+
from ..expr import Tuple, TupleWrapper
2021
from . import _make
21-
from .tensor import sqrt, log, exp
22+
from .tensor import exp, log, sqrt
2223
from .transform import squeeze
23-
from ..expr import Tuple, TupleWrapper
2424

2525

26-
def argmax(data, axis=None, keepdims=False, exclude=False):
26+
def argmax(data, axis=None, keepdims=False, exclude=False, select_last_index=False):
2727
"""Returns the indices of the maximum values along an axis.
2828
2929
Parameters
@@ -45,16 +45,20 @@ def argmax(data, axis=None, keepdims=False, exclude=False):
4545
If `exclude` is true, reduction will be performed on the axes that are
4646
NOT in axis instead.
4747
48+
select_last_index : bool
49+
Whether to select the last index or the first index if the max element appears in
50+
multiple indices, default is False (first index).
51+
4852
Returns
4953
-------
5054
result : relay.Expr
5155
The computed result.
5256
"""
5357
axis = [axis] if isinstance(axis, int) else axis
54-
return _make.argmax(data, axis, keepdims, exclude)
58+
return _make.argmax(data, axis, keepdims, exclude, select_last_index)
5559

5660

57-
def argmin(data, axis=None, keepdims=False, exclude=False):
61+
def argmin(data, axis=None, keepdims=False, exclude=False, select_last_index=False):
5862
"""Returns the indices of the minimum values along an axis.
5963
6064
Parameters
@@ -76,13 +80,17 @@ def argmin(data, axis=None, keepdims=False, exclude=False):
7680
If `exclude` is true, reduction will be performed on the axes that are
7781
NOT in axis instead.
7882
83+
select_last_index : bool
84+
Whether to select the last index or the first index if the min element appears in
85+
multiple indices, default is False (first index).
86+
7987
Returns
8088
-------
8189
result : relay.Expr
8290
The computed result.
8391
"""
8492
axis = [axis] if isinstance(axis, int) else axis
85-
return _make.argmin(data, axis, keepdims, exclude)
93+
return _make.argmin(data, axis, keepdims, exclude, select_last_index)
8694

8795

8896
def sum(data, axis=None, keepdims=False, exclude=False):

python/tvm/topi/reduction.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def min(data, axis=None, keepdims=False):
167167
return cpp.min(data, axis, keepdims)
168168

169169

170-
def argmax(data, axis=None, keepdims=False):
170+
def argmax(data, axis=None, keepdims=False, select_last_index=False):
171171
"""Returns the indices of the maximum values along an axis.
172172
173173
Parameters
@@ -185,14 +185,18 @@ def argmax(data, axis=None, keepdims=False):
185185
with size one.
186186
With this option, the result will broadcast correctly against the input array.
187187
188+
select_last_index: bool
189+
Whether to select the last index if the maximum element appears multiple times, else
190+
select the first index.
191+
188192
Returns
189193
-------
190194
ret : tvm.te.Tensor
191195
"""
192-
return cpp.argmax(data, axis, keepdims)
196+
return cpp.argmax(data, axis, keepdims, select_last_index)
193197

194198

195-
def argmin(data, axis=None, keepdims=False):
199+
def argmin(data, axis=None, keepdims=False, select_last_index=False):
196200
"""Returns the indices of the minimum values along an axis.
197201
198202
Parameters
@@ -210,11 +214,15 @@ def argmin(data, axis=None, keepdims=False):
210214
with size one.
211215
With this option, the result will broadcast correctly against the input array.
212216
217+
select_last_index: bool
218+
Whether to select the last index if the minimum element appears multiple times, else
219+
select the first index.
220+
213221
Returns
214222
-------
215223
ret : tvm.te.Tensor
216224
"""
217-
return cpp.argmin(data, axis, keepdims)
225+
return cpp.argmin(data, axis, keepdims, select_last_index)
218226

219227

220228
def prod(data, axis=None, keepdims=False):

0 commit comments

Comments
 (0)