Skip to content

Commit 59f0df2

Browse files
committed
update
1 parent 66ec9bf commit 59f0df2

File tree

1 file changed

+5
-7
lines changed

1 file changed

+5
-7
lines changed

paddle/fluid/operators/reduce_ops/reduce_op.cu.h

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -338,11 +338,9 @@ struct ReduceConfig {
338338
void SetReduceType() {
339339
int rank = x_dim.size();
340340
int reduce_rank = reduce_dim.size();
341-
bool is_large_enough = (reduce_num > REDUCE_SPLIT_BOUNDARY / 2) ||
342-
(left_num > REDUCE_SPLIT_BOUNDARY);
343-
344-
if (rank == reduce_rank ||
345-
rank == 2 && reduce_rank == 1 && reduce_dim[0] == 1) {
341+
bool is_last_dim =
342+
(rank == 2) && (reduce_rank == 1) && (reduce_dim[0] == 1);
343+
if (rank == reduce_rank || is_last_dim) {
346344
reduce_type = static_cast<int>(ReduceType::kReduceLastDim);
347345
} else if (reduce_rank == 1) {
348346
// ReduceFirstDim and reduceSecondDim
@@ -788,9 +786,9 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y,
788786
}
789787

790788
config.SetOutputData(y_data, x.place(), &tmp);
791-
bool use_cub_Reduce = (config.left_num == 1) &&
789+
bool use_cub_reduce = (config.left_num == 1) &&
792790
(!std::is_same<Tx, paddle::platform::float16>::value);
793-
if (use_cub_Reduce) {
791+
if (use_cub_reduce) {
794792
// launch CUB::Reduce
795793
using TransformOp = typename ReduceOp<Tx, Ty>::Transformer;
796794
auto reducer = ReduceOp<Tx, Ty>();

0 commit comments

Comments
 (0)