Skip to content

Commit c52a664

Browse files
authored
[Phi]Move elementwise_div grad/double grad Kernel to Phi (#40172)
* move elementwise_div grad * change mutable_data to alloc * fix compile bugs
1 parent 0fb6bca commit c52a664

17 files changed

+547
-472
lines changed

paddle/fluid/framework/new_executor/standalone_executor_test.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ USE_OP(slice_grad);
5454
USE_OP(lookup_table_grad);
5555
USE_OP(sqrt);
5656
USE_OP(elementwise_max);
57-
USE_OP(elementwise_div);
57+
USE_OP_ITSELF(elementwise_div);
5858
USE_OP(sgd);
5959
USE_OP(squared_l2_norm);
6060
USE_OP(memcpy_h2d);

paddle/fluid/operators/elementwise/elementwise_div_op.cc

-36
Original file line numberDiff line numberDiff line change
@@ -102,42 +102,6 @@ REGISTER_OPERATOR(
102102
REGISTER_OPERATOR(elementwise_div_grad_grad, ops::ElementwiseDivOpDoubleGrad,
103103
ops::ElementwiseDoubleGradOpInplaceInferer);
104104

105-
REGISTER_OP_CPU_KERNEL(
106-
elementwise_div,
107-
ops::ElementwiseDivKernel<paddle::platform::CPUDeviceContext, float>,
108-
ops::ElementwiseDivKernel<paddle::platform::CPUDeviceContext, double>,
109-
ops::ElementwiseDivKernel<paddle::platform::CPUDeviceContext, int>,
110-
ops::ElementwiseDivKernel<paddle::platform::CPUDeviceContext, int64_t>,
111-
ops::ElementwiseDivKernel<paddle::platform::CPUDeviceContext,
112-
paddle::platform::complex<float>>,
113-
ops::ElementwiseDivKernel<paddle::platform::CPUDeviceContext,
114-
paddle::platform::complex<double>>);
115-
REGISTER_OP_CPU_KERNEL(
116-
elementwise_div_grad,
117-
ops::ElementwiseDivGradKernel<paddle::platform::CPUDeviceContext, float>,
118-
ops::ElementwiseDivGradKernel<paddle::platform::CPUDeviceContext, double>,
119-
ops::ElementwiseDivGradKernel<paddle::platform::CPUDeviceContext, int>,
120-
ops::ElementwiseDivGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
121-
ops::ElementwiseDivGradKernel<paddle::platform::CPUDeviceContext,
122-
paddle::platform::complex<float>>,
123-
ops::ElementwiseDivGradKernel<paddle::platform::CPUDeviceContext,
124-
paddle::platform::complex<double>>);
125-
126-
REGISTER_OP_CPU_KERNEL(
127-
elementwise_div_grad_grad,
128-
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CPUDeviceContext,
129-
float>,
130-
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CPUDeviceContext,
131-
double>,
132-
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CPUDeviceContext,
133-
int>,
134-
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CPUDeviceContext,
135-
int64_t>,
136-
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CPUDeviceContext,
137-
paddle::platform::complex<float>>,
138-
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CPUDeviceContext,
139-
paddle::platform::complex<double>>);
140-
141105
REGISTER_OP_VERSION(elementwise_div)
142106
.AddCheckpoint(
143107
R"ROC(Register elementwise_div for adding the attribute of Scale_y)ROC",

paddle/fluid/operators/elementwise/elementwise_div_op.cu

-96
This file was deleted.

paddle/fluid/operators/elementwise/elementwise_div_op.h

-211
Original file line numberDiff line numberDiff line change
@@ -20,142 +20,6 @@ limitations under the License. */
2020
namespace paddle {
2121
namespace operators {
2222

23-
template <typename DeviceContext, typename T>
24-
void default_elementwise_sub(const framework::ExecutionContext& ctx,
25-
const framework::Tensor* x,
26-
const framework::Tensor* y, framework::Tensor* z) {
27-
int axis = ctx.Attr<int>("axis");
28-
auto x_dims = x->dims();
29-
auto y_dims = y->dims();
30-
if (x_dims.size() >= y_dims.size()) {
31-
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
32-
SubFunctor<T>(), z);
33-
} else {
34-
ElementwiseComputeEx<InverseSubFunctor<T>, DeviceContext, T>(
35-
ctx, x, y, axis, InverseSubFunctor<T>(), z);
36-
}
37-
}
38-
39-
template <typename DeviceContext, typename T>
40-
void default_elementwise_div(const framework::ExecutionContext& ctx,
41-
const framework::Tensor* x,
42-
const framework::Tensor* y, framework::Tensor* z) {
43-
int axis = ctx.Attr<int>("axis");
44-
auto x_dims = x->dims();
45-
auto y_dims = y->dims();
46-
if (x_dims.size() >= y_dims.size()) {
47-
ElementwiseComputeEx<DivFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
48-
DivFunctor<T>(), z);
49-
} else {
50-
ElementwiseComputeEx<InverseDivFunctor<T>, DeviceContext, T>(
51-
ctx, x, y, axis, InverseDivFunctor<T>(), z);
52-
}
53-
}
54-
55-
template <typename DeviceContext, typename T>
56-
class ElementwiseDivKernel : public framework::OpKernel<T> {
57-
public:
58-
void Compute(const framework::ExecutionContext& ctx) const override {
59-
auto* x = ctx.Input<framework::LoDTensor>("X");
60-
auto* y = ctx.Input<framework::LoDTensor>("Y");
61-
auto* z = ctx.Output<framework::LoDTensor>("Out");
62-
z->mutable_data<T>(ctx.GetPlace());
63-
64-
auto& dev_ctx = ctx.device_context<DeviceContext>();
65-
int axis = ctx.Attr<int>("axis");
66-
auto pt_x = paddle::experimental::MakePhiDenseTensor(*x);
67-
auto pt_y = paddle::experimental::MakePhiDenseTensor(*y);
68-
auto pt_z = paddle::experimental::MakePhiDenseTensor(*z);
69-
phi::DivideRawKernel<T>(
70-
static_cast<const typename framework::ConvertToPhiContext<
71-
DeviceContext>::TYPE&>(dev_ctx),
72-
*pt_x.get(), *pt_y.get(), axis, pt_z.get());
73-
}
74-
};
75-
76-
template <typename T>
77-
struct DivGradDX {
78-
HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout / y; }
79-
};
80-
81-
template <typename T>
82-
struct DivGradDX<paddle::platform::complex<T>> {
83-
HOSTDEVICE paddle::platform::complex<T> operator()(
84-
paddle::platform::complex<T> x, paddle::platform::complex<T> y,
85-
paddle::platform::complex<T> out,
86-
paddle::platform::complex<T> dout) const {
87-
paddle::platform::complex<T> y_conj(y.real, -y.imag);
88-
return dout / y_conj;
89-
}
90-
};
91-
92-
template <typename T>
93-
struct DivGradDY {
94-
HOSTDEVICE T operator()(T x, T y, T out, T dout) const {
95-
return -dout * out / y;
96-
}
97-
};
98-
99-
template <typename T>
100-
struct DivGradDY<paddle::platform::complex<T>> {
101-
HOSTDEVICE paddle::platform::complex<T> operator()(
102-
paddle::platform::complex<T> x, paddle::platform::complex<T> y,
103-
paddle::platform::complex<T> out,
104-
paddle::platform::complex<T> dout) const {
105-
paddle::platform::complex<T> out_div_y_conj((out / y).real,
106-
-(out / y).imag);
107-
return -dout * out_div_y_conj;
108-
}
109-
};
110-
111-
template <typename T>
112-
struct DivDoubleDY {
113-
HOSTDEVICE T operator()(T x, T y, T out, T dout) const {
114-
return y * out * dout - x * dout;
115-
}
116-
};
117-
118-
template <typename DeviceContext, typename T>
119-
typename std::enable_if<
120-
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
121-
ElementwiseDivGrad(const framework::ExecutionContext& ctx,
122-
const framework::Tensor* x, const framework::Tensor* y,
123-
const framework::Tensor* out, const framework::Tensor* dout,
124-
framework::Tensor* dx, framework::Tensor* dy) {
125-
int axis = ctx.Attr<int>("axis");
126-
127-
ElemwiseGradCompute<DeviceContext, T, DivGradDX<T>, DivGradDY<T>>(
128-
ctx, *x, *y, *out, *dout, axis, dx, dy, DivGradDX<T>(), DivGradDY<T>());
129-
}
130-
131-
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
132-
template <typename DeviceContext, typename T>
133-
typename std::enable_if<
134-
std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type
135-
ElementwiseDivGrad(const framework::ExecutionContext& ctx,
136-
const framework::Tensor* x, const framework::Tensor* y,
137-
const framework::Tensor* out, const framework::Tensor* dout,
138-
framework::Tensor* dx, framework::Tensor* dy);
139-
#endif
140-
141-
template <typename DeviceContext, typename T>
142-
class ElementwiseDivGradKernel : public ElemwiseGradKernel<T> {
143-
public:
144-
void Compute(const framework::ExecutionContext& ctx) const override {
145-
ElemwiseGradKernel<T>::Compute(ctx);
146-
using Tensor = framework::Tensor;
147-
148-
auto* x = ctx.Input<Tensor>("X");
149-
auto* y = ctx.Input<Tensor>("Y");
150-
auto* out = ctx.Input<Tensor>("Out");
151-
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
152-
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
153-
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
154-
155-
ElementwiseDivGrad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
156-
}
157-
};
158-
15923
class ElementwiseDivOpDoubleGrad : public framework::OperatorWithKernel {
16024
public:
16125
using framework::OperatorWithKernel::OperatorWithKernel;
@@ -206,80 +70,5 @@ class ElementwiseDivOpDoubleGrad : public framework::OperatorWithKernel {
20670
}
20771
};
20872

209-
template <typename DeviceContext, typename T>
210-
class ElementwiseDivDoubleGradKernel : public framework::OpKernel<T> {
211-
using Tensor = framework::Tensor;
212-
213-
public:
214-
void Compute(const framework::ExecutionContext& ctx) const override {
215-
auto* Y = ctx.Input<Tensor>("Y");
216-
auto* Out = ctx.Input<Tensor>("Out");
217-
auto* ddX = ctx.Input<Tensor>("DDX");
218-
auto* ddY = ctx.Input<Tensor>("DDY");
219-
auto* dX = ctx.Input<Tensor>("DX");
220-
221-
auto* dY = ctx.Output<Tensor>(framework::GradVarName("Y"));
222-
auto* dOut = ctx.Output<Tensor>("DOut");
223-
auto* ddOut = ctx.Output<Tensor>("DDOut");
224-
225-
int axis = ctx.Attr<int>("axis");
226-
227-
if (dY) dY->mutable_data<T>(Y->dims(), ctx.GetPlace());
228-
if (dOut) dOut->mutable_data<T>(Out->dims(), ctx.GetPlace());
229-
if (ddOut) ddOut->mutable_data<T>(Out->dims(), ctx.GetPlace());
230-
231-
// ddX_safe == null ? 0 : ddX
232-
// ddY_safe == null ? 0 : ddY
233-
Tensor ddX_safe, ddY_safe;
234-
GetDoubleGradSafeTensor<DeviceContext, T>(ctx, dX, ddX, &ddX_safe);
235-
GetDoubleGradSafeTensor<DeviceContext, T>(ctx, Y, ddY, &ddY_safe);
236-
237-
// ddOut = ddX / Y - Out * ddY / Y = (ddX - Out * ddY) / Y
238-
// dY = Out * dX * ddY / Y - dX * ddX / Y
239-
// dOut = - dX * ddY
240-
// To save memory, (1) dout can be used as 'tmp' tensor, (2) ddout can
241-
// inplace ddx
242-
Tensor tmp;
243-
if (dOut) {
244-
tmp = *dOut;
245-
} else {
246-
auto& dev_ctx = ctx.template device_context<DeviceContext>();
247-
tmp = ctx.AllocateTmpTensor<T, DeviceContext>(Out->dims(), dev_ctx);
248-
}
249-
if (dY) {
250-
// dX_div_Y = dX / Y;
251-
Tensor dX_div_Y = tmp;
252-
default_elementwise_div<DeviceContext, T>(ctx, dX, Y, &dX_div_Y);
253-
254-
// NOTE(dengkaipeng): in the following ElemwiseGradCompute, for the
255-
// first output tensor is nullptr, the branch to calculate first
256-
// output tensor will not be activated, DivGradDx function will not
257-
// be called and can be ignored, the first branch has little effect
258-
// on running speed.
259-
260-
// dY = Out * dX * ddY / Y - dX * ddX / Y
261-
ElemwiseGradCompute<DeviceContext, T, DivGradDX<T>, DivDoubleDY<T>>(
262-
ctx, ddX_safe, ddY_safe, *Out, dX_div_Y, axis, nullptr, dY,
263-
DivGradDX<T>(), DivDoubleDY<T>());
264-
}
265-
266-
if (ddOut) {
267-
// ddOut = ddX / Y - Out * ddY / Y = (ddX - Out * ddY) / Y
268-
default_elementwise_mul<DeviceContext, T>(ctx, Out, &ddY_safe, &tmp);
269-
default_elementwise_sub<DeviceContext, T>(ctx, &ddX_safe, &tmp, &tmp);
270-
default_elementwise_div<DeviceContext, T>(ctx, &tmp, Y, ddOut);
271-
}
272-
273-
if (dOut) {
274-
// dOut = - dX * ddY
275-
default_elementwise_mul<DeviceContext, T>(ctx, dX, &ddY_safe, dOut);
276-
auto& place =
277-
*ctx.template device_context<DeviceContext>().eigen_device();
278-
auto dout = framework::EigenVector<T>::Flatten(*dOut);
279-
dout.device(place) = static_cast<T>(-1) * dout;
280-
}
281-
}
282-
};
283-
28473
} // namespace operators
28574
} // namespace paddle

0 commit comments

Comments
 (0)