@@ -436,18 +436,24 @@ inline FCommReduce MakeArgminReducer(bool select_last_index = false) {
436
436
auto fcombine = [=](Array<Var> lhs, Array<Var> rhs) {
437
437
Array<PrimExpr> result;
438
438
439
+ // Casting to avoid operator ambiguity
440
+ PrimExpr lhs_idx = static_cast <PrimExpr>(lhs[0 ]);
441
+ PrimExpr rhs_idx = static_cast <PrimExpr>(rhs[0 ]);
442
+ PrimExpr lhs_val = static_cast <PrimExpr>(lhs[1 ]);
443
+ PrimExpr rhs_val = static_cast <PrimExpr>(rhs[1 ]);
444
+
439
445
// These variables compare the actual values of the array
440
- auto is_smaller = lhs[ 1 ] < rhs[ 1 ] ;
441
- auto is_same = lhs[ 1 ] == rhs[ 1 ] ;
446
+ auto is_smaller = lhs_val < rhs_val ;
447
+ auto is_same = lhs_val == rhs_val ;
442
448
443
449
// This checks if the indices are correct for the reduction. E.g. for select_last_index
444
450
// it gives precedence for later indices of the same element and precedence for sooner
445
451
// indices if not select_last_index;
446
452
PrimExpr proper_index;
447
453
if (select_last_index) {
448
- proper_index = PrimExpr (lhs[ 0 ]) > PrimExpr (rhs[ 0 ]) ;
454
+ proper_index = lhs_idx > rhs_idx ;
449
455
} else {
450
- proper_index = PrimExpr (lhs[ 0 ]) < PrimExpr (rhs[ 0 ]) ;
456
+ proper_index = lhs_idx < rhs_idx ;
451
457
}
452
458
453
459
PrimExpr update_index = is_smaller || (is_same && proper_index);
@@ -491,18 +497,24 @@ inline FCommReduce MakeArgmaxReducer(bool select_last_index = false) {
491
497
auto fcombine = [=](Array<Var> lhs, Array<Var> rhs) {
492
498
Array<PrimExpr> result;
493
499
500
+ // Casting to avoid operator ambiguity
501
+ PrimExpr lhs_idx = static_cast <PrimExpr>(lhs[0 ]);
502
+ PrimExpr rhs_idx = static_cast <PrimExpr>(rhs[0 ]);
503
+ PrimExpr lhs_val = static_cast <PrimExpr>(lhs[1 ]);
504
+ PrimExpr rhs_val = static_cast <PrimExpr>(rhs[1 ]);
505
+
494
506
// These variables compare the actual values of the array
495
- auto is_bigger = lhs[ 1 ] > rhs[ 1 ] ;
496
- auto is_same = lhs[ 1 ] == rhs[ 1 ] ;
507
+ auto is_bigger = lhs_val > rhs_val ;
508
+ auto is_same = lhs_val == rhs_val ;
497
509
498
510
// This checks if the indices are correct for the reduction. E.g. for select_last_index
499
511
// it gives precedence for later indices of the same element and precedence for sooner
500
512
// indices if not select_last_index;
501
513
PrimExpr proper_index;
502
514
if (select_last_index) {
503
- proper_index = PrimExpr (lhs[ 0 ]) > PrimExpr (rhs[ 0 ]) ;
515
+ proper_index = lhs_idx > rhs_idx ;
504
516
} else {
505
- proper_index = PrimExpr (lhs[ 0 ]) < PrimExpr (rhs[ 0 ]) ;
517
+ proper_index = lhs_idx < rhs_idx ;
506
518
}
507
519
508
520
PrimExpr update_index = is_bigger || (is_same && proper_index);
0 commit comments