Skip to content

Commit 452c75b

Browse files
authored
move elementwise mul grad (#40252)
1 parent 0604df9 commit 452c75b

12 files changed

+539
-401
lines changed

paddle/fluid/framework/new_executor/standalone_executor_test.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ USE_OP(matmul_grad);
4646
USE_OP(square);
4747
USE_OP(transpose2_grad);
4848
USE_OP(concat_grad);
49-
USE_OP(elementwise_mul_grad);
49+
USE_OP_ITSELF(elementwise_mul_grad);
5050
USE_OP(sigmoid_grad);
5151
USE_OP(tanh_grad);
5252
USE_OP(sum);

paddle/fluid/operators/elementwise/elementwise_functor.h

-41
Original file line numberDiff line numberDiff line change
@@ -196,47 +196,6 @@ struct MinGradXYFunctor {
196196
}
197197
};
198198

199-
template <typename T>
200-
struct MulGradFunctor {
201-
inline HOSTDEVICE T operator()(const T a, const T b) const { return a * b; }
202-
};
203-
template <typename T>
204-
struct MulGradFunctor<Complex<T>> {
205-
inline HOSTDEVICE Complex<T> operator()(const Complex<T> a,
206-
const Complex<T> b) const {
207-
Complex<T> b_conj(b.real, -b.imag);
208-
return a * b_conj;
209-
}
210-
};
211-
212-
template <typename InT, typename OutT>
213-
struct MulGradXYFunctor {
214-
inline HOSTDEVICE phi::Array<OutT, 2> operator()(const InT a, const InT b,
215-
const InT c) {
216-
phi::Array<OutT, 2> outs;
217-
// dx = dout * y
218-
outs[0] = a * b;
219-
// dy = dout * x
220-
outs[1] = a * c;
221-
return outs;
222-
}
223-
};
224-
225-
template <typename InT, typename OutT>
226-
struct MulGradXYFunctor<Complex<InT>, Complex<OutT>> {
227-
inline HOSTDEVICE phi::Array<Complex<OutT>, 2> operator()(
228-
const Complex<InT> a, const Complex<InT> b, const Complex<InT> c) {
229-
phi::Array<Complex<OutT>, 2> outs;
230-
// dx = dout * y
231-
Complex<InT> b_conj(b.real, -b.imag);
232-
outs[0] = a * b_conj;
233-
// dy = dout * x
234-
Complex<InT> c_conj(c.real, -c.imag);
235-
outs[1] = a * c_conj;
236-
return outs;
237-
}
238-
};
239-
240199
// Ternary compare
241200
template <typename T>
242201
struct MaxGradXFunctor {

paddle/fluid/operators/elementwise/elementwise_mul_op.cc

-49
Original file line numberDiff line numberDiff line change
@@ -173,55 +173,6 @@ REGISTER_OP_CPU_KERNEL(
173173
paddle::platform::complex<float>>,
174174
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext,
175175
paddle::platform::complex<double>>);
176-
REGISTER_OP_CPU_KERNEL(
177-
elementwise_mul_grad,
178-
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, float>,
179-
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, double>,
180-
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, int>,
181-
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
182-
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, bool>,
183-
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext,
184-
paddle::platform::bfloat16>,
185-
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext,
186-
paddle::platform::complex<float>>,
187-
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext,
188-
paddle::platform::complex<double>>);
189-
REGISTER_OP_CPU_KERNEL(
190-
elementwise_mul_grad_grad,
191-
ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext,
192-
float>,
193-
ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext,
194-
double>,
195-
ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext,
196-
int>,
197-
ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext,
198-
int64_t>,
199-
ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext,
200-
bool>,
201-
ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext,
202-
paddle::platform::bfloat16>,
203-
ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext,
204-
paddle::platform::complex<float>>,
205-
ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext,
206-
paddle::platform::complex<double>>);
207-
REGISTER_OP_CPU_KERNEL(
208-
elementwise_mul_triple_grad,
209-
ops::ElementwiseMulTripleGradKernel<paddle::platform::CPUDeviceContext,
210-
float>,
211-
ops::ElementwiseMulTripleGradKernel<paddle::platform::CPUDeviceContext,
212-
double>,
213-
ops::ElementwiseMulTripleGradKernel<paddle::platform::CPUDeviceContext,
214-
int>,
215-
ops::ElementwiseMulTripleGradKernel<paddle::platform::CPUDeviceContext,
216-
int64_t>,
217-
ops::ElementwiseMulTripleGradKernel<paddle::platform::CPUDeviceContext,
218-
bool>,
219-
ops::ElementwiseMulTripleGradKernel<paddle::platform::CPUDeviceContext,
220-
paddle::platform::bfloat16>,
221-
ops::ElementwiseMulTripleGradKernel<paddle::platform::CPUDeviceContext,
222-
paddle::platform::complex<float>>,
223-
ops::ElementwiseMulTripleGradKernel<paddle::platform::CPUDeviceContext,
224-
paddle::platform::complex<double>>);
225176

226177
REGISTER_OP_VERSION(elementwise_mul)
227178
.AddCheckpoint(

paddle/fluid/operators/elementwise/elementwise_mul_op.cu

-68
Original file line numberDiff line numberDiff line change
@@ -63,33 +63,6 @@ class ElementwiseMulKernel<platform::CUDADeviceContext, T>
6363
}
6464
};
6565

66-
template <typename DeviceContext, typename T>
67-
typename std::enable_if<
68-
std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type
69-
ElementwiseMulGrad(const framework::ExecutionContext& ctx,
70-
const framework::Tensor* x, const framework::Tensor* y,
71-
const framework::Tensor* out, const framework::Tensor* dout,
72-
framework::Tensor* dx, framework::Tensor* dy) {
73-
int axis = ctx.Attr<int>("axis");
74-
const auto& dev_ctx =
75-
ctx.template device_context<platform::CUDADeviceContext>();
76-
const auto place = ctx.GetPlace();
77-
78-
if (dx != nullptr && dy != nullptr) {
79-
std::vector<const framework::Tensor*> ins = {dout, y, x};
80-
GetGradXAndYOut<ElementwiseType::kTernary, T>(
81-
dev_ctx, place, axis, ins, dout, dx, dy, MulGradXYFunctor<T, T>());
82-
} else if (dx != nullptr && dy == nullptr) {
83-
std::vector<const framework::Tensor*> ins = {dout, y};
84-
GetGradXOrYOut<ElementwiseType::kBinary, T>(dev_ctx, place, axis, ins, dout,
85-
dx, MulGradFunctor<T>());
86-
} else if (dx == nullptr && dy != nullptr) {
87-
std::vector<const framework::Tensor*> ins = {dout, x};
88-
GetGradXOrYOut<ElementwiseType::kBinary, T>(dev_ctx, place, axis, ins, dout,
89-
dy, MulGradFunctor<T>());
90-
}
91-
}
92-
9366
} // namespace operators
9467
} // namespace paddle
9568

@@ -103,44 +76,3 @@ REGISTER_OP_CUDA_KERNEL(
10376
ops::ElementwiseMulKernel<plat::CUDADeviceContext, plat::bfloat16>,
10477
ops::ElementwiseMulKernel<plat::CUDADeviceContext, plat::complex<float>>,
10578
ops::ElementwiseMulKernel<plat::CUDADeviceContext, plat::complex<double>>);
106-
REGISTER_OP_CUDA_KERNEL(
107-
elementwise_mul_grad,
108-
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, float>,
109-
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, double>,
110-
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, int>,
111-
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, int64_t>,
112-
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, bool>,
113-
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, plat::float16>,
114-
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, plat::bfloat16>,
115-
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext,
116-
plat::complex<float>>,
117-
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext,
118-
plat::complex<double>>);
119-
REGISTER_OP_CUDA_KERNEL(
120-
elementwise_mul_grad_grad,
121-
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, float>,
122-
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, double>,
123-
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, int>,
124-
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, int64_t>,
125-
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, bool>,
126-
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, plat::float16>,
127-
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext,
128-
plat::bfloat16>,
129-
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext,
130-
plat::complex<float>>,
131-
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext,
132-
plat::complex<double>>);
133-
REGISTER_OP_CUDA_KERNEL(
134-
elementwise_mul_triple_grad,
135-
ops::ElementwiseMulTripleGradKernel<plat::CUDADeviceContext, float>,
136-
ops::ElementwiseMulTripleGradKernel<plat::CUDADeviceContext, double>,
137-
ops::ElementwiseMulTripleGradKernel<plat::CUDADeviceContext, int>,
138-
ops::ElementwiseMulTripleGradKernel<plat::CUDADeviceContext, int64_t>,
139-
ops::ElementwiseMulTripleGradKernel<plat::CUDADeviceContext, bool>,
140-
ops::ElementwiseMulTripleGradKernel<plat::CUDADeviceContext, plat::float16>,
141-
ops::ElementwiseMulTripleGradKernel<plat::CUDADeviceContext,
142-
plat::bfloat16>,
143-
ops::ElementwiseMulTripleGradKernel<plat::CUDADeviceContext,
144-
plat::complex<float>>,
145-
ops::ElementwiseMulTripleGradKernel<plat::CUDADeviceContext,
146-
plat::complex<double>>);

0 commit comments

Comments
 (0)