1- /*  Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2- 
1+ /*  Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
32Licensed under the Apache License, Version 2.0 (the "License"); 
43you may not use this file except in compliance with the License. 
54You may obtain a copy of the License at 
6- 
75    http://www.apache.org/licenses/LICENSE-2.0 
8- 
96Unless required by applicable law or agreed to in writing, software 
107distributed under the License is distributed on an "AS IS" BASIS, 
118WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
@@ -52,9 +49,15 @@ class ElementwiseAddGradXPUKernel : public ElemwiseGradKernel<T> {
5249        ctx.template  device_context <paddle::platform::XPUDeviceContext>();
5350
5451    if  (dx != nullptr ) {
52+       T* dx_data = dx->mutable_data <T>(ctx.GetPlace ());
5553      if  (dx->dims () == dz_dims) {
56-         //  TensorCopy will call mutable_data of dx.
57-         framework::TensorCopy (*dz, ctx.GetPlace (), dev_ctx, dx);
54+         if  (dx_data != dz_data) {
55+           framework::TensorCopy (
56+               *dz,
57+               ctx.GetPlace (),
58+               ctx.template  device_context <platform::DeviceContext>(),
59+               dx);
60+         }
5861      } else  {
5962        //  For inplace strategy, dx will be stored in addr of dz, which makes
6063        //  the result of dy wrong.
@@ -65,24 +68,29 @@ class ElementwiseAddGradXPUKernel : public ElemwiseGradKernel<T> {
6568        std::vector<int > reduce_dims = GetReduceDim (dx->dims (), dz_dims, axis);
6669        std::vector<int > dz_vector = phi::vectorize<int >(dz_dims);
6770
68-         T* dx_data = dx->mutable_data <T>(ctx.GetPlace ());
6971        int  ret =
7072            xpu::reduce_sum<XPUType>(dev_ctx.x_context (),
7173                                     reinterpret_cast <const  XPUType*>(dz_data),
72-                                      reinterpret_cast <XPUType*>(dx_data ),
74+                                      reinterpret_cast <XPUType*>(dx-> data <T>() ),
7375                                     dz_vector,
7476                                     reduce_dims);
7577        PADDLE_ENFORCE_XDNN_SUCCESS (ret, " reduce_sum" 
7678      }
7779    }
7880
7981    if  (dy != nullptr ) {
82+       T* dy_data = dy->mutable_data <T>(ctx.GetPlace ());
8083      if  (dy->dims () == dz_dims) {
81-         framework::TensorCopy (*dz, ctx.GetPlace (), dev_ctx, dy);
84+         if  (dy_data != dz_data) {
85+           framework::TensorCopy (
86+               *dz,
87+               ctx.GetPlace (),
88+               ctx.template  device_context <platform::DeviceContext>(),
89+               dy);
90+         }
8291      } else  {
8392        std::vector<int > reduce_dims = GetReduceDim (dy->dims (), dz_dims, axis);
8493        std::vector<int > dz_vector = phi::vectorize<int >(dz_dims);
85-         T* dy_data = dy->mutable_data <T>(ctx.GetPlace ());
8694        int  ret =
8795            xpu::reduce_sum<XPUType>(dev_ctx.x_context (),
8896                                     reinterpret_cast <const  XPUType*>(dz_data),
0 commit comments