@@ -15,6 +15,8 @@ limitations under the License. */
1515#ifdef PADDLE_WITH_XPU
1616
1717#include " paddle/fluid/operators/batch_norm_op.h"
18+ #include < iterator>
19+ #include < vector>
1820
1921namespace paddle {
2022namespace operators {
@@ -25,23 +27,25 @@ using DDim = framework::DDim;
2527template <typename DeviceContext, typename T>
2628class BatchNormXPUKernel : public framework ::OpKernel<T> {
2729 public:
28- void Compute (const framework::ExecutionContext& ctx) const override {
30+ void Compute (const framework::ExecutionContext & ctx) const override {
2931 const auto epsilon = ctx.Attr <float >(" epsilon" );
30- const auto momentum = ctx.Attr <float >(" momentum" );
32+ float momentum = ctx.Attr <float >(" momentum" );
3133 const auto is_test = ctx.Attr <bool >(" is_test" );
3234 const auto use_global_stats = ctx.Attr <bool >(" use_global_stats" );
3335 const auto trainable_stats = ctx.Attr <bool >(" trainable_statistics" );
3436 bool test_mode = is_test && (!trainable_stats);
37+
3538 bool global_stats = test_mode || use_global_stats;
36- const auto & data_layout_str = ctx.Attr <std::string>(" data_layout" );
39+ const auto & data_layout_str = ctx.Attr <std::string>(" data_layout" );
3740 const auto data_layout = framework::StringToDataLayout (data_layout_str);
3841 PADDLE_ENFORCE_EQ (data_layout, DataLayout::kNCHW ,
3942 platform::errors::InvalidArgument (
4043 " The 'data_layout' attribute must be NCHW. But "
4144 " recevived 'data_layout' is [%s]." ,
4245 data_layout_str));
43- const auto * x = ctx.Input <Tensor>(" X" );
44- const auto & x_dims = x->dims ();
46+
47+ const auto *x = ctx.Input <Tensor>(" X" );
48+ const auto &x_dims = x->dims ();
4549 PADDLE_ENFORCE_EQ (x_dims.size (), 4 ,
4650 platform::errors::InvalidArgument (
4751 " The input tensor X's dimension must equal to 4. But "
@@ -51,27 +55,42 @@ class BatchNormXPUKernel : public framework::OpKernel<T> {
5155 const int C = x_dims[1 ];
5256 const int H = x_dims[2 ];
5357 const int W = x_dims[3 ];
54- const auto * scale = ctx.Input <Tensor>(" Scale" );
55- const auto * bias = ctx.Input <Tensor>(" Bias" );
56- const auto * x_data = x->data <T>();
57- const auto * scale_data = scale->data <T>();
58- const auto * bias_data = bias->data <T>();
59- auto * y = ctx.Output <Tensor>(" Y" );
60- auto * y_data = y->mutable_data <T>(ctx.GetPlace ());
61- auto & dev_ctx = ctx.template device_context <DeviceContext>();
58+ const auto *scale = ctx.Input <Tensor>(" Scale" );
59+ const auto *bias = ctx.Input <Tensor>(" Bias" );
60+ const auto *x_data = x->data <T>();
61+ const auto *scale_data = scale->data <float >();
62+ const auto *bias_data = bias->data <float >();
63+
64+ auto *y = ctx.Output <Tensor>(" Y" );
65+ auto *mean_out = ctx.Output <Tensor>(" MeanOut" );
66+ auto *variance_out = ctx.Output <Tensor>(" VarianceOut" );
67+ auto *saved_mean = ctx.Output <Tensor>(" SavedMean" );
68+ auto *saved_variance = ctx.Output <Tensor>(" SavedVariance" );
69+
70+ // alloc memory
71+ auto *y_data = y->mutable_data <T>(ctx.GetPlace ());
72+ mean_out->mutable_data <float >(ctx.GetPlace ());
73+ variance_out->mutable_data <float >(ctx.GetPlace ());
74+ saved_mean->mutable_data <float >(ctx.GetPlace ());
75+ saved_variance->mutable_data <float >(ctx.GetPlace ());
76+
77+ auto &dev_ctx = ctx.template device_context <DeviceContext>();
78+
6279 if (!global_stats) {
63- auto * mean_out = ctx.Output <Tensor>(" MeanOut" );
64- auto * variance_out = ctx.Output <Tensor>(" VarianceOut" );
65- auto * saved_mean = ctx.Output <Tensor>(" SavedMean" );
66- auto * saved_variance = ctx.Output <Tensor>(" SavedVariance" );
67- mean_out->mutable_data <T>(ctx.GetPlace ());
68- variance_out->mutable_data <T>(ctx.GetPlace ());
69- saved_mean->mutable_data <T>(ctx.GetPlace ());
70- saved_variance->mutable_data <T>(ctx.GetPlace ());
71- auto * mean_out_data = mean_out->data <T>();
72- auto * variance_out_data = variance_out->data <T>();
73- auto * saved_mean_data = saved_mean->data <T>();
74- auto * saved_variance_data = saved_variance->data <T>();
80+ auto *mean_out_data = mean_out->data <float >();
81+ auto *variance_out_data = variance_out->data <float >();
82+ auto *saved_mean_data = saved_mean->data <float >();
83+ auto *saved_variance_data = saved_variance->data <float >();
84+
85+ // if MomentumTensor is set, use MomentumTensor value, momentum
86+ // is only used in this training branch
87+ if (ctx.HasInput (" MomentumTensor" )) {
88+ const auto *mom_tensor = ctx.Input <Tensor>(" MomentumTensor" );
89+ Tensor mom_cpu;
90+ TensorCopySync (*mom_tensor, platform::CPUPlace (), &mom_cpu);
91+ momentum = mom_tensor->data <float >()[0 ];
92+ }
93+
7594 int r = xpu::batch_norm<T>(dev_ctx.x_context (), x_data, y_data, N, C, H,
7695 W, epsilon, momentum, scale_data, bias_data,
7796 saved_mean_data, saved_variance_data,
@@ -81,12 +100,10 @@ class BatchNormXPUKernel : public framework::OpKernel<T> {
81100 " The batch_norm XPU API return wrong value[%d %s]" ,
82101 r, XPUAPIErrorMsg[r]));
83102 } else {
84- const auto * mean = ctx.Input <Tensor>(" Mean" );
85- const auto * variance = ctx.Input <Tensor>(" Variance" );
86- const auto * mean_data = mean->data <float >();
87- const auto * variance_data = variance->data <float >();
88- const auto * x_data = x->data <float >();
89- auto * y_data = y->mutable_data <float >(ctx.GetPlace ());
103+ const auto *mean = ctx.Input <Tensor>(" Mean" );
104+ const auto *variance = ctx.Input <Tensor>(" Variance" );
105+ const auto *mean_data = mean->data <float >();
106+ const auto *variance_data = variance->data <float >();
90107 int r = xpu::batch_norm_infer (dev_ctx.x_context (), x_data, y_data, N, C,
91108 H, W, epsilon, scale_data, bias_data,
92109 mean_data, variance_data, true );
@@ -99,24 +116,96 @@ class BatchNormXPUKernel : public framework::OpKernel<T> {
99116 }
100117};
101118
119+ template <typename T>
120+ static int calculate_inv_BN_Y (xpu::Context *ctx, T *x, const T *scale,
121+ const T *bias, const T *mean, const T *variance,
122+ const int N, const int C, const int M,
123+ const T *y) {
124+ PADDLE_ENFORCE_EQ (x, y, platform::errors::InvalidArgument (
125+ " X and Y should be inplaced in inplace mode" ));
126+ std::vector<int > tensor_shape_vec ({N, C, M});
127+ std::vector<int > array_shape_vec ({1 , C, 1 });
128+ // y - bias
129+ int r1 =
130+ xpu::broadcast_sub<T>(ctx, bias, y, x, array_shape_vec, tensor_shape_vec);
131+ // (y - bias) / scale
132+ int r2 = xpu::broadcast_div<T>(ctx, scale, x, x, array_shape_vec,
133+ tensor_shape_vec);
134+ // (y - bias) / scale / variance
135+ int r3 = xpu::broadcast_div<T>(ctx, variance, x, x, array_shape_vec,
136+ tensor_shape_vec);
137+ // (y - bias) / scale / variance + mean
138+ int r4 =
139+ xpu::broadcast_add<T>(ctx, mean, x, x, array_shape_vec, tensor_shape_vec);
140+
141+ return r1 + r2 + r3 + r4;
142+ }
143+
144+ template <typename T>
145+ static int calculate_inv_var (xpu::Context *ctx, const T *var, const T epsilon,
146+ const int C, T *epsilon_data, T *inv_var) {
147+ int r1 = constant (ctx, epsilon_data, 1 , epsilon);
148+ std::vector<int > tensor_shape_vec ({C});
149+ std::vector<int > array_shape_vec ({1 });
150+ int r2 = xpu::broadcast_add<T>(ctx, epsilon_data, var, inv_var,
151+ array_shape_vec, tensor_shape_vec);
152+ int r3 = xpu::rsqrt<T>(ctx, inv_var, inv_var, C);
153+ return r1 + r2 + r3;
154+ }
155+
102156template <typename DeviceContext, typename T>
103157class BatchNormGradXPUKernel : public framework ::OpKernel<T> {
104158 public:
105- void Compute (const framework::ExecutionContext& ctx) const override {
106- const auto * x = ctx.Input <Tensor>(" X" );
107- const auto * dy = ctx.Input <Tensor>(framework::GradVarName (" Y" ));
108- const auto * scale = ctx.Input <Tensor>(" Scale" );
109- const auto * saved_mean = ctx.Input <Tensor>(" SavedMean" );
110- // SavedVariance have been reverted in forward operator
111- const auto * saved_inv_variance = ctx.Input <Tensor>(" SavedVariance" );
112- const auto & data_layout_str = ctx.Attr <std::string>(" data_layout" );
159+ void Compute (const framework::ExecutionContext &ctx) const override {
160+ const auto *d_y = ctx.Input <Tensor>(framework::GradVarName (" Y" ));
161+ const auto *scale = ctx.Input <Tensor>(" Scale" );
162+ const auto *bias = ctx.Input <Tensor>(" Bias" );
163+
164+ const auto &data_layout_str = ctx.Attr <std::string>(" data_layout" );
165+ bool use_global_stats = ctx.Attr <bool >(" use_global_stats" );
166+ const bool is_test = ctx.Attr <bool >(" is_test" );
167+ const float epsilon = ctx.Attr <float >(" epsilon" );
113168 const auto data_layout = framework::StringToDataLayout (data_layout_str);
169+
170+ // TODO(guozbin): Transform input tensor from NHWC to NCHW
114171 PADDLE_ENFORCE_EQ (data_layout, DataLayout::kNCHW ,
115172 platform::errors::InvalidArgument (
116173 " The 'data_layout' attribute must be NCHW. But "
117174 " recevived 'data_layout' is [%s]." ,
118175 data_layout_str));
119- const auto & x_dims = x->dims ();
176+
177+ auto *d_x = ctx.Output <Tensor>(framework::GradVarName (" X" ));
178+ auto *d_scale = ctx.Output <Tensor>(framework::GradVarName (" Scale" ));
179+ auto *d_bias = ctx.Output <Tensor>(framework::GradVarName (" Bias" ));
180+
181+ use_global_stats = is_test || use_global_stats;
182+
183+ // batch_norm with inplace as false will take X as grad input, which
184+ // is same as cuDNN batch_norm backward calculation, batch_norm
185+ // with inplace as true only take Y as input and X should be calculate
186+ // by inverse operation of batch_norm on Y
187+ const Tensor *x;
188+ bool is_inplace;
189+ if (ctx.HasInput (" Y" )) {
190+ x = ctx.Input <Tensor>(" Y" );
191+ is_inplace = true ;
192+ // if the input of batch norm is stop_gradient, d_x is null.
193+ if (d_x) {
194+ PADDLE_ENFORCE_EQ (d_x, d_y,
195+ platform::errors::InvalidArgument (
196+ " X@GRAD and Y@GRAD not inplace in inplace mode" ));
197+ }
198+ } else {
199+ x = ctx.Input <Tensor>(" X" );
200+ is_inplace = false ;
201+ if (d_x) {
202+ PADDLE_ENFORCE_NE (
203+ d_x, d_y, platform::errors::InvalidArgument (
204+ " X@GRAD and Y@GRAD inplaced in non-inplace mode" ));
205+ }
206+ }
207+
208+ const auto &x_dims = x->dims ();
120209 PADDLE_ENFORCE_EQ (x_dims.size (), 4 ,
121210 platform::errors::InvalidArgument (
122211 " The input tensor X's dimension must equal to 4. But "
@@ -126,26 +215,96 @@ class BatchNormGradXPUKernel : public framework::OpKernel<T> {
126215 const int C = x_dims[1 ];
127216 const int H = x_dims[2 ];
128217 const int W = x_dims[3 ];
129- const auto * x_data = x->data <T>();
130- const auto * dy_data = dy->data <T>();
131- const auto * scale_data = scale->data <T>();
132- const auto * saved_mean_data = saved_mean->data <T>();
133- const auto * saved_inv_variance_data = saved_inv_variance->data <T>();
134- auto * dx = ctx.Output <Tensor>(framework::GradVarName (" X" ));
135- auto * dscale = ctx.Output <Tensor>(framework::GradVarName (" Scale" ));
136- auto * dbias = ctx.Output <Tensor>(framework::GradVarName (" Bias" ));
137- auto * dx_data = dx->mutable_data <T>(ctx.GetPlace ());
138- auto * dscale_data = dscale->mutable_data <T>(ctx.GetPlace ());
139- auto * dbias_data = dbias->mutable_data <T>(ctx.GetPlace ());
140- auto & dev_ctx = ctx.template device_context <DeviceContext>();
141- int r = xpu::batch_norm_grad<T>(dev_ctx.x_context (), x_data, dy_data,
142- dx_data, N, C, H, W, scale_data,
143- saved_mean_data, saved_inv_variance_data,
144- dscale_data, dbias_data, true );
145- PADDLE_ENFORCE_EQ (r, XPU_SUCCESS, platform::errors::External (
146- " XPU API(batch_norm_grad) return "
147- " wrong value[%d %s]" ,
148- r, XPUAPIErrorMsg[r]));
218+
219+ const auto *x_data = x->data <T>();
220+ const auto *d_y_data = d_y->data <T>();
221+ const auto *scale_data = scale->data <float >();
222+
223+ // init output
224+ T *d_x_data = nullptr ;
225+ T *d_bias_data = nullptr ;
226+ T *d_scale_data = nullptr ;
227+ if (d_x) {
228+ d_x_data = d_x->mutable_data <T>(ctx.GetPlace ());
229+ }
230+ if (d_scale && d_bias) {
231+ d_scale_data = d_scale->mutable_data <float >(ctx.GetPlace ());
232+ d_bias_data = d_bias->mutable_data <float >(ctx.GetPlace ());
233+ }
234+
235+ PADDLE_ENFORCE_EQ (
236+ scale->dims ().size (), 1UL ,
237+ platform::errors::InvalidArgument (
238+ " The size of scale's dimensions must equal to 1. But received: "
239+ " the size of scale's dimensions is [%d], the dimensions of scale "
240+ " is [%s]." ,
241+ scale->dims ().size (), scale->dims ()));
242+ PADDLE_ENFORCE_EQ (
243+ scale->dims ()[0 ], C,
244+ platform::errors::InvalidArgument (
245+ " The first dimension of scale must equal to Channels[%d]. But "
246+ " received: the first dimension of scale is [%d]" ,
247+ C, scale->dims ()[0 ]));
248+
249+ auto &dev_ctx = ctx.template device_context <DeviceContext>();
250+ xpu::ctx_guard RAII_GUARD (dev_ctx.x_context ());
251+
252+ const T *mean_data = nullptr ;
253+ const T *inv_var_data = nullptr ;
254+
255+ // TODO(guozibin): hadle the situation case of N * H * W = 1
256+ if (!use_global_stats) {
257+ const auto *saved_mean = ctx.Input <Tensor>(" SavedMean" );
258+ // SavedVariance have been reverted in forward operator
259+ const auto *saved_inv_variance = ctx.Input <Tensor>(" SavedVariance" );
260+ mean_data = saved_mean->data <float >();
261+ inv_var_data = saved_inv_variance->data <float >();
262+ } else {
263+ const auto *running_mean = ctx.Input <Tensor>(" Mean" );
264+ const auto *running_variance = ctx.Input <Tensor>(" Variance" );
265+ mean_data = running_mean->data <float >();
266+ inv_var_data = running_variance->data <float >();
267+ float *running_inv_var_data =
268+ RAII_GUARD.alloc_l3_or_gm <float >(running_variance->numel ());
269+ float *epsilon_data = RAII_GUARD.alloc_l3_or_gm <float >(1 );
270+ int r1 = calculate_inv_var (dev_ctx.x_context (), inv_var_data, epsilon, C,
271+ epsilon_data, running_inv_var_data);
272+ PADDLE_ENFORCE_EQ (r1, XPU_SUCCESS, platform::errors::External (
273+ " XPU API(batch_norm_grad "
274+ " calculate_inv_var function) "
275+ " return wrong value[%d %s]" ,
276+ r1, XPUAPIErrorMsg[r1]));
277+ inv_var_data = running_inv_var_data;
278+ }
279+ if (is_inplace) {
280+ auto px = *x;
281+ int r2 = calculate_inv_BN_Y (
282+ dev_ctx.x_context (), px.mutable_data <T>(ctx.GetPlace ()),
283+ scale->data <float >(), bias->data <float >(), mean_data, inv_var_data, N,
284+ C, H * W, x->data <T>());
285+ PADDLE_ENFORCE_EQ (r2, XPU_SUCCESS, platform::errors::External (
286+ " XPU API(batch_norm_grad "
287+ " calculate_inv_BN_Y function) "
288+ " return wrong value[%d %s]" ,
289+ r2, XPUAPIErrorMsg[r2]));
290+ }
291+ if (!d_x) {
292+ d_x_data = RAII_GUARD.alloc_l3_or_gm <T>(x->numel ());
293+ }
294+ if (!d_scale) {
295+ d_scale_data = RAII_GUARD.alloc_l3_or_gm <float >(C);
296+ }
297+ if (!d_bias_data) {
298+ d_bias_data = RAII_GUARD.alloc_l3_or_gm <float >(C);
299+ }
300+
301+ int r3 = xpu::batch_norm_grad<T>(
302+ dev_ctx.x_context (), x_data, d_y_data, d_x_data, N, C, H, W, scale_data,
303+ mean_data, inv_var_data, d_scale_data, d_bias_data, true );
304+ PADDLE_ENFORCE_EQ (r3, XPU_SUCCESS, platform::errors::External (
305+ " XPU API(batch_norm_grad) return "
306+ " wrong value[%d %s]" ,
307+ r3, XPUAPIErrorMsg[r3]));
149308 }
150309};
151310
0 commit comments