@@ -603,12 +603,12 @@ __device__ void ReduceHigherDim(const Tx* x, Ty* y, ReduceOp reducer,
603603// when reduce_dim.size() != 1 and reduce_dim.size() != x_dim.size(), this
604604// function will be used
605605template <typename Tx, typename Ty, typename MPType, typename ReduceOp,
606- typename TransformOp, typename ReduceIndexCal, typename LeftIndexCal >
606+ typename TransformOp>
607607__device__ void ReduceAny (const Tx* x, Ty* y, ReduceOp reducer,
608608 TransformOp transformer, MPType init, int reduce_num,
609609 int left_num, bool reduce_lastdim,
610- ReduceIndexCal reduce_index_calculator,
611- LeftIndexCal left_index_calculator) {
610+ const IndexCalculator& reduce_index_calculator,
611+ const IndexCalculator& left_index_calculator) {
612612 int input_idx, left_idx, stride;
613613 // the last dim gets involved in reduction
614614 if (reduce_lastdim) {
@@ -621,7 +621,7 @@ __device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer,
621621 stride = gridDim.y * blockDim.y ;
622622 }
623623 // calculate the offset, means the addr where each thread really start.
624- int input_offset = left_index_calculator (left_idx);
624+ int input_offset = left_index_calculator. Get (left_idx);
625625 const Tx* input = x + input_offset;
626626 MPType reduce_var = init;
627627
@@ -634,7 +634,7 @@ __device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer,
634634#pragma unroll
635635 for (int i = 0 ; i < REDUCE_VEC_SIZE; ++i) {
636636 int reduce_idx = input_idx + i * stride;
637- int idx_x = reduce_index_calculator (reduce_idx);
637+ int idx_x = reduce_index_calculator. Get (reduce_idx);
638638 input_reg[i] = input[idx_x];
639639 }
640640#pragma unroll
@@ -653,7 +653,7 @@ __device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer,
653653 break ;
654654 }
655655 int reduce_idx = input_idx;
656- int idx_x = reduce_index_calculator (reduce_idx);
656+ int idx_x = reduce_index_calculator. Get (reduce_idx);
657657 input_reg[i] = input[idx_x];
658658 input_idx += stride;
659659 }
@@ -697,23 +697,15 @@ __device__ void ReduceModule(const Tx* x, Ty* y, ReduceOp reducer,
697697 int reduce_type, bool reduce_lastdim,
698698 const IndexCalculator& reduce_index_calculator,
699699 const IndexCalculator& left_index_calculator) {
700- if (reduce_type == ReduceType::kReduceLastDim ) {
700+ if (reduce_type == ReduceType::kReduceLastDim ||
701+ reduce_type == ReduceType::kReduceAny ) {
701702 ReduceAny<Tx, Ty, MPType, ReduceOp, TransformOp>(
702703 x, y, reducer, transformer, init, reduce_num, left_num, reduce_lastdim,
703- [&](int idx) { return idx; },
704- [&](int idx) { return idx * reduce_num; });
705-
704+ reduce_index_calculator, left_index_calculator);
706705 // reduce_rank == 1 && reduce_dim[0] != x_dim.size() - 1
707706 } else if (reduce_type == ReduceType::kReduceHigherDim ) {
708707 ReduceHigherDim<Tx, Ty, MPType, ReduceOp, TransformOp>(
709708 x, y, reducer, transformer, init, reduce_num, left_num, blocking_size);
710-
711- // reduce_rank >= 2
712- } else {
713- ReduceAny<Tx, Ty, MPType, ReduceOp, TransformOp>(
714- x, y, reducer, transformer, init, reduce_num, left_num, reduce_lastdim,
715- [&](int idx) { return reduce_index_calculator.Get (idx); },
716- [&](int idx) { return left_index_calculator.Get (idx); });
717709 }
718710}
719711
0 commit comments