Skip to content

Commit 1429e04

Browse files
committed
fix some format problem
1 parent 7f3613b commit 1429e04

File tree

3 files changed

+59
-63
lines changed

3 files changed

+59
-63
lines changed

paddle/phi/kernels/cpu/graph_send_recv_funcs.h

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
#include <algorithm>
1717
#include <vector>
1818

19-
#include "paddle/fluid/operators/eigen/eigen_function.h"
2019
#include "paddle/phi/backends/cpu/cpu_context.h"
2120
#include "paddle/phi/core/dense_tensor.h"
2221
#include "paddle/phi/core/hostdevice.h"
@@ -66,12 +65,12 @@ struct GraphSendRecvMaxFunctor {
6665
};
6766

6867
template <typename T, typename IndexT, typename Functor>
69-
void elementwise_inner_operation(const DenseTensor& src,
70-
DenseTensor* dst,
71-
const IndexT& src_index,
72-
const IndexT& dst_index,
73-
const bool& first_flag,
74-
Functor functor) {
68+
void ElementwiseInnerOperation(const DenseTensor& src,
69+
DenseTensor* dst,
70+
const IndexT& src_index,
71+
const IndexT& dst_index,
72+
const bool& first_flag,
73+
Functor functor) {
7574
auto src_slice = src.Slice(src_index, src_index + 1);
7675
auto dst_slice = dst->Slice(dst_index, dst_index + 1);
7776

paddle/phi/kernels/cpu/graph_send_recv_grad_kernel.cc

Lines changed: 30 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -23,22 +23,22 @@
2323
namespace phi {
2424

2525
template <typename T, typename IndexT, typename Functor>
26-
void graph_send_recv_cpu_for_loop_grad(const int& input_size,
27-
const int& index_size,
28-
const IndexT* s_index,
29-
const IndexT* d_index,
30-
const DenseTensor& src,
31-
DenseTensor* dst,
32-
const std::string& pool_type,
33-
const int* dst_count = nullptr,
34-
const DenseTensor* input = nullptr,
35-
const DenseTensor* output = nullptr) {
26+
void GraphSendRecvCpuGradLoop(const int& input_size,
27+
const int& index_size,
28+
const IndexT* s_index,
29+
const IndexT* d_index,
30+
const DenseTensor& src,
31+
DenseTensor* dst,
32+
const std::string& pool_type,
33+
const int* dst_count = nullptr,
34+
const DenseTensor* input = nullptr,
35+
const DenseTensor* output = nullptr) {
3636
if (pool_type == "SUM") {
3737
Functor functor;
3838
for (int i = 0; i < index_size; ++i) {
3939
const IndexT& src_idx = s_index[i];
4040
const IndexT& dst_idx = d_index[i];
41-
elementwise_inner_operation<T, IndexT, Functor>(
41+
ElementwiseInnerOperation<T, IndexT, Functor>(
4242
src, dst, src_idx, dst_idx, false, functor);
4343
}
4444
} else if (pool_type == "MEAN") {
@@ -96,33 +96,31 @@ void GraphSendRecvGradOpKernelLaunchHelper(
9696
const IndexT* d_index = dst_index.data<IndexT>();
9797

9898
if (pool_type == "SUM") {
99-
graph_send_recv_cpu_for_loop_grad<T, IndexT, GraphSendRecvSumFunctor<T>>(
99+
GraphSendRecvCpuGradLoop<T, IndexT, GraphSendRecvSumFunctor<T>>(
100100
src_dims[0], index_size, d_index, s_index, out_grad, x_grad, pool_type);
101101
} else if (pool_type == "MEAN") {
102102
const int* s_count = dst_count->data<int>();
103103
// Functor not used here.
104-
graph_send_recv_cpu_for_loop_grad<T, IndexT, GraphSendRecvSumFunctor<T>>(
105-
src_dims[0],
106-
index_size,
107-
d_index,
108-
s_index,
109-
out_grad,
110-
x_grad,
111-
pool_type,
112-
s_count);
104+
GraphSendRecvCpuGradLoop<T, IndexT, GraphSendRecvSumFunctor<T>>(src_dims[0],
105+
index_size,
106+
d_index,
107+
s_index,
108+
out_grad,
109+
x_grad,
110+
pool_type,
111+
s_count);
113112
} else if (pool_type == "MIN" || pool_type == "MAX") {
114113
// Functor not used here.
115-
graph_send_recv_cpu_for_loop_grad<T, IndexT, GraphSendRecvMinFunctor<T>>(
116-
src_dims[0],
117-
index_size,
118-
d_index,
119-
s_index,
120-
out_grad,
121-
x_grad,
122-
pool_type,
123-
nullptr,
124-
x,
125-
out);
114+
GraphSendRecvCpuGradLoop<T, IndexT, GraphSendRecvMinFunctor<T>>(src_dims[0],
115+
index_size,
116+
d_index,
117+
s_index,
118+
out_grad,
119+
x_grad,
120+
pool_type,
121+
nullptr,
122+
x,
123+
out);
126124
}
127125
}
128126

paddle/phi/kernels/cpu/graph_send_recv_kernel.cc

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -26,27 +26,27 @@
2626
namespace phi {
2727

2828
template <typename T, typename IndexT, typename Functor>
29-
void graph_send_recv_cpu_for_loop(const int& input_size,
30-
const int& index_size,
31-
const IndexT* s_index,
32-
const IndexT* d_index,
33-
const DenseTensor& src,
34-
DenseTensor* dst,
35-
const std::string& pool_type,
36-
int* dst_count = nullptr) {
29+
void GraphSendRecvCpuLoop(const int& input_size,
30+
const int& index_size,
31+
const IndexT* s_index,
32+
const IndexT* d_index,
33+
const DenseTensor& src,
34+
DenseTensor* dst,
35+
const std::string& pool_type,
36+
int* dst_count = nullptr) {
3737
Functor functor;
3838
if (pool_type == "SUM") {
3939
for (int i = 0; i < index_size; ++i) {
4040
const IndexT& src_idx = s_index[i];
4141
const IndexT& dst_idx = d_index[i];
42-
elementwise_inner_operation<T, IndexT, Functor>(
42+
ElementwiseInnerOperation<T, IndexT, Functor>(
4343
src, dst, src_idx, dst_idx, false, functor);
4444
}
4545
} else if (pool_type == "MEAN") {
4646
for (int i = 0; i < index_size; ++i) {
4747
const IndexT& src_idx = s_index[i];
4848
const IndexT& dst_idx = d_index[i];
49-
elementwise_inner_operation<T, IndexT, Functor>(
49+
ElementwiseInnerOperation<T, IndexT, Functor>(
5050
src, dst, src_idx, dst_idx, false, functor);
5151
}
5252
for (int i = 0; i < index_size; ++i) {
@@ -66,11 +66,11 @@ void graph_send_recv_cpu_for_loop(const int& input_size,
6666
const IndexT& dst_idx = d_index[i];
6767
bool in_set = existed_dst.find(dst_idx) != existed_dst.end();
6868
if (!in_set) {
69-
elementwise_inner_operation<T, IndexT, Functor>(
69+
ElementwiseInnerOperation<T, IndexT, Functor>(
7070
src, dst, src_idx, dst_idx, true, functor);
7171
existed_dst.emplace(dst_idx);
7272
} else {
73-
elementwise_inner_operation<T, IndexT, Functor>(
73+
ElementwiseInnerOperation<T, IndexT, Functor>(
7474
src, dst, src_idx, dst_idx, false, functor);
7575
}
7676
}
@@ -100,27 +100,26 @@ void GraphSendRecvOpKernelLaunchHelper(const Context& ctx,
100100
const IndexT* s_index = src_index.data<IndexT>();
101101
const IndexT* d_index = dst_index.data<IndexT>();
102102
if (pool_type == "SUM") {
103-
graph_send_recv_cpu_for_loop<T, IndexT, GraphSendRecvSumFunctor<T>>(
103+
GraphSendRecvCpuLoop<T, IndexT, GraphSendRecvSumFunctor<T>>(
104104
src_dims[0], index_size, s_index, d_index, x, out, pool_type);
105105
} else if (pool_type == "MIN") {
106-
graph_send_recv_cpu_for_loop<T, IndexT, GraphSendRecvMinFunctor<T>>(
106+
GraphSendRecvCpuLoop<T, IndexT, GraphSendRecvMinFunctor<T>>(
107107
src_dims[0], index_size, s_index, d_index, x, out, pool_type);
108108
} else if (pool_type == "MAX") {
109-
graph_send_recv_cpu_for_loop<T, IndexT, GraphSendRecvMaxFunctor<T>>(
109+
GraphSendRecvCpuLoop<T, IndexT, GraphSendRecvMaxFunctor<T>>(
110110
src_dims[0], index_size, s_index, d_index, x, out, pool_type);
111111
} else if (pool_type == "MEAN") {
112112
ctx.template Alloc<int>(dst_count);
113113
int* p_dst_count = dst_count->data<int>();
114114
memset(p_dst_count, 0, src_dims[0] * sizeof(int));
115-
graph_send_recv_cpu_for_loop<T, IndexT, GraphSendRecvSumFunctor<T>>(
116-
src_dims[0],
117-
index_size,
118-
s_index,
119-
d_index,
120-
x,
121-
out,
122-
pool_type,
123-
p_dst_count);
115+
GraphSendRecvCpuLoop<T, IndexT, GraphSendRecvSumFunctor<T>>(src_dims[0],
116+
index_size,
117+
s_index,
118+
d_index,
119+
x,
120+
out,
121+
pool_type,
122+
p_dst_count);
124123
}
125124
}
126125

0 commit comments

Comments
 (0)