Skip to content

Commit aecf630

Browse files
committed
casting more
1 parent 71ab1f3 commit aecf630

File tree

1 file changed

+20
-8
lines changed

1 file changed

+20
-8
lines changed

include/tvm/topi/reduction.h

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -436,18 +436,24 @@ inline FCommReduce MakeArgminReducer(bool select_last_index = false) {
436436
auto fcombine = [=](Array<Var> lhs, Array<Var> rhs) {
437437
Array<PrimExpr> result;
438438

439+
// Casting to avoid operator ambiguity
440+
PrimExpr lhs_idx = static_cast<PrimExpr>(lhs[0]);
441+
PrimExpr rhs_idx = static_cast<PrimExpr>(rhs[0]);
442+
PrimExpr lhs_val = static_cast<PrimExpr>(lhs[1]);
443+
PrimExpr rhs_val = static_cast<PrimExpr>(rhs[1]);
444+
439445
// These variables compare the actual values of the array
440-
auto is_smaller = lhs[1] < rhs[1];
441-
auto is_same = lhs[1] == rhs[1];
446+
auto is_smaller = lhs_val < rhs_val;
447+
auto is_same = lhs_val == rhs_val;
442448

443449
// This checks if the indices are correct for the reduction. E.g. for select_last_index
444450
// it gives precedence for later indices of the same element and precedence for sooner
445451
// indices if not select_last_index;
446452
PrimExpr proper_index;
447453
if (select_last_index) {
448-
proper_index = PrimExpr(lhs[0]) > PrimExpr(rhs[0]);
454+
proper_index = lhs_idx > rhs_idx;
449455
} else {
450-
proper_index = PrimExpr(lhs[0]) < PrimExpr(rhs[0]);
456+
proper_index = lhs_idx < rhs_idx;
451457
}
452458

453459
PrimExpr update_index = is_smaller || (is_same && proper_index);
@@ -491,18 +497,24 @@ inline FCommReduce MakeArgmaxReducer(bool select_last_index = false) {
491497
auto fcombine = [=](Array<Var> lhs, Array<Var> rhs) {
492498
Array<PrimExpr> result;
493499

500+
// Casting to avoid operator ambiguity
501+
PrimExpr lhs_idx = static_cast<PrimExpr>(lhs[0]);
502+
PrimExpr rhs_idx = static_cast<PrimExpr>(rhs[0]);
503+
PrimExpr lhs_val = static_cast<PrimExpr>(lhs[1]);
504+
PrimExpr rhs_val = static_cast<PrimExpr>(rhs[1]);
505+
494506
// These variables compare the actual values of the array
495-
auto is_bigger = lhs[1] > rhs[1];
496-
auto is_same = lhs[1] == rhs[1];
507+
auto is_bigger = lhs_val > rhs_val;
508+
auto is_same = lhs_val == rhs_val;
497509

498510
// This checks if the indices are correct for the reduction. E.g. for select_last_index
499511
// it gives precedence for later indices of the same element and precedence for sooner
500512
// indices if not select_last_index;
501513
PrimExpr proper_index;
502514
if (select_last_index) {
503-
proper_index = PrimExpr(lhs[0]) > PrimExpr(rhs[0]);
515+
proper_index = lhs_idx > rhs_idx;
504516
} else {
505-
proper_index = PrimExpr(lhs[0]) < PrimExpr(rhs[0]);
517+
proper_index = lhs_idx < rhs_idx;
506518
}
507519

508520
PrimExpr update_index = is_bigger || (is_same && proper_index);

0 commit comments

Comments
 (0)