@@ -63,33 +63,6 @@ class ElementwiseMulKernel<platform::CUDADeviceContext, T>
63
63
}
64
64
};
65
65
66
- template <typename DeviceContext, typename T>
67
- typename std::enable_if<
68
- std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type
69
- ElementwiseMulGrad (const framework::ExecutionContext& ctx,
70
- const framework::Tensor* x, const framework::Tensor* y,
71
- const framework::Tensor* out, const framework::Tensor* dout,
72
- framework::Tensor* dx, framework::Tensor* dy) {
73
- int axis = ctx.Attr <int >(" axis" );
74
- const auto & dev_ctx =
75
- ctx.template device_context <platform::CUDADeviceContext>();
76
- const auto place = ctx.GetPlace ();
77
-
78
- if (dx != nullptr && dy != nullptr ) {
79
- std::vector<const framework::Tensor*> ins = {dout, y, x};
80
- GetGradXAndYOut<ElementwiseType::kTernary , T>(
81
- dev_ctx, place, axis, ins, dout, dx, dy, MulGradXYFunctor<T, T>());
82
- } else if (dx != nullptr && dy == nullptr ) {
83
- std::vector<const framework::Tensor*> ins = {dout, y};
84
- GetGradXOrYOut<ElementwiseType::kBinary , T>(dev_ctx, place, axis, ins, dout,
85
- dx, MulGradFunctor<T>());
86
- } else if (dx == nullptr && dy != nullptr ) {
87
- std::vector<const framework::Tensor*> ins = {dout, x};
88
- GetGradXOrYOut<ElementwiseType::kBinary , T>(dev_ctx, place, axis, ins, dout,
89
- dy, MulGradFunctor<T>());
90
- }
91
- }
92
-
93
66
} // namespace operators
94
67
} // namespace paddle
95
68
@@ -103,44 +76,3 @@ REGISTER_OP_CUDA_KERNEL(
103
76
ops::ElementwiseMulKernel<plat::CUDADeviceContext, plat::bfloat16>,
104
77
ops::ElementwiseMulKernel<plat::CUDADeviceContext, plat::complex<float >>,
105
78
ops::ElementwiseMulKernel<plat::CUDADeviceContext, plat::complex<double >>);
106
- REGISTER_OP_CUDA_KERNEL (
107
- elementwise_mul_grad,
108
- ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, float >,
109
- ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, double >,
110
- ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, int >,
111
- ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, int64_t >,
112
- ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, bool >,
113
- ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, plat::float16>,
114
- ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, plat::bfloat16>,
115
- ops::ElementwiseMulGradKernel<plat::CUDADeviceContext,
116
- plat::complex<float >>,
117
- ops::ElementwiseMulGradKernel<plat::CUDADeviceContext,
118
- plat::complex<double >>);
119
- REGISTER_OP_CUDA_KERNEL (
120
- elementwise_mul_grad_grad,
121
- ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, float >,
122
- ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, double >,
123
- ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, int >,
124
- ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, int64_t >,
125
- ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, bool >,
126
- ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, plat::float16>,
127
- ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext,
128
- plat::bfloat16>,
129
- ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext,
130
- plat::complex<float >>,
131
- ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext,
132
- plat::complex<double >>);
133
- REGISTER_OP_CUDA_KERNEL (
134
- elementwise_mul_triple_grad,
135
- ops::ElementwiseMulTripleGradKernel<plat::CUDADeviceContext, float >,
136
- ops::ElementwiseMulTripleGradKernel<plat::CUDADeviceContext, double >,
137
- ops::ElementwiseMulTripleGradKernel<plat::CUDADeviceContext, int >,
138
- ops::ElementwiseMulTripleGradKernel<plat::CUDADeviceContext, int64_t >,
139
- ops::ElementwiseMulTripleGradKernel<plat::CUDADeviceContext, bool >,
140
- ops::ElementwiseMulTripleGradKernel<plat::CUDADeviceContext, plat::float16>,
141
- ops::ElementwiseMulTripleGradKernel<plat::CUDADeviceContext,
142
- plat::bfloat16>,
143
- ops::ElementwiseMulTripleGradKernel<plat::CUDADeviceContext,
144
- plat::complex<float >>,
145
- ops::ElementwiseMulTripleGradKernel<plat::CUDADeviceContext,
146
- plat::complex<double >>);
0 commit comments