Skip to content

Commit cc83c95

Browse files
ZibinGuoZibin
andauthored
Fix the bug of batch_norm and batch_norm_grad op. (PaddlePaddle#38288)
* Fix the bug of batch_norm and batch_norm_grad op. Add the "roi_align" and "roi_align_grad" op in xpu2 op list. * Fix the bug of batch_norm and batch_norm_grad op. Add the "roi_align" and "roi_align_grad" op in xpu2 op list. test=kunlun Co-authored-by: Zibin <guozibin@baidu.com>
1 parent 9e0a03e commit cc83c95

File tree

3 files changed

+275
-60
lines changed

3 files changed

+275
-60
lines changed

paddle/fluid/operators/batch_norm_op_xpu.cc

Lines changed: 219 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ limitations under the License. */
1515
#ifdef PADDLE_WITH_XPU
1616

1717
#include "paddle/fluid/operators/batch_norm_op.h"
18+
#include <iterator>
19+
#include <vector>
1820

1921
namespace paddle {
2022
namespace operators {
@@ -25,23 +27,25 @@ using DDim = framework::DDim;
2527
template <typename DeviceContext, typename T>
2628
class BatchNormXPUKernel : public framework::OpKernel<T> {
2729
public:
28-
void Compute(const framework::ExecutionContext& ctx) const override {
30+
void Compute(const framework::ExecutionContext &ctx) const override {
2931
const auto epsilon = ctx.Attr<float>("epsilon");
30-
const auto momentum = ctx.Attr<float>("momentum");
32+
float momentum = ctx.Attr<float>("momentum");
3133
const auto is_test = ctx.Attr<bool>("is_test");
3234
const auto use_global_stats = ctx.Attr<bool>("use_global_stats");
3335
const auto trainable_stats = ctx.Attr<bool>("trainable_statistics");
3436
bool test_mode = is_test && (!trainable_stats);
37+
3538
bool global_stats = test_mode || use_global_stats;
36-
const auto& data_layout_str = ctx.Attr<std::string>("data_layout");
39+
const auto &data_layout_str = ctx.Attr<std::string>("data_layout");
3740
const auto data_layout = framework::StringToDataLayout(data_layout_str);
3841
PADDLE_ENFORCE_EQ(data_layout, DataLayout::kNCHW,
3942
platform::errors::InvalidArgument(
4043
"The 'data_layout' attribute must be NCHW. But "
4144
"recevived 'data_layout' is [%s].",
4245
data_layout_str));
43-
const auto* x = ctx.Input<Tensor>("X");
44-
const auto& x_dims = x->dims();
46+
47+
const auto *x = ctx.Input<Tensor>("X");
48+
const auto &x_dims = x->dims();
4549
PADDLE_ENFORCE_EQ(x_dims.size(), 4,
4650
platform::errors::InvalidArgument(
4751
"The input tensor X's dimension must equal to 4. But "
@@ -51,27 +55,42 @@ class BatchNormXPUKernel : public framework::OpKernel<T> {
5155
const int C = x_dims[1];
5256
const int H = x_dims[2];
5357
const int W = x_dims[3];
54-
const auto* scale = ctx.Input<Tensor>("Scale");
55-
const auto* bias = ctx.Input<Tensor>("Bias");
56-
const auto* x_data = x->data<T>();
57-
const auto* scale_data = scale->data<T>();
58-
const auto* bias_data = bias->data<T>();
59-
auto* y = ctx.Output<Tensor>("Y");
60-
auto* y_data = y->mutable_data<T>(ctx.GetPlace());
61-
auto& dev_ctx = ctx.template device_context<DeviceContext>();
58+
const auto *scale = ctx.Input<Tensor>("Scale");
59+
const auto *bias = ctx.Input<Tensor>("Bias");
60+
const auto *x_data = x->data<T>();
61+
const auto *scale_data = scale->data<float>();
62+
const auto *bias_data = bias->data<float>();
63+
64+
auto *y = ctx.Output<Tensor>("Y");
65+
auto *mean_out = ctx.Output<Tensor>("MeanOut");
66+
auto *variance_out = ctx.Output<Tensor>("VarianceOut");
67+
auto *saved_mean = ctx.Output<Tensor>("SavedMean");
68+
auto *saved_variance = ctx.Output<Tensor>("SavedVariance");
69+
70+
// alloc memory
71+
auto *y_data = y->mutable_data<T>(ctx.GetPlace());
72+
mean_out->mutable_data<float>(ctx.GetPlace());
73+
variance_out->mutable_data<float>(ctx.GetPlace());
74+
saved_mean->mutable_data<float>(ctx.GetPlace());
75+
saved_variance->mutable_data<float>(ctx.GetPlace());
76+
77+
auto &dev_ctx = ctx.template device_context<DeviceContext>();
78+
6279
if (!global_stats) {
63-
auto* mean_out = ctx.Output<Tensor>("MeanOut");
64-
auto* variance_out = ctx.Output<Tensor>("VarianceOut");
65-
auto* saved_mean = ctx.Output<Tensor>("SavedMean");
66-
auto* saved_variance = ctx.Output<Tensor>("SavedVariance");
67-
mean_out->mutable_data<T>(ctx.GetPlace());
68-
variance_out->mutable_data<T>(ctx.GetPlace());
69-
saved_mean->mutable_data<T>(ctx.GetPlace());
70-
saved_variance->mutable_data<T>(ctx.GetPlace());
71-
auto* mean_out_data = mean_out->data<T>();
72-
auto* variance_out_data = variance_out->data<T>();
73-
auto* saved_mean_data = saved_mean->data<T>();
74-
auto* saved_variance_data = saved_variance->data<T>();
80+
auto *mean_out_data = mean_out->data<float>();
81+
auto *variance_out_data = variance_out->data<float>();
82+
auto *saved_mean_data = saved_mean->data<float>();
83+
auto *saved_variance_data = saved_variance->data<float>();
84+
85+
// if MomentumTensor is set, use MomentumTensor value, momentum
86+
// is only used in this training branch
87+
if (ctx.HasInput("MomentumTensor")) {
88+
const auto *mom_tensor = ctx.Input<Tensor>("MomentumTensor");
89+
Tensor mom_cpu;
90+
TensorCopySync(*mom_tensor, platform::CPUPlace(), &mom_cpu);
91+
momentum = mom_tensor->data<float>()[0];
92+
}
93+
7594
int r = xpu::batch_norm<T>(dev_ctx.x_context(), x_data, y_data, N, C, H,
7695
W, epsilon, momentum, scale_data, bias_data,
7796
saved_mean_data, saved_variance_data,
@@ -81,12 +100,10 @@ class BatchNormXPUKernel : public framework::OpKernel<T> {
81100
"The batch_norm XPU API return wrong value[%d %s]",
82101
r, XPUAPIErrorMsg[r]));
83102
} else {
84-
const auto* mean = ctx.Input<Tensor>("Mean");
85-
const auto* variance = ctx.Input<Tensor>("Variance");
86-
const auto* mean_data = mean->data<float>();
87-
const auto* variance_data = variance->data<float>();
88-
const auto* x_data = x->data<float>();
89-
auto* y_data = y->mutable_data<float>(ctx.GetPlace());
103+
const auto *mean = ctx.Input<Tensor>("Mean");
104+
const auto *variance = ctx.Input<Tensor>("Variance");
105+
const auto *mean_data = mean->data<float>();
106+
const auto *variance_data = variance->data<float>();
90107
int r = xpu::batch_norm_infer(dev_ctx.x_context(), x_data, y_data, N, C,
91108
H, W, epsilon, scale_data, bias_data,
92109
mean_data, variance_data, true);
@@ -99,24 +116,96 @@ class BatchNormXPUKernel : public framework::OpKernel<T> {
99116
}
100117
};
101118

119+
template <typename T>
120+
static int calculate_inv_BN_Y(xpu::Context *ctx, T *x, const T *scale,
121+
const T *bias, const T *mean, const T *variance,
122+
const int N, const int C, const int M,
123+
const T *y) {
124+
PADDLE_ENFORCE_EQ(x, y, platform::errors::InvalidArgument(
125+
"X and Y should be inplaced in inplace mode"));
126+
std::vector<int> tensor_shape_vec({N, C, M});
127+
std::vector<int> array_shape_vec({1, C, 1});
128+
// y - bias
129+
int r1 =
130+
xpu::broadcast_sub<T>(ctx, bias, y, x, array_shape_vec, tensor_shape_vec);
131+
// (y - bias) / scale
132+
int r2 = xpu::broadcast_div<T>(ctx, scale, x, x, array_shape_vec,
133+
tensor_shape_vec);
134+
// (y - bias) / scale / variance
135+
int r3 = xpu::broadcast_div<T>(ctx, variance, x, x, array_shape_vec,
136+
tensor_shape_vec);
137+
// (y - bias) / scale / variance + mean
138+
int r4 =
139+
xpu::broadcast_add<T>(ctx, mean, x, x, array_shape_vec, tensor_shape_vec);
140+
141+
return r1 + r2 + r3 + r4;
142+
}
143+
144+
template <typename T>
145+
static int calculate_inv_var(xpu::Context *ctx, const T *var, const T epsilon,
146+
const int C, T *epsilon_data, T *inv_var) {
147+
int r1 = constant(ctx, epsilon_data, 1, epsilon);
148+
std::vector<int> tensor_shape_vec({C});
149+
std::vector<int> array_shape_vec({1});
150+
int r2 = xpu::broadcast_add<T>(ctx, epsilon_data, var, inv_var,
151+
array_shape_vec, tensor_shape_vec);
152+
int r3 = xpu::rsqrt<T>(ctx, inv_var, inv_var, C);
153+
return r1 + r2 + r3;
154+
}
155+
102156
template <typename DeviceContext, typename T>
103157
class BatchNormGradXPUKernel : public framework::OpKernel<T> {
104158
public:
105-
void Compute(const framework::ExecutionContext& ctx) const override {
106-
const auto* x = ctx.Input<Tensor>("X");
107-
const auto* dy = ctx.Input<Tensor>(framework::GradVarName("Y"));
108-
const auto* scale = ctx.Input<Tensor>("Scale");
109-
const auto* saved_mean = ctx.Input<Tensor>("SavedMean");
110-
// SavedVariance have been reverted in forward operator
111-
const auto* saved_inv_variance = ctx.Input<Tensor>("SavedVariance");
112-
const auto& data_layout_str = ctx.Attr<std::string>("data_layout");
159+
void Compute(const framework::ExecutionContext &ctx) const override {
160+
const auto *d_y = ctx.Input<Tensor>(framework::GradVarName("Y"));
161+
const auto *scale = ctx.Input<Tensor>("Scale");
162+
const auto *bias = ctx.Input<Tensor>("Bias");
163+
164+
const auto &data_layout_str = ctx.Attr<std::string>("data_layout");
165+
bool use_global_stats = ctx.Attr<bool>("use_global_stats");
166+
const bool is_test = ctx.Attr<bool>("is_test");
167+
const float epsilon = ctx.Attr<float>("epsilon");
113168
const auto data_layout = framework::StringToDataLayout(data_layout_str);
169+
170+
// TODO(guozbin): Transform input tensor from NHWC to NCHW
114171
PADDLE_ENFORCE_EQ(data_layout, DataLayout::kNCHW,
115172
platform::errors::InvalidArgument(
116173
"The 'data_layout' attribute must be NCHW. But "
117174
"recevived 'data_layout' is [%s].",
118175
data_layout_str));
119-
const auto& x_dims = x->dims();
176+
177+
auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
178+
auto *d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
179+
auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));
180+
181+
use_global_stats = is_test || use_global_stats;
182+
183+
// batch_norm with inplace as false will take X as grad input, which
184+
// is same as cuDNN batch_norm backward calculation, batch_norm
185+
// with inplace as true only take Y as input and X should be calculate
186+
// by inverse operation of batch_norm on Y
187+
const Tensor *x;
188+
bool is_inplace;
189+
if (ctx.HasInput("Y")) {
190+
x = ctx.Input<Tensor>("Y");
191+
is_inplace = true;
192+
// if the input of batch norm is stop_gradient, d_x is null.
193+
if (d_x) {
194+
PADDLE_ENFORCE_EQ(d_x, d_y,
195+
platform::errors::InvalidArgument(
196+
"X@GRAD and Y@GRAD not inplace in inplace mode"));
197+
}
198+
} else {
199+
x = ctx.Input<Tensor>("X");
200+
is_inplace = false;
201+
if (d_x) {
202+
PADDLE_ENFORCE_NE(
203+
d_x, d_y, platform::errors::InvalidArgument(
204+
"X@GRAD and Y@GRAD inplaced in non-inplace mode"));
205+
}
206+
}
207+
208+
const auto &x_dims = x->dims();
120209
PADDLE_ENFORCE_EQ(x_dims.size(), 4,
121210
platform::errors::InvalidArgument(
122211
"The input tensor X's dimension must equal to 4. But "
@@ -126,26 +215,96 @@ class BatchNormGradXPUKernel : public framework::OpKernel<T> {
126215
const int C = x_dims[1];
127216
const int H = x_dims[2];
128217
const int W = x_dims[3];
129-
const auto* x_data = x->data<T>();
130-
const auto* dy_data = dy->data<T>();
131-
const auto* scale_data = scale->data<T>();
132-
const auto* saved_mean_data = saved_mean->data<T>();
133-
const auto* saved_inv_variance_data = saved_inv_variance->data<T>();
134-
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
135-
auto* dscale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
136-
auto* dbias = ctx.Output<Tensor>(framework::GradVarName("Bias"));
137-
auto* dx_data = dx->mutable_data<T>(ctx.GetPlace());
138-
auto* dscale_data = dscale->mutable_data<T>(ctx.GetPlace());
139-
auto* dbias_data = dbias->mutable_data<T>(ctx.GetPlace());
140-
auto& dev_ctx = ctx.template device_context<DeviceContext>();
141-
int r = xpu::batch_norm_grad<T>(dev_ctx.x_context(), x_data, dy_data,
142-
dx_data, N, C, H, W, scale_data,
143-
saved_mean_data, saved_inv_variance_data,
144-
dscale_data, dbias_data, true);
145-
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, platform::errors::External(
146-
"XPU API(batch_norm_grad) return "
147-
"wrong value[%d %s]",
148-
r, XPUAPIErrorMsg[r]));
218+
219+
const auto *x_data = x->data<T>();
220+
const auto *d_y_data = d_y->data<T>();
221+
const auto *scale_data = scale->data<float>();
222+
223+
// init output
224+
T *d_x_data = nullptr;
225+
T *d_bias_data = nullptr;
226+
T *d_scale_data = nullptr;
227+
if (d_x) {
228+
d_x_data = d_x->mutable_data<T>(ctx.GetPlace());
229+
}
230+
if (d_scale && d_bias) {
231+
d_scale_data = d_scale->mutable_data<float>(ctx.GetPlace());
232+
d_bias_data = d_bias->mutable_data<float>(ctx.GetPlace());
233+
}
234+
235+
PADDLE_ENFORCE_EQ(
236+
scale->dims().size(), 1UL,
237+
platform::errors::InvalidArgument(
238+
"The size of scale's dimensions must equal to 1. But received: "
239+
"the size of scale's dimensions is [%d], the dimensions of scale "
240+
"is [%s].",
241+
scale->dims().size(), scale->dims()));
242+
PADDLE_ENFORCE_EQ(
243+
scale->dims()[0], C,
244+
platform::errors::InvalidArgument(
245+
"The first dimension of scale must equal to Channels[%d]. But "
246+
"received: the first dimension of scale is [%d]",
247+
C, scale->dims()[0]));
248+
249+
auto &dev_ctx = ctx.template device_context<DeviceContext>();
250+
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
251+
252+
const T *mean_data = nullptr;
253+
const T *inv_var_data = nullptr;
254+
255+
// TODO(guozibin): hadle the situation case of N * H * W = 1
256+
if (!use_global_stats) {
257+
const auto *saved_mean = ctx.Input<Tensor>("SavedMean");
258+
// SavedVariance have been reverted in forward operator
259+
const auto *saved_inv_variance = ctx.Input<Tensor>("SavedVariance");
260+
mean_data = saved_mean->data<float>();
261+
inv_var_data = saved_inv_variance->data<float>();
262+
} else {
263+
const auto *running_mean = ctx.Input<Tensor>("Mean");
264+
const auto *running_variance = ctx.Input<Tensor>("Variance");
265+
mean_data = running_mean->data<float>();
266+
inv_var_data = running_variance->data<float>();
267+
float *running_inv_var_data =
268+
RAII_GUARD.alloc_l3_or_gm<float>(running_variance->numel());
269+
float *epsilon_data = RAII_GUARD.alloc_l3_or_gm<float>(1);
270+
int r1 = calculate_inv_var(dev_ctx.x_context(), inv_var_data, epsilon, C,
271+
epsilon_data, running_inv_var_data);
272+
PADDLE_ENFORCE_EQ(r1, XPU_SUCCESS, platform::errors::External(
273+
"XPU API(batch_norm_grad "
274+
"calculate_inv_var function) "
275+
"return wrong value[%d %s]",
276+
r1, XPUAPIErrorMsg[r1]));
277+
inv_var_data = running_inv_var_data;
278+
}
279+
if (is_inplace) {
280+
auto px = *x;
281+
int r2 = calculate_inv_BN_Y(
282+
dev_ctx.x_context(), px.mutable_data<T>(ctx.GetPlace()),
283+
scale->data<float>(), bias->data<float>(), mean_data, inv_var_data, N,
284+
C, H * W, x->data<T>());
285+
PADDLE_ENFORCE_EQ(r2, XPU_SUCCESS, platform::errors::External(
286+
"XPU API(batch_norm_grad "
287+
"calculate_inv_BN_Y function) "
288+
"return wrong value[%d %s]",
289+
r2, XPUAPIErrorMsg[r2]));
290+
}
291+
if (!d_x) {
292+
d_x_data = RAII_GUARD.alloc_l3_or_gm<T>(x->numel());
293+
}
294+
if (!d_scale) {
295+
d_scale_data = RAII_GUARD.alloc_l3_or_gm<float>(C);
296+
}
297+
if (!d_bias_data) {
298+
d_bias_data = RAII_GUARD.alloc_l3_or_gm<float>(C);
299+
}
300+
301+
int r3 = xpu::batch_norm_grad<T>(
302+
dev_ctx.x_context(), x_data, d_y_data, d_x_data, N, C, H, W, scale_data,
303+
mean_data, inv_var_data, d_scale_data, d_bias_data, true);
304+
PADDLE_ENFORCE_EQ(r3, XPU_SUCCESS, platform::errors::External(
305+
"XPU API(batch_norm_grad) return "
306+
"wrong value[%d %s]",
307+
r3, XPUAPIErrorMsg[r3]));
149308
}
150309
};
151310

paddle/fluid/platform/device/xpu/xpu2_op_list.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,9 @@ XPUOpMap& get_kl2_ops() {
262262
pOpKernelType(vartype::INT32, XPUPlace()),
263263
pOpKernelType(vartype::BOOL, XPUPlace()),
264264
pOpKernelType(vartype::FP32, XPUPlace())})},
265+
{"roi_align", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
266+
{"roi_align_grad",
267+
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
265268
{"scale", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
266269
{"scale", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
267270
pOpKernelType(vartype::FP16, XPUPlace()),

0 commit comments

Comments
 (0)