1414
1515#include " paddle/cinn/frontend/decomposer_registry.h"
1616#include " paddle/cinn/frontend/syntax.h"
17-
17+ # include " paddle/common/enforce.h "
1818namespace cinn {
1919namespace frontend {
2020namespace decomposer {
@@ -25,9 +25,12 @@ struct BatchNormHelper {
2525 const std::vector<int >& arg_param_shape,
2626 std::string data_layout,
2727 std::string bn_op_type) {
28- CHECK_EQ (arg_x_shape.size (), 4UL )
29- << " Only 4-D input tensor is supported, but get " << arg_x_shape.size ()
30- << " -D input tensor." ;
28+ PADDLE_ENFORCE_EQ (arg_x_shape.size (),
29+ 4UL ,
30+ phi::errors::InvalidArgument (
31+ " Only 4-D input tensor is supported, but get %d" ,
32+ arg_x_shape.size (),
33+ " -D input tensor." ));
3134
3235 builder = net_builder;
3336 x_shape = arg_x_shape;
@@ -162,21 +165,34 @@ struct BatchNormHelper {
162165
163166void batch_norm_train (const Instruction& instr,
164167 const DecomposerContext& context) {
165- CHECK_EQ (instr->inputs .size (), 5UL )
166- << " The number of the given inputs is not equal to the required for op "
167- << instr->op_type ;
168- CHECK_EQ (instr->outputs .size (), 5UL )
169- << " The number of the given outputs is not equal to the required for op "
170- << instr->op_type ;
168+ PADDLE_ENFORCE_EQ (
169+ instr->inputs .size (),
170+ 5UL ,
171+ phi::errors::InvalidArgument (
172+ " The number of the given inputs is not equal to the required" ));
173+ PADDLE_ENFORCE_EQ (
174+ instr->outputs .size (),
175+ 5UL ,
176+ phi::errors::InvalidArgument (
177+ " The number of the given outputs is not equal to the required" ));
171178
172179 auto & x = instr->inputs [0 ];
173180 auto & scale = instr->inputs [1 ];
174181 auto & bias = instr->inputs [2 ];
175182 auto & moving_mean = instr->inputs [3 ];
176183 auto & moving_variance = instr->inputs [4 ];
177- CHECK_EQ (scale->type , bias->type );
178- CHECK_EQ (scale->type , moving_mean->type );
179- CHECK_EQ (scale->type , moving_variance->type );
184+ PADDLE_ENFORCE_EQ (
185+ scale->type == bias->type ,
186+ true ,
187+ phi::errors::InvalidArgument (" The type of scale and bias is not equal" ));
188+ PADDLE_ENFORCE_EQ (scale->type == moving_mean->type ,
189+ true ,
190+ phi::errors::InvalidArgument (
191+ " The type of scale and moving_mean is not equal" ));
192+ PADDLE_ENFORCE_EQ (scale->type == moving_variance->type ,
193+ true ,
194+ phi::errors::InvalidArgument (
195+ " The type of scale and moving_variance is not equal" ));
180196
181197 float epsilon = instr.GetAttrs <float >(" epsilon" );
182198 float momentum = instr.GetAttrs <float >(" momentum" );
@@ -219,21 +235,34 @@ void batch_norm_train(const Instruction& instr,
219235
220236void batch_norm_grad (const Instruction& instr,
221237 const DecomposerContext& context) {
222- CHECK_EQ (instr->inputs .size (), 5UL )
223- << " The number of the given inputs is not equal to the required "
224- << instr->op_type ;
225- CHECK_EQ (instr->outputs .size (), 3UL )
226- << " The number of the given outputs is not equal to the required"
227- << instr->op_type ;
238+ PADDLE_ENFORCE_EQ (
239+ instr->inputs .size (),
240+ 5UL ,
241+ phi::errors::InvalidArgument (
242+ " The number of the given inputs is not equal to the required" ));
243+ PADDLE_ENFORCE_EQ (
244+ instr->outputs .size (),
245+ 3UL ,
246+ phi::errors::InvalidArgument (
247+ " The number of the given outputs is not equal to the required" ));
228248
229249 auto & y_grad = instr->inputs [0 ];
230250 auto & x = instr->inputs [1 ];
231251 auto & scale = instr->inputs [2 ];
232252 auto & save_mean = instr->inputs [3 ];
233253 auto & save_variance = instr->inputs [4 ];
234- CHECK_EQ (y_grad->type , x->type );
235- CHECK_EQ (scale->type , save_mean->type );
236- CHECK_EQ (scale->type , save_variance->type );
254+ PADDLE_ENFORCE_EQ (
255+ y_grad->type == x->type ,
256+ true ,
257+ phi::errors::InvalidArgument (" The type of y_grad and x is not equal" ));
258+ PADDLE_ENFORCE_EQ (scale->type == save_mean->type ,
259+ true ,
260+ phi::errors::InvalidArgument (
261+ " The type of scale and save_mean is not equal" ));
262+ PADDLE_ENFORCE_EQ (scale->type == save_variance->type ,
263+ true ,
264+ phi::errors::InvalidArgument (
265+ " The type of scale and save_variance is not equal" ));
237266
238267 auto epsilon = instr.GetAttrs <float >(" epsilon" );
239268 auto layout = instr.GetAttrs <std::string>(" data_layout" );
@@ -304,21 +333,35 @@ void batch_norm_grad(const Instruction& instr,
304333}
305334
306335void batch_norm (const Instruction& instr, const DecomposerContext& context) {
307- CHECK_EQ (instr->inputs .size (), 5UL )
308- << " The number of the given inputs is not equal to the required for op "
309- << instr->op_type ;
310- CHECK_EQ (instr->outputs .size (), 1UL )
311- << " The number of the given outputs is not equal to the required for op "
312- << instr->op_type ;
336+ PADDLE_ENFORCE_EQ (
337+ instr->inputs .size (),
338+ 5UL ,
339+ phi::errors::InvalidArgument (
340+ " The number of the given inputs is not equal to the required" ));
341+
342+ PADDLE_ENFORCE_EQ (
343+ instr->outputs .size (),
344+ 1UL ,
345+ phi::errors::InvalidArgument (
346+ " The number of the given outputs is not equal to the required" ));
313347
314348 auto & x = instr->inputs [0 ];
315349 auto & scale = instr->inputs [1 ];
316350 auto & bias = instr->inputs [2 ];
317351 auto & moving_mean = instr->inputs [3 ];
318352 auto & moving_variance = instr->inputs [4 ];
319- CHECK_EQ (scale->type , bias->type );
320- CHECK_EQ (scale->type , moving_mean->type );
321- CHECK_EQ (scale->type , moving_variance->type );
353+ PADDLE_ENFORCE_EQ (
354+ scale->type == bias->type ,
355+ true ,
356+ phi::errors::InvalidArgument (" The type of scale and bias is not equal" ));
357+ PADDLE_ENFORCE_EQ (scale->type == moving_mean->type ,
358+ true ,
359+ phi::errors::InvalidArgument (
360+ " The type of scale and moving_mean is not equal" ));
361+ PADDLE_ENFORCE_EQ (scale->type == moving_variance->type ,
362+ true ,
363+ phi::errors::InvalidArgument (
364+ " The type of scale and moving_variance is not equal" ));
322365
323366 float epsilon = instr.GetAttrs <float >(" epsilon" );
324367 float momentum = instr.GetAttrs <float >(" momentum" );
0 commit comments