@@ -431,6 +431,45 @@ inline Tensor max(const Tensor& data, const Array<Integer>& axis, bool keepdims
431
431
return CommReduce (data, axis, MaxOp, keepdims, atleast1d);
432
432
}
433
433
434
+ inline FCommReduce MakeArgminReducer (bool select_last_index = false ) {
435
+ // Create a Commutative Reducer with a comparison operation, and method to get the initial value.
436
+ auto fcombine = [=](Array<Var> lhs, Array<Var> rhs) {
437
+ Array<PrimExpr> result;
438
+
439
+ // Casting to avoid operator ambiguity
440
+ PrimExpr lhs_idx = static_cast <PrimExpr>(lhs[0 ]);
441
+ PrimExpr rhs_idx = static_cast <PrimExpr>(rhs[0 ]);
442
+ PrimExpr lhs_val = static_cast <PrimExpr>(lhs[1 ]);
443
+ PrimExpr rhs_val = static_cast <PrimExpr>(rhs[1 ]);
444
+
445
+ // These variables compare the actual values of the array
446
+ auto is_smaller = lhs_val < rhs_val;
447
+ auto is_same = lhs_val == rhs_val;
448
+
449
+ // This checks if the indices are correct for the reduction. E.g. for select_last_index
450
+ // it gives precedence for later indices of the same element and precedence for sooner
451
+ // indices if not select_last_index;
452
+ PrimExpr proper_index;
453
+ if (select_last_index) {
454
+ proper_index = lhs_idx > rhs_idx;
455
+ } else {
456
+ proper_index = lhs_idx < rhs_idx;
457
+ }
458
+
459
+ PrimExpr update_index = is_smaller || (is_same && proper_index);
460
+ result.push_back (tvm::tir::Select (update_index, lhs[0 ], rhs[0 ])); // idx
461
+ result.push_back (tvm::tir::Select (is_smaller, lhs[1 ], rhs[1 ])); // val
462
+ return result;
463
+ };
464
+ auto fidentity = [&](std::vector<DataType> types) {
465
+ Array<PrimExpr> result;
466
+ result.push_back (tvm::tir::make_const (types[0 ], -1 )); // idx
467
+ result.push_back (tvm::max_value (types[1 ])); // val
468
+ return result;
469
+ };
470
+ return MakeCommReducer (fcombine, fidentity, " argmin" );
471
+ }
472
+
434
473
/* !
435
474
* \brief Creates an operation that finds the indices of the minimum
436
475
* values over a given axis.
@@ -442,35 +481,48 @@ inline Tensor max(const Tensor& data, const Array<Integer>& axis, bool keepdims
442
481
* left in the result as dimensions with size one. This enables the result
443
482
* to broadcast correctly against the input array.
444
483
* \param atleast1d Whether the output need to be atleast1d.
484
+ * \param select_last_index Whether to select the last index if the minimum element
485
+ * appears multiple times, else select the first index.
445
486
*
446
487
* \return A Tensor whose op member is the argmin operation
447
488
*/
448
489
inline Tensor argmin (const Tensor& data, const Array<Integer>& axis, bool keepdims = false ,
449
- bool atleast1d = false ) {
450
- auto fcombine = [](Array<Var> lhs, Array<Var> rhs) {
451
- Array<PrimExpr> result;
452
- result.push_back (tvm::tir::Select (lhs[1 ] <= rhs[1 ], lhs[0 ], rhs[0 ])); // idx
453
- result.push_back (tvm::tir::Select (lhs[1 ] <= rhs[1 ], lhs[1 ], rhs[1 ])); // val
454
- return result;
455
- };
456
- auto fidentity = [](std::vector<DataType> types) {
457
- Array<PrimExpr> result;
458
- result.push_back (tvm::tir::make_const (types[0 ], -1 )); // idx
459
- result.push_back (tvm::max_value (types[1 ])); // val
460
- return result;
461
- };
462
- auto func = MakeCommReducer (fcombine, fidentity, " argmin" );
463
- return CommReduceIdx (data, axis, func, keepdims, atleast1d);
490
+ bool atleast1d = false , bool select_last_index = false ) {
491
+ auto reducer = MakeArgminReducer (select_last_index);
492
+ return CommReduceIdx (data, axis, reducer, keepdims, atleast1d);
464
493
}
465
494
466
- inline FCommReduce MakeArgmaxReducer () {
467
- auto fcombine = [](Array<Var> lhs, Array<Var> rhs) {
495
+ inline FCommReduce MakeArgmaxReducer (bool select_last_index = false ) {
496
+ // Create a Commutative Reducer with a comparison operation, and method to get the initial value.
497
+ auto fcombine = [=](Array<Var> lhs, Array<Var> rhs) {
468
498
Array<PrimExpr> result;
469
- result.push_back (tvm::tir::Select (lhs[1 ] >= rhs[1 ], lhs[0 ], rhs[0 ])); // idx
470
- result.push_back (tvm::tir::Select (lhs[1 ] >= rhs[1 ], lhs[1 ], rhs[1 ])); // val
499
+
500
+ // Casting to avoid operator ambiguity
501
+ PrimExpr lhs_idx = static_cast <PrimExpr>(lhs[0 ]);
502
+ PrimExpr rhs_idx = static_cast <PrimExpr>(rhs[0 ]);
503
+ PrimExpr lhs_val = static_cast <PrimExpr>(lhs[1 ]);
504
+ PrimExpr rhs_val = static_cast <PrimExpr>(rhs[1 ]);
505
+
506
+ // These variables compare the actual values of the array
507
+ auto is_bigger = lhs_val > rhs_val;
508
+ auto is_same = lhs_val == rhs_val;
509
+
510
+ // This checks if the indices are correct for the reduction. E.g. for select_last_index
511
+ // it gives precedence for later indices of the same element and precedence for sooner
512
+ // indices if not select_last_index;
513
+ PrimExpr proper_index;
514
+ if (select_last_index) {
515
+ proper_index = lhs_idx > rhs_idx;
516
+ } else {
517
+ proper_index = lhs_idx < rhs_idx;
518
+ }
519
+
520
+ PrimExpr update_index = is_bigger || (is_same && proper_index);
521
+ result.push_back (tvm::tir::Select (update_index, lhs[0 ], rhs[0 ])); // idx
522
+ result.push_back (tvm::tir::Select (is_bigger, lhs[1 ], rhs[1 ])); // val
471
523
return result;
472
524
};
473
- auto fidentity = [](std::vector<DataType> types) {
525
+ auto fidentity = [& ](std::vector<DataType> types) {
474
526
Array<PrimExpr> result;
475
527
result.push_back (tvm::tir::make_const (types[0 ], -1 )); // idx
476
528
result.push_back (tvm::min_value (types[1 ])); // val
@@ -490,12 +542,13 @@ inline FCommReduce MakeArgmaxReducer() {
490
542
* left in the result as dimensions with size one. This enables the result
491
543
* to broadcast correctly against the input array.
492
544
* \param atleast1d Whether the output need to be atleast1d.
493
- *
545
+ * \param select_last_index Whether to select the last index if the maximum element
546
+ * appears multiple times, else select the first index.
494
547
* \return A Tensor whose op member is the argmax operation
495
548
*/
496
549
inline Tensor argmax (const Tensor& data, const Array<Integer>& axis, bool keepdims = false ,
497
- bool atleast1d = false ) {
498
- auto reducer = MakeArgmaxReducer ();
550
+ bool atleast1d = false , bool select_last_index = false ) {
551
+ auto reducer = MakeArgmaxReducer (select_last_index );
499
552
return CommReduceIdx (data, axis, reducer, keepdims, atleast1d);
500
553
}
501
554
0 commit comments