@@ -107,6 +107,42 @@ class UnfoldOp : public framework::OperatorWithKernel {
107107            " But recieved dims(strides: %u) != dims(dilations: %u)." 
108108            strides.size (), dilations.size ()));
109109
110+     //  check kernel_sizes
111+     PADDLE_ENFORCE_GT (kernel_sizes[0 ], 0 ,
112+                       platform::errors::InvalidArgument (
113+                           " The `kernel_sizes` should be greater than zero, " 
114+                           " but recieved kernel_height: %d kernel_width: %d." 
115+                           kernel_sizes[0 ], kernel_sizes[1 ]));
116+     PADDLE_ENFORCE_GT (kernel_sizes[1 ], 0 ,
117+                       platform::errors::InvalidArgument (
118+                           " The `kernel_sizes` should be greater than zero, " 
119+                           " but recieved kernel_height: %d kernel_width: %d." 
120+                           kernel_sizes[0 ], kernel_sizes[1 ]));
121+     //  check strides
122+     PADDLE_ENFORCE_GT (strides[0 ], 0 ,
123+                       platform::errors::InvalidArgument (
124+                           " The `strides` should be greater than zero, " 
125+                           " but recieved strides_height: %d strides_width: %d." 
126+                           strides[0 ], strides[1 ]));
127+     PADDLE_ENFORCE_GT (strides[1 ], 0 ,
128+                       platform::errors::InvalidArgument (
129+                           " The `strides` should be greater than zero, " 
130+                           " but recieved strides_height: %d strides_width: %d." 
131+                           strides[0 ], strides[1 ]));
132+     //  check dilations
133+     PADDLE_ENFORCE_GT (
134+         dilations[0 ], 0 ,
135+         platform::errors::InvalidArgument (
136+             " The `dilations` should be greater than zero, " 
137+             " but recieved dilations_height: %d dilations_width: %d." 
138+             dilations[0 ], dilations[1 ]));
139+     PADDLE_ENFORCE_GT (
140+         dilations[1 ], 0 ,
141+         platform::errors::InvalidArgument (
142+             " The `dilations` should be greater than zero, " 
143+             " but recieved dilations_height: %d dilations_width: %d." 
144+             dilations[0 ], dilations[1 ]));
145+ 
110146    std::vector<int > out_dims;
111147    out_dims.push_back (in_dims[0 ]);
112148
0 commit comments