Skip to content

Commit 0fedb3e

Browse files
committed
fix bug of ew_add_grad when inplace, *test=kunlun
1 parent 4004935 commit 0fedb3e

File tree

1 file changed

+18
-10
lines changed

1 file changed

+18
-10
lines changed

paddle/fluid/operators/elementwise/elementwise_add_op_xpu.cc

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,8 @@
1-
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2-
1+
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
32
Licensed under the Apache License, Version 2.0 (the "License");
43
you may not use this file except in compliance with the License.
54
You may obtain a copy of the License at
6-
75
http://www.apache.org/licenses/LICENSE-2.0
8-
96
Unless required by applicable law or agreed to in writing, software
107
distributed under the License is distributed on an "AS IS" BASIS,
118
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@@ -52,9 +49,15 @@ class ElementwiseAddGradXPUKernel : public ElemwiseGradKernel<T> {
5249
ctx.template device_context<paddle::platform::XPUDeviceContext>();
5350

5451
if (dx != nullptr) {
52+
T* dx_data = dx->mutable_data<T>(ctx.GetPlace());
5553
if (dx->dims() == dz_dims) {
56-
// TensorCopy will call mutable_data of dx.
57-
framework::TensorCopy(*dz, ctx.GetPlace(), dev_ctx, dx);
54+
if (dx_data != dz_data) {
55+
framework::TensorCopy(
56+
*dz,
57+
ctx.GetPlace(),
58+
ctx.template device_context<platform::DeviceContext>(),
59+
dx);
60+
}
5861
} else {
5962
// For inplace strategy, dx will be stored in addr of dz, which makes
6063
// the result of dy wrong.
@@ -65,24 +68,29 @@ class ElementwiseAddGradXPUKernel : public ElemwiseGradKernel<T> {
6568
std::vector<int> reduce_dims = GetReduceDim(dx->dims(), dz_dims, axis);
6669
std::vector<int> dz_vector = phi::vectorize<int>(dz_dims);
6770

68-
T* dx_data = dx->mutable_data<T>(ctx.GetPlace());
6971
int ret =
7072
xpu::reduce_sum<XPUType>(dev_ctx.x_context(),
7173
reinterpret_cast<const XPUType*>(dz_data),
72-
reinterpret_cast<XPUType*>(dx_data),
74+
reinterpret_cast<XPUType*>(dx->data<T>()),
7375
dz_vector,
7476
reduce_dims);
7577
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "reduce_sum");
7678
}
7779
}
7880

7981
if (dy != nullptr) {
82+
T* dy_data = dy->mutable_data<T>(ctx.GetPlace());
8083
if (dy->dims() == dz_dims) {
81-
framework::TensorCopy(*dz, ctx.GetPlace(), dev_ctx, dy);
84+
if (dy_data != dz_data) {
85+
framework::TensorCopy(
86+
*dz,
87+
ctx.GetPlace(),
88+
ctx.template device_context<platform::DeviceContext>(),
89+
dy);
90+
}
8291
} else {
8392
std::vector<int> reduce_dims = GetReduceDim(dy->dims(), dz_dims, axis);
8493
std::vector<int> dz_vector = phi::vectorize<int>(dz_dims);
85-
T* dy_data = dy->mutable_data<T>(ctx.GetPlace());
8694
int ret =
8795
xpu::reduce_sum<XPUType>(dev_ctx.x_context(),
8896
reinterpret_cast<const XPUType*>(dz_data),

0 commit comments

Comments
 (0)