@@ -13,42 +13,12 @@ limitations under the License. */
1313
1414#include  < vector> 
1515
16- #include  < boost/preprocessor/arithmetic/div.hpp> 
17- #include  < boost/preprocessor/arithmetic/mod.hpp> 
18- #include  < boost/preprocessor/comparison/greater.hpp> 
19- #include  < boost/preprocessor/comparison/greater_equal.hpp> 
20- #include  < boost/preprocessor/control/if.hpp> 
21- #include  < boost/preprocessor/repetition/repeat.hpp> 
2216#include  " paddle/fluid/framework/eigen.h" 
2317#include  " paddle/fluid/framework/op_registry.h" 
2418#include  " paddle/fluid/framework/operator.h" 
2519#include  " paddle/fluid/operators/eigen/eigen_function.h" 
2620
2721#define  MAX_RANK_SUPPORTED  6 
28- //  1. BOOST_PP_REPEAT macro represents a fast horizontal repetition construct.
29- //     Usage: BOOST_PP_REPEAT(count, macro, data).
30- //     This macro expands to the sequence:
31- //     macro(z, 0, data) macro(z, 1, data) ... macro(z, count - 1, data).
32- //  2. As for our case, count = MAX_RANK_SUPPORTED(which is 6).
33- //     So the range of n is 0-5(which is count-1).
34- //     We want to generate case 1-6 instead of case 0-5.
35- //     So we need to change n to n + 1.
36- #define  EXPAND_AS_TEMPLATE (z, n, data ) \
37-   case  n + 1 : {                        \
38-     ExpandAs<n + 1 >(context);          \
39-     break ;                             \
40-   }
41- #define  REP_EXPAND_AS_TEMPLATE (n ) BOOST_PP_REPEAT(n, EXPAND_AS_TEMPLATE, ~)
42- #define  COND (n ) BOOST_PP_GREATER_EQUAL(n, BOOST_PP_MOD(n, MAX_RANK_SUPPORTED))
43- #define  EXPAND_AS_GRAD_CASE (n )                                           \
44-   case  n + 1 : {                                                          \
45-     ExpandAsBackward<n + 1 >(context, reshape_dims_vec, reduce_dims_vec); \
46-     break ;                                                               \
47-   }
48- #define  EXPAND_AS_GRAD_TEMPLATE (z, n, data ) \
49-   BOOST_PP_IF (COND(n), EXPAND_AS_GRAD_CASE(n), )
50- #define  REP_EXPAND_AS_GRAD_TEMPLATE (n ) \
51-   BOOST_PP_REPEAT (n, EXPAND_AS_GRAD_TEMPLATE, ~)
5222
5323namespace  paddle  {
5424namespace  operators  {
@@ -67,7 +37,24 @@ class ExpandAsKernel : public framework::OpKernel<T> {
6737  void  Compute (const  framework::ExecutionContext& context) const  override  {
6838    auto  rank = context.Input <Tensor>(" X" dims ().size ();
6939    switch  (rank) {
70-       REP_EXPAND_AS_TEMPLATE (MAX_RANK_SUPPORTED)
40+       case  1 :
41+         ExpandAs<1 >(context);
42+         break ;
43+       case  2 :
44+         ExpandAs<2 >(context);
45+         break ;
46+       case  3 :
47+         ExpandAs<3 >(context);
48+         break ;
49+       case  4 :
50+         ExpandAs<4 >(context);
51+         break ;
52+       case  5 :
53+         ExpandAs<5 >(context);
54+         break ;
55+       case  6 :
56+         ExpandAs<6 >(context);
57+         break ;
7158      default :
7259        PADDLE_THROW (platform::errors::InvalidArgument (
7360            " Only support tensor with rank being between 1 and 6. But received " 
@@ -165,7 +152,24 @@ class ExpandAsGradKernel : public framework::OpKernel<T> {
165152                            " to %d, but the value received is %d." 
166153                            MAX_RANK_SUPPORTED, dims));
167154      switch  (dims) {
168-         REP_EXPAND_AS_GRAD_TEMPLATE (MAX_RANK_SUPPORTED)
155+         case  1 :
156+           ExpandAsBackward<1 >(context, reshape_dims_vec, reduce_dims_vec);
157+           break ;
158+         case  2 :
159+           ExpandAsBackward<2 >(context, reshape_dims_vec, reduce_dims_vec);
160+           break ;
161+         case  3 :
162+           ExpandAsBackward<3 >(context, reshape_dims_vec, reduce_dims_vec);
163+           break ;
164+         case  4 :
165+           ExpandAsBackward<4 >(context, reshape_dims_vec, reduce_dims_vec);
166+           break ;
167+         case  5 :
168+           ExpandAsBackward<5 >(context, reshape_dims_vec, reduce_dims_vec);
169+           break ;
170+         case  6 :
171+           ExpandAsBackward<6 >(context, reshape_dims_vec, reduce_dims_vec);
172+           break ;
169173        default :
170174          PADDLE_THROW (platform::errors::InvalidArgument (
171175              " Only support tensor with rank being between 1 and 6. But " 
0 commit comments