@@ -20,142 +20,6 @@ limitations under the License. */
20
20
namespace paddle {
21
21
namespace operators {
22
22
23
- template <typename DeviceContext, typename T>
24
- void default_elementwise_sub (const framework::ExecutionContext& ctx,
25
- const framework::Tensor* x,
26
- const framework::Tensor* y, framework::Tensor* z) {
27
- int axis = ctx.Attr <int >(" axis" );
28
- auto x_dims = x->dims ();
29
- auto y_dims = y->dims ();
30
- if (x_dims.size () >= y_dims.size ()) {
31
- ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
32
- SubFunctor<T>(), z);
33
- } else {
34
- ElementwiseComputeEx<InverseSubFunctor<T>, DeviceContext, T>(
35
- ctx, x, y, axis, InverseSubFunctor<T>(), z);
36
- }
37
- }
38
-
39
- template <typename DeviceContext, typename T>
40
- void default_elementwise_div (const framework::ExecutionContext& ctx,
41
- const framework::Tensor* x,
42
- const framework::Tensor* y, framework::Tensor* z) {
43
- int axis = ctx.Attr <int >(" axis" );
44
- auto x_dims = x->dims ();
45
- auto y_dims = y->dims ();
46
- if (x_dims.size () >= y_dims.size ()) {
47
- ElementwiseComputeEx<DivFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
48
- DivFunctor<T>(), z);
49
- } else {
50
- ElementwiseComputeEx<InverseDivFunctor<T>, DeviceContext, T>(
51
- ctx, x, y, axis, InverseDivFunctor<T>(), z);
52
- }
53
- }
54
-
55
- template <typename DeviceContext, typename T>
56
- class ElementwiseDivKernel : public framework ::OpKernel<T> {
57
- public:
58
- void Compute (const framework::ExecutionContext& ctx) const override {
59
- auto * x = ctx.Input <framework::LoDTensor>(" X" );
60
- auto * y = ctx.Input <framework::LoDTensor>(" Y" );
61
- auto * z = ctx.Output <framework::LoDTensor>(" Out" );
62
- z->mutable_data <T>(ctx.GetPlace ());
63
-
64
- auto & dev_ctx = ctx.device_context <DeviceContext>();
65
- int axis = ctx.Attr <int >(" axis" );
66
- auto pt_x = paddle::experimental::MakePhiDenseTensor (*x);
67
- auto pt_y = paddle::experimental::MakePhiDenseTensor (*y);
68
- auto pt_z = paddle::experimental::MakePhiDenseTensor (*z);
69
- phi::DivideRawKernel<T>(
70
- static_cast <const typename framework::ConvertToPhiContext<
71
- DeviceContext>::TYPE&>(dev_ctx),
72
- *pt_x.get (), *pt_y.get (), axis, pt_z.get ());
73
- }
74
- };
75
-
76
- template <typename T>
77
- struct DivGradDX {
78
- HOSTDEVICE T operator ()(T x, T y, T out, T dout) const { return dout / y; }
79
- };
80
-
81
- template <typename T>
82
- struct DivGradDX <paddle::platform::complex<T>> {
83
- HOSTDEVICE paddle::platform::complex<T> operator ()(
84
- paddle::platform::complex<T> x, paddle::platform::complex<T> y,
85
- paddle::platform::complex<T> out,
86
- paddle::platform::complex<T> dout) const {
87
- paddle::platform::complex<T> y_conj (y.real , -y.imag );
88
- return dout / y_conj;
89
- }
90
- };
91
-
92
- template <typename T>
93
- struct DivGradDY {
94
- HOSTDEVICE T operator ()(T x, T y, T out, T dout) const {
95
- return -dout * out / y;
96
- }
97
- };
98
-
99
- template <typename T>
100
- struct DivGradDY <paddle::platform::complex<T>> {
101
- HOSTDEVICE paddle::platform::complex<T> operator ()(
102
- paddle::platform::complex<T> x, paddle::platform::complex<T> y,
103
- paddle::platform::complex<T> out,
104
- paddle::platform::complex<T> dout) const {
105
- paddle::platform::complex<T> out_div_y_conj ((out / y).real ,
106
- -(out / y).imag );
107
- return -dout * out_div_y_conj;
108
- }
109
- };
110
-
111
- template <typename T>
112
- struct DivDoubleDY {
113
- HOSTDEVICE T operator ()(T x, T y, T out, T dout) const {
114
- return y * out * dout - x * dout;
115
- }
116
- };
117
-
118
- template <typename DeviceContext, typename T>
119
- typename std::enable_if<
120
- std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
121
- ElementwiseDivGrad (const framework::ExecutionContext& ctx,
122
- const framework::Tensor* x, const framework::Tensor* y,
123
- const framework::Tensor* out, const framework::Tensor* dout,
124
- framework::Tensor* dx, framework::Tensor* dy) {
125
- int axis = ctx.Attr <int >(" axis" );
126
-
127
- ElemwiseGradCompute<DeviceContext, T, DivGradDX<T>, DivGradDY<T>>(
128
- ctx, *x, *y, *out, *dout, axis, dx, dy, DivGradDX<T>(), DivGradDY<T>());
129
- }
130
-
131
- #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
132
- template <typename DeviceContext, typename T>
133
- typename std::enable_if<
134
- std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type
135
- ElementwiseDivGrad (const framework::ExecutionContext& ctx,
136
- const framework::Tensor* x, const framework::Tensor* y,
137
- const framework::Tensor* out, const framework::Tensor* dout,
138
- framework::Tensor* dx, framework::Tensor* dy);
139
- #endif
140
-
141
- template <typename DeviceContext, typename T>
142
- class ElementwiseDivGradKernel : public ElemwiseGradKernel <T> {
143
- public:
144
- void Compute (const framework::ExecutionContext& ctx) const override {
145
- ElemwiseGradKernel<T>::Compute (ctx);
146
- using Tensor = framework::Tensor;
147
-
148
- auto * x = ctx.Input <Tensor>(" X" );
149
- auto * y = ctx.Input <Tensor>(" Y" );
150
- auto * out = ctx.Input <Tensor>(" Out" );
151
- auto * dout = ctx.Input <Tensor>(framework::GradVarName (" Out" ));
152
- auto * dx = ctx.Output <Tensor>(framework::GradVarName (" X" ));
153
- auto * dy = ctx.Output <Tensor>(framework::GradVarName (" Y" ));
154
-
155
- ElementwiseDivGrad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
156
- }
157
- };
158
-
159
23
class ElementwiseDivOpDoubleGrad : public framework ::OperatorWithKernel {
160
24
public:
161
25
using framework::OperatorWithKernel::OperatorWithKernel;
@@ -206,80 +70,5 @@ class ElementwiseDivOpDoubleGrad : public framework::OperatorWithKernel {
206
70
}
207
71
};
208
72
209
- template <typename DeviceContext, typename T>
210
- class ElementwiseDivDoubleGradKernel : public framework ::OpKernel<T> {
211
- using Tensor = framework::Tensor;
212
-
213
- public:
214
- void Compute (const framework::ExecutionContext& ctx) const override {
215
- auto * Y = ctx.Input <Tensor>(" Y" );
216
- auto * Out = ctx.Input <Tensor>(" Out" );
217
- auto * ddX = ctx.Input <Tensor>(" DDX" );
218
- auto * ddY = ctx.Input <Tensor>(" DDY" );
219
- auto * dX = ctx.Input <Tensor>(" DX" );
220
-
221
- auto * dY = ctx.Output <Tensor>(framework::GradVarName (" Y" ));
222
- auto * dOut = ctx.Output <Tensor>(" DOut" );
223
- auto * ddOut = ctx.Output <Tensor>(" DDOut" );
224
-
225
- int axis = ctx.Attr <int >(" axis" );
226
-
227
- if (dY) dY->mutable_data <T>(Y->dims (), ctx.GetPlace ());
228
- if (dOut) dOut->mutable_data <T>(Out->dims (), ctx.GetPlace ());
229
- if (ddOut) ddOut->mutable_data <T>(Out->dims (), ctx.GetPlace ());
230
-
231
- // ddX_safe == null ? 0 : ddX
232
- // ddY_safe == null ? 0 : ddY
233
- Tensor ddX_safe, ddY_safe;
234
- GetDoubleGradSafeTensor<DeviceContext, T>(ctx, dX, ddX, &ddX_safe);
235
- GetDoubleGradSafeTensor<DeviceContext, T>(ctx, Y, ddY, &ddY_safe);
236
-
237
- // ddOut = ddX / Y - Out * ddY / Y = (ddX - Out * ddY) / Y
238
- // dY = Out * dX * ddY / Y - dX * ddX / Y
239
- // dOut = - dX * ddY
240
- // To save memory, (1) dout can be used as 'tmp' tensor, (2) ddout can
241
- // inplace ddx
242
- Tensor tmp;
243
- if (dOut) {
244
- tmp = *dOut;
245
- } else {
246
- auto & dev_ctx = ctx.template device_context <DeviceContext>();
247
- tmp = ctx.AllocateTmpTensor <T, DeviceContext>(Out->dims (), dev_ctx);
248
- }
249
- if (dY) {
250
- // dX_div_Y = dX / Y;
251
- Tensor dX_div_Y = tmp;
252
- default_elementwise_div<DeviceContext, T>(ctx, dX, Y, &dX_div_Y);
253
-
254
- // NOTE(dengkaipeng): in the following ElemwiseGradCompute, for the
255
- // first output tensor is nullptr, the branch to calculate first
256
- // output tensor will not be activated, DivGradDx function will not
257
- // be called and can be ignored, the first branch has little effect
258
- // on running speed.
259
-
260
- // dY = Out * dX * ddY / Y - dX * ddX / Y
261
- ElemwiseGradCompute<DeviceContext, T, DivGradDX<T>, DivDoubleDY<T>>(
262
- ctx, ddX_safe, ddY_safe, *Out, dX_div_Y, axis, nullptr , dY,
263
- DivGradDX<T>(), DivDoubleDY<T>());
264
- }
265
-
266
- if (ddOut) {
267
- // ddOut = ddX / Y - Out * ddY / Y = (ddX - Out * ddY) / Y
268
- default_elementwise_mul<DeviceContext, T>(ctx, Out, &ddY_safe, &tmp);
269
- default_elementwise_sub<DeviceContext, T>(ctx, &ddX_safe, &tmp, &tmp);
270
- default_elementwise_div<DeviceContext, T>(ctx, &tmp, Y, ddOut);
271
- }
272
-
273
- if (dOut) {
274
- // dOut = - dX * ddY
275
- default_elementwise_mul<DeviceContext, T>(ctx, dX, &ddY_safe, dOut);
276
- auto & place =
277
- *ctx.template device_context <DeviceContext>().eigen_device ();
278
- auto dout = framework::EigenVector<T>::Flatten (*dOut);
279
- dout.device (place) = static_cast <T>(-1 ) * dout;
280
- }
281
- }
282
- };
283
-
284
73
} // namespace operators
285
74
} // namespace paddle
0 commit comments