Skip to content

Commit 3419de5

Browse files
authored
Support different data type between input and output (#32823)
1 parent fbbc339 commit 3419de5

File tree

4 files changed

+129
-86
lines changed

4 files changed

+129
-86
lines changed

paddle/fluid/operators/abs_op.cu

Lines changed: 66 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -13,44 +13,79 @@
1313
// limitations under the License.
1414

1515
#include "paddle/fluid/operators/abs_op.h"
16+
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
1617
#include "paddle/fluid/platform/complex128.h"
1718
#include "paddle/fluid/platform/complex64.h"
1819
#include "paddle/fluid/platform/float16.h"
1920

21+
namespace paddle {
22+
namespace operators {
23+
24+
template <typename T, typename Enable = void>
25+
struct CudaAbsFunctor;
26+
27+
template <typename T>
28+
struct CudaAbsFunctor<T, math::Complex<T, math::Real<T>>> {
29+
__device__ __forceinline__ math::Real<T> operator()(const T* args) const {
30+
return abs(args[0]);
31+
}
32+
};
33+
34+
template <typename T>
35+
struct CudaAbsFunctor<T, math::NoComplex<T, math::Real<T>>> {
36+
__device__ __forceinline__ T operator()(const T* args) const {
37+
return std::abs(args[0]);
38+
}
39+
};
40+
41+
template <typename T>
42+
class AbsKernel<platform::CUDADeviceContext, T>
43+
: public framework::OpKernel<T> {
44+
public:
45+
void Compute(const framework::ExecutionContext& context) const override {
46+
const Tensor* x = context.Input<Tensor>("X");
47+
Tensor* out = context.Output<Tensor>("Out");
48+
out->mutable_data<math::Real<T>>(context.GetPlace());
49+
50+
auto& dev_ctx =
51+
context.template device_context<platform::CUDADeviceContext>();
52+
std::vector<const framework::Tensor*> ins = {x};
53+
std::vector<framework::Tensor*> outs = {out};
54+
auto functor = CudaAbsFunctor<T>();
55+
LaunchElementwiseCudaKernel<ElementwiseType::kUnary, T, math::Real<T>>(
56+
dev_ctx, ins, &outs, functor);
57+
}
58+
};
59+
60+
} // namespace operators
61+
} // namespace paddle
62+
2063
namespace ops = paddle::operators;
64+
namespace plat = paddle::platform;
65+
2166
REGISTER_OP_CUDA_KERNEL(
22-
abs, ops::AbsKernel<paddle::platform::CUDADeviceContext, float>,
23-
ops::AbsKernel<paddle::platform::CUDADeviceContext, double>,
24-
ops::AbsKernel<paddle::platform::CUDADeviceContext, int>,
25-
ops::AbsKernel<paddle::platform::CUDADeviceContext, int64_t>,
26-
ops::AbsKernel<paddle::platform::CUDADeviceContext,
27-
paddle::platform::float16>,
28-
ops::AbsKernel<paddle::platform::CUDADeviceContext,
29-
paddle::platform::complex64>,
30-
ops::AbsKernel<paddle::platform::CUDADeviceContext,
31-
paddle::platform::complex128>);
67+
abs, ops::AbsKernel<plat::CUDADeviceContext, float>,
68+
ops::AbsKernel<plat::CUDADeviceContext, double>,
69+
ops::AbsKernel<plat::CUDADeviceContext, int>,
70+
ops::AbsKernel<plat::CUDADeviceContext, int64_t>,
71+
ops::AbsKernel<plat::CUDADeviceContext, plat::float16>,
72+
ops::AbsKernel<plat::CUDADeviceContext, plat::complex64>,
73+
ops::AbsKernel<plat::CUDADeviceContext, plat::complex128>);
3274

3375
REGISTER_OP_CUDA_KERNEL(
34-
abs_grad, ops::AbsGradKernel<paddle::platform::CUDADeviceContext, float>,
35-
ops::AbsGradKernel<paddle::platform::CUDADeviceContext, double>,
36-
ops::AbsGradKernel<paddle::platform::CUDADeviceContext, int>,
37-
ops::AbsGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
38-
ops::AbsGradKernel<paddle::platform::CUDADeviceContext,
39-
paddle::platform::float16>,
40-
ops::AbsGradKernel<paddle::platform::CUDADeviceContext,
41-
paddle::platform::complex64>,
42-
ops::AbsGradKernel<paddle::platform::CUDADeviceContext,
43-
paddle::platform::complex128>);
76+
abs_grad, ops::AbsGradKernel<plat::CUDADeviceContext, float>,
77+
ops::AbsGradKernel<plat::CUDADeviceContext, double>,
78+
ops::AbsGradKernel<plat::CUDADeviceContext, int>,
79+
ops::AbsGradKernel<plat::CUDADeviceContext, int64_t>,
80+
ops::AbsGradKernel<plat::CUDADeviceContext, plat::float16>,
81+
ops::AbsGradKernel<plat::CUDADeviceContext, plat::complex64>,
82+
ops::AbsGradKernel<plat::CUDADeviceContext, plat::complex128>);
4483

4584
REGISTER_OP_CUDA_KERNEL(
46-
abs_grad_grad,
47-
ops::AbsDoubleGradKernel<paddle::platform::CUDADeviceContext, float>,
48-
ops::AbsDoubleGradKernel<paddle::platform::CUDADeviceContext, double>,
49-
ops::AbsDoubleGradKernel<paddle::platform::CUDADeviceContext, int>,
50-
ops::AbsDoubleGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
51-
ops::AbsDoubleGradKernel<paddle::platform::CUDADeviceContext,
52-
paddle::platform::float16>,
53-
ops::AbsDoubleGradKernel<paddle::platform::CUDADeviceContext,
54-
paddle::platform::complex64>,
55-
ops::AbsDoubleGradKernel<paddle::platform::CUDADeviceContext,
56-
paddle::platform::complex128>);
85+
abs_grad_grad, ops::AbsDoubleGradKernel<plat::CUDADeviceContext, float>,
86+
ops::AbsDoubleGradKernel<plat::CUDADeviceContext, double>,
87+
ops::AbsDoubleGradKernel<plat::CUDADeviceContext, int>,
88+
ops::AbsDoubleGradKernel<plat::CUDADeviceContext, int64_t>,
89+
ops::AbsDoubleGradKernel<plat::CUDADeviceContext, plat::float16>,
90+
ops::AbsDoubleGradKernel<plat::CUDADeviceContext, plat::complex64>,
91+
ops::AbsDoubleGradKernel<plat::CUDADeviceContext, plat::complex128>);

paddle/fluid/operators/activation_op.cu

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1315,8 +1315,8 @@ class ActivationCudaKernel
13151315
for (auto& attr : attrs) {
13161316
*attr.second = ctx.Attr<float>(attr.first);
13171317
}
1318-
LaunchElementwiseCudaKernel<ElementwiseType::kUnary, T>(dev_ctx, ins, &outs,
1319-
functor);
1318+
LaunchElementwiseCudaKernel<ElementwiseType::kUnary, T, T>(dev_ctx, ins,
1319+
&outs, functor);
13201320
}
13211321
};
13221322

@@ -1345,17 +1345,17 @@ class ActivationGradCudaKernel
13451345
if (static_cast<int>(Functor::FwdDeps()) == static_cast<int>(kDepOut)) {
13461346
// Only need forward output Out
13471347
ins.push_back(out);
1348-
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T>(dev_ctx, ins,
1349-
&outs, functor);
1348+
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
1349+
dev_ctx, ins, &outs, functor);
13501350
} else if (static_cast<int>(Functor::FwdDeps()) ==
13511351
static_cast<int>(kDepX)) {
13521352
// Only need forward input X
13531353
ins.push_back(x);
1354-
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T>(dev_ctx, ins,
1355-
&outs, functor);
1354+
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
1355+
dev_ctx, ins, &outs, functor);
13561356
} else {
1357-
LaunchElementwiseCudaKernel<ElementwiseType::kUnary, T>(dev_ctx, ins,
1358-
&outs, functor);
1357+
LaunchElementwiseCudaKernel<ElementwiseType::kUnary, T, T>(
1358+
dev_ctx, ins, &outs, functor);
13591359
}
13601360
}
13611361
};

paddle/fluid/operators/elementwise/elementwise_add_op.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ struct SameDimsElemwiseAdd<platform::CUDADeviceContext, T> {
4545
framework::Tensor* z) {
4646
std::vector<const framework::Tensor*> ins = {x, y};
4747
std::vector<framework::Tensor*> outs = {z};
48-
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T>(
48+
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
4949
ctx.template device_context<platform::CUDADeviceContext>(), ins, &outs,
5050
CudaAddFunctor<T>());
5151
}

paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h

Lines changed: 54 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -49,77 +49,81 @@ int GetVectorizedSizeImpl(const T *pointer) {
4949
return 1;
5050
}
5151

52-
template <typename T>
52+
template <typename InT, typename OutT>
5353
int GetVectorizedSize(const std::vector<const framework::Tensor *> &ins,
5454
const std::vector<framework::Tensor *> &outs) {
5555
int vec_size = 4;
5656
for (auto iter = ins.begin(); iter != ins.end(); ++iter) {
5757
vec_size =
58-
std::min<int>(vec_size, GetVectorizedSizeImpl((*iter)->data<T>()));
58+
std::min<int>(vec_size, GetVectorizedSizeImpl((*iter)->data<InT>()));
5959
}
6060
for (auto iter = outs.begin(); iter != outs.end(); ++iter) {
6161
vec_size =
62-
std::min<int>(vec_size, GetVectorizedSizeImpl((*iter)->data<T>()));
62+
std::min<int>(vec_size, GetVectorizedSizeImpl((*iter)->data<OutT>()));
6363
}
6464
return vec_size;
6565
}
6666

67-
template <ElementwiseType ET, int VecSize, typename T>
67+
template <ElementwiseType ET, int VecSize, typename InT, typename OutT>
6868
struct ElementwiseDataWrapper {
69-
T *out;
70-
const T *in0;
71-
const T *in1;
72-
__device__ ElementwiseDataWrapper(T *out, const T *in0,
73-
const T *in1 = nullptr)
69+
OutT *out;
70+
const InT *in0;
71+
const InT *in1;
72+
__device__ ElementwiseDataWrapper(OutT *out, const InT *in0,
73+
const InT *in1 = nullptr)
7474
: out(out), in0(in0), in1(in1) {}
7575

76-
using VecType = CudaAlignedVector<T, VecSize>;
76+
using InVecType = CudaAlignedVector<InT, VecSize>;
77+
using OutVecType = CudaAlignedVector<OutT, VecSize>;
7778

78-
inline __device__ void load_vector(VecType args[], int idx) {
79-
const VecType *x_vec = reinterpret_cast<const VecType *>(in0);
79+
inline __device__ void load_vector(InVecType args[], int idx) {
80+
const InVecType *x_vec = reinterpret_cast<const InVecType *>(in0);
8081
args[0] = x_vec[idx];
8182
if (ET == ElementwiseType::kBinary) {
82-
const VecType *y_vec = reinterpret_cast<const VecType *>(in1);
83+
const InVecType *y_vec = reinterpret_cast<const InVecType *>(in1);
8384
args[1] = y_vec[idx];
8485
}
8586
}
8687

87-
inline __device__ void load_scalar(T args[], int idx) {
88+
inline __device__ void load_scalar(InT args[], int idx) {
8889
args[0] = in0[idx];
8990
if (ET == ElementwiseType::kBinary) {
9091
args[1] = in1[idx];
9192
}
9293
}
9394

94-
inline __device__ void store_vector(VecType res, int idx) {
95-
VecType *out_vec = reinterpret_cast<VecType *>(out);
95+
inline __device__ void store_vector(OutVecType res, int idx) {
96+
OutVecType *out_vec = reinterpret_cast<OutVecType *>(out);
9697
out_vec[idx] = res;
9798
}
9899

99-
inline __device__ void store_scalar(T res, int idx) { out[idx] = res; }
100+
inline __device__ void store_scalar(OutT res, int idx) { out[idx] = res; }
100101
};
101102

102-
template <ElementwiseType ET, int VecSize, typename T, typename Functor>
103+
template <ElementwiseType ET, int VecSize, typename InT, typename OutT,
104+
typename Functor>
103105
__device__ void VectorizedKernelImpl(
104-
ElementwiseDataWrapper<ET, VecSize, T> data, Functor func, int tid) {
105-
using VecType = CudaAlignedVector<T, VecSize>;
106-
VecType ins_vec[ET];
107-
VecType out_vec;
108-
T *ins_ptr[ET];
109-
T *out_ptr;
106+
ElementwiseDataWrapper<ET, VecSize, InT, OutT> data, Functor func,
107+
int tid) {
108+
using InVecType = CudaAlignedVector<InT, VecSize>;
109+
using OutVecType = CudaAlignedVector<OutT, VecSize>;
110+
InVecType ins_vec[ET];
111+
OutVecType out_vec;
112+
InT *ins_ptr[ET];
113+
OutT *out_ptr;
110114
#pragma unroll
111115
for (int i = 0; i < ET; ++i) {
112-
ins_ptr[i] = reinterpret_cast<T *>(&(ins_vec[i]));
116+
ins_ptr[i] = reinterpret_cast<InT *>(&(ins_vec[i]));
113117
}
114-
out_ptr = reinterpret_cast<T *>(&out_vec);
118+
out_ptr = reinterpret_cast<OutT *>(&out_vec);
115119

116120
// load
117121
data.load_vector(ins_vec, tid);
118122

119123
// compute
120124
#pragma unroll
121125
for (int i = 0; i < VecSize; ++i) {
122-
T ins[ET];
126+
InT ins[ET];
123127
#pragma unroll
124128
for (int j = 0; j < ET; ++j) {
125129
ins[j] = ins_ptr[j][i];
@@ -131,11 +135,13 @@ __device__ void VectorizedKernelImpl(
131135
data.store_vector(out_vec, tid);
132136
}
133137

134-
template <ElementwiseType ET, int VecSize, typename T, typename Functor>
135-
__device__ void ScalarKernelImpl(ElementwiseDataWrapper<ET, VecSize, T> data,
136-
Functor func, int start, int remain) {
137-
T ins[ET];
138-
T out;
138+
template <ElementwiseType ET, int VecSize, typename InT, typename OutT,
139+
typename Functor>
140+
__device__ void ScalarKernelImpl(
141+
ElementwiseDataWrapper<ET, VecSize, InT, OutT> data, Functor func,
142+
int start, int remain) {
143+
InT ins[ET];
144+
OutT out;
139145

140146
for (int i = 0; i < remain; ++i) {
141147
int idx = start + i;
@@ -148,45 +154,47 @@ __device__ void ScalarKernelImpl(ElementwiseDataWrapper<ET, VecSize, T> data,
148154
}
149155
}
150156

151-
template <ElementwiseType ET, int VecSize, typename T, typename Functor>
152-
__global__ void VectorizedKernel(const T *__restrict__ in0,
153-
const T *__restrict__ in1, T *out, int size,
154-
Functor func) {
157+
template <ElementwiseType ET, int VecSize, typename InT, typename OutT,
158+
typename Functor>
159+
__global__ void VectorizedKernel(const InT *__restrict__ in0,
160+
const InT *__restrict__ in1, OutT *out,
161+
int size, Functor func) {
155162
int tid = blockIdx.x * blockDim.x + threadIdx.x;
156163
int remain = size - VecSize * tid;
157164
remain = remain > 0 ? remain : 0;
158-
auto data = ElementwiseDataWrapper<ET, VecSize, T>(out, in0, in1);
165+
auto data = ElementwiseDataWrapper<ET, VecSize, InT, OutT>(out, in0, in1);
159166
if (remain >= VecSize) {
160167
VectorizedKernelImpl(data, func, tid);
161168
} else {
162169
ScalarKernelImpl(data, func, tid * VecSize, remain);
163170
}
164171
}
165172

166-
template <ElementwiseType ET, typename T, typename Functor>
167-
__global__ void ScalarKernel(const T *__restrict__ in0,
168-
const T *__restrict__ in1, T *out, int size,
173+
template <ElementwiseType ET, typename InT, typename OutT, typename Functor>
174+
__global__ void ScalarKernel(const InT *__restrict__ in0,
175+
const InT *__restrict__ in1, OutT *out, int size,
169176
Functor func) {
170-
auto data = ElementwiseDataWrapper<ET, 1, T>(out, in0, in1);
177+
auto data = ElementwiseDataWrapper<ET, 1, InT, OutT>(out, in0, in1);
171178
int tid = blockIdx.x * blockDim.x + threadIdx.x;
172179
int remain = tid < size ? 1 : 0;
173180
ScalarKernelImpl(data, func, tid, remain);
174181
}
175182

176-
template <ElementwiseType ET, typename T, typename Functor>
183+
template <ElementwiseType ET, typename InT, typename OutT, typename Functor>
177184
void LaunchElementwiseCudaKernel(
178185
const platform::CUDADeviceContext &ctx,
179186
const std::vector<const framework::Tensor *> &ins,
180187
std::vector<framework::Tensor *> *outs, Functor func) {
181188
// calculate the max vec_size for all ins and outs
182189
auto size = ins[0]->numel();
183-
int vec_size = GetVectorizedSize<T>(ins, *outs);
190+
int vec_size = GetVectorizedSize<InT, OutT>(ins, *outs);
184191
int block_size = ELEMENTWISE_BLOCK_SIZE;
185192
int grid_size =
186193
((size + vec_size - 1) / vec_size + block_size - 1) / block_size;
187-
const T *in0 = ins[0]->data<T>();
188-
const T *in1 = (ET == ElementwiseType::kBinary) ? ins[1]->data<T>() : nullptr;
189-
T *out = (*outs)[0]->data<T>();
194+
const InT *in0 = ins[0]->data<InT>();
195+
const InT *in1 =
196+
(ET == ElementwiseType::kBinary) ? ins[1]->data<InT>() : nullptr;
197+
OutT *out = (*outs)[0]->data<OutT>();
190198
// cuda kernel
191199
auto stream = ctx.stream();
192200
switch (vec_size) {

0 commit comments

Comments
 (0)