Skip to content

Commit 3e10bd2

Browse files
committed
update
1 parent ba1d2fa commit 3e10bd2

File tree

1 file changed

+9
-17
lines changed

1 file changed

+9
-17
lines changed

paddle/fluid/operators/reduce_ops/reduce_op.cu.h

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -603,12 +603,12 @@ __device__ void ReduceHigherDim(const Tx* x, Ty* y, ReduceOp reducer,
603603
// when reduce_dim.size() != 1 and reduce_dim.size() != x_dim.size(), this
604604
// function will be used
605605
template <typename Tx, typename Ty, typename MPType, typename ReduceOp,
606-
typename TransformOp, typename ReduceIndexCal, typename LeftIndexCal>
606+
typename TransformOp>
607607
__device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer,
608608
TransformOp transformer, MPType init, int reduce_num,
609609
int left_num, bool reduce_lastdim,
610-
ReduceIndexCal reduce_index_calculator,
611-
LeftIndexCal left_index_calculator) {
610+
const IndexCalculator& reduce_index_calculator,
611+
const IndexCalculator& left_index_calculator) {
612612
int input_idx, left_idx, stride;
613613
// the last dim gets involved in reduction
614614
if (reduce_lastdim) {
@@ -621,7 +621,7 @@ __device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer,
621621
stride = gridDim.y * blockDim.y;
622622
}
623623
// calculate the offset, means the addr where each thread really start.
624-
int input_offset = left_index_calculator(left_idx);
624+
int input_offset = left_index_calculator.Get(left_idx);
625625
const Tx* input = x + input_offset;
626626
MPType reduce_var = init;
627627

@@ -634,7 +634,7 @@ __device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer,
634634
#pragma unroll
635635
for (int i = 0; i < REDUCE_VEC_SIZE; ++i) {
636636
int reduce_idx = input_idx + i * stride;
637-
int idx_x = reduce_index_calculator(reduce_idx);
637+
int idx_x = reduce_index_calculator.Get(reduce_idx);
638638
input_reg[i] = input[idx_x];
639639
}
640640
#pragma unroll
@@ -653,7 +653,7 @@ __device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer,
653653
break;
654654
}
655655
int reduce_idx = input_idx;
656-
int idx_x = reduce_index_calculator(reduce_idx);
656+
int idx_x = reduce_index_calculator.Get(reduce_idx);
657657
input_reg[i] = input[idx_x];
658658
input_idx += stride;
659659
}
@@ -697,23 +697,15 @@ __device__ void ReduceModule(const Tx* x, Ty* y, ReduceOp reducer,
697697
int reduce_type, bool reduce_lastdim,
698698
const IndexCalculator& reduce_index_calculator,
699699
const IndexCalculator& left_index_calculator) {
700-
if (reduce_type == ReduceType::kReduceLastDim) {
700+
if (reduce_type == ReduceType::kReduceLastDim ||
701+
reduce_type == ReduceType::kReduceAny) {
701702
ReduceAny<Tx, Ty, MPType, ReduceOp, TransformOp>(
702703
x, y, reducer, transformer, init, reduce_num, left_num, reduce_lastdim,
703-
[&](int idx) { return idx; },
704-
[&](int idx) { return idx * reduce_num; });
705-
704+
reduce_index_calculator, left_index_calculator);
706705
// reduce_rank == 1 && reduce_dim[0] != x_dim.size() - 1
707706
} else if (reduce_type == ReduceType::kReduceHigherDim) {
708707
ReduceHigherDim<Tx, Ty, MPType, ReduceOp, TransformOp>(
709708
x, y, reducer, transformer, init, reduce_num, left_num, blocking_size);
710-
711-
// reduce_rank >= 2
712-
} else {
713-
ReduceAny<Tx, Ty, MPType, ReduceOp, TransformOp>(
714-
x, y, reducer, transformer, init, reduce_num, left_num, reduce_lastdim,
715-
[&](int idx) { return reduce_index_calculator.Get(idx); },
716-
[&](int idx) { return left_index_calculator.Get(idx); });
717709
}
718710
}
719711

0 commit comments

Comments
 (0)