26
26
namespace phi {
27
27
28
28
template <typename T, typename IndexT, typename Functor>
29
- void graph_send_recv_cpu_for_loop (const int & input_size,
30
- const int & index_size,
31
- const IndexT* s_index,
32
- const IndexT* d_index,
33
- const DenseTensor& src,
34
- DenseTensor* dst,
35
- const std::string& pool_type,
36
- int * dst_count = nullptr ) {
29
+ void GraphSendRecvCpuLoop (const int & input_size,
30
+ const int & index_size,
31
+ const IndexT* s_index,
32
+ const IndexT* d_index,
33
+ const DenseTensor& src,
34
+ DenseTensor* dst,
35
+ const std::string& pool_type,
36
+ int * dst_count = nullptr ) {
37
37
Functor functor;
38
38
if (pool_type == " SUM" ) {
39
39
for (int i = 0 ; i < index_size; ++i) {
40
40
const IndexT& src_idx = s_index[i];
41
41
const IndexT& dst_idx = d_index[i];
42
- elementwise_inner_operation <T, IndexT, Functor>(
42
+ ElementwiseInnerOperation <T, IndexT, Functor>(
43
43
src, dst, src_idx, dst_idx, false , functor);
44
44
}
45
45
} else if (pool_type == " MEAN" ) {
46
46
for (int i = 0 ; i < index_size; ++i) {
47
47
const IndexT& src_idx = s_index[i];
48
48
const IndexT& dst_idx = d_index[i];
49
- elementwise_inner_operation <T, IndexT, Functor>(
49
+ ElementwiseInnerOperation <T, IndexT, Functor>(
50
50
src, dst, src_idx, dst_idx, false , functor);
51
51
}
52
52
for (int i = 0 ; i < index_size; ++i) {
@@ -66,11 +66,11 @@ void graph_send_recv_cpu_for_loop(const int& input_size,
66
66
const IndexT& dst_idx = d_index[i];
67
67
bool in_set = existed_dst.find (dst_idx) != existed_dst.end ();
68
68
if (!in_set) {
69
- elementwise_inner_operation <T, IndexT, Functor>(
69
+ ElementwiseInnerOperation <T, IndexT, Functor>(
70
70
src, dst, src_idx, dst_idx, true , functor);
71
71
existed_dst.emplace (dst_idx);
72
72
} else {
73
- elementwise_inner_operation <T, IndexT, Functor>(
73
+ ElementwiseInnerOperation <T, IndexT, Functor>(
74
74
src, dst, src_idx, dst_idx, false , functor);
75
75
}
76
76
}
@@ -100,27 +100,26 @@ void GraphSendRecvOpKernelLaunchHelper(const Context& ctx,
100
100
const IndexT* s_index = src_index.data <IndexT>();
101
101
const IndexT* d_index = dst_index.data <IndexT>();
102
102
if (pool_type == " SUM" ) {
103
- graph_send_recv_cpu_for_loop <T, IndexT, GraphSendRecvSumFunctor<T>>(
103
+ GraphSendRecvCpuLoop <T, IndexT, GraphSendRecvSumFunctor<T>>(
104
104
src_dims[0 ], index_size, s_index, d_index, x, out, pool_type);
105
105
} else if (pool_type == " MIN" ) {
106
- graph_send_recv_cpu_for_loop <T, IndexT, GraphSendRecvMinFunctor<T>>(
106
+ GraphSendRecvCpuLoop <T, IndexT, GraphSendRecvMinFunctor<T>>(
107
107
src_dims[0 ], index_size, s_index, d_index, x, out, pool_type);
108
108
} else if (pool_type == " MAX" ) {
109
- graph_send_recv_cpu_for_loop <T, IndexT, GraphSendRecvMaxFunctor<T>>(
109
+ GraphSendRecvCpuLoop <T, IndexT, GraphSendRecvMaxFunctor<T>>(
110
110
src_dims[0 ], index_size, s_index, d_index, x, out, pool_type);
111
111
} else if (pool_type == " MEAN" ) {
112
112
ctx.template Alloc <int >(dst_count);
113
113
int * p_dst_count = dst_count->data <int >();
114
114
memset (p_dst_count, 0 , src_dims[0 ] * sizeof (int ));
115
- graph_send_recv_cpu_for_loop<T, IndexT, GraphSendRecvSumFunctor<T>>(
116
- src_dims[0 ],
117
- index_size,
118
- s_index,
119
- d_index,
120
- x,
121
- out,
122
- pool_type,
123
- p_dst_count);
115
+ GraphSendRecvCpuLoop<T, IndexT, GraphSendRecvSumFunctor<T>>(src_dims[0 ],
116
+ index_size,
117
+ s_index,
118
+ d_index,
119
+ x,
120
+ out,
121
+ pool_type,
122
+ p_dst_count);
124
123
}
125
124
}
126
125
0 commit comments