@@ -49,77 +49,81 @@ int GetVectorizedSizeImpl(const T *pointer) {
4949 return 1 ;
5050}
5151
52- template <typename T >
52+ template <typename InT, typename OutT >
5353int GetVectorizedSize (const std::vector<const framework::Tensor *> &ins,
5454 const std::vector<framework::Tensor *> &outs) {
5555 int vec_size = 4 ;
5656 for (auto iter = ins.begin (); iter != ins.end (); ++iter) {
5757 vec_size =
58- std::min<int >(vec_size, GetVectorizedSizeImpl ((*iter)->data <T >()));
58+ std::min<int >(vec_size, GetVectorizedSizeImpl ((*iter)->data <InT >()));
5959 }
6060 for (auto iter = outs.begin (); iter != outs.end (); ++iter) {
6161 vec_size =
62- std::min<int >(vec_size, GetVectorizedSizeImpl ((*iter)->data <T >()));
62+ std::min<int >(vec_size, GetVectorizedSizeImpl ((*iter)->data <OutT >()));
6363 }
6464 return vec_size;
6565}
6666
67- template <ElementwiseType ET, int VecSize, typename T >
67+ template <ElementwiseType ET, int VecSize, typename InT, typename OutT >
6868struct ElementwiseDataWrapper {
69- T *out;
70- const T *in0;
71- const T *in1;
72- __device__ ElementwiseDataWrapper (T *out, const T *in0,
73- const T *in1 = nullptr )
69+ OutT *out;
70+ const InT *in0;
71+ const InT *in1;
72+ __device__ ElementwiseDataWrapper (OutT *out, const InT *in0,
73+ const InT *in1 = nullptr )
7474 : out(out), in0(in0), in1(in1) {}
7575
76- using VecType = CudaAlignedVector<T, VecSize>;
76+ using InVecType = CudaAlignedVector<InT, VecSize>;
77+ using OutVecType = CudaAlignedVector<OutT, VecSize>;
7778
78- inline __device__ void load_vector (VecType args[], int idx) {
79- const VecType *x_vec = reinterpret_cast <const VecType *>(in0);
79+ inline __device__ void load_vector (InVecType args[], int idx) {
80+ const InVecType *x_vec = reinterpret_cast <const InVecType *>(in0);
8081 args[0 ] = x_vec[idx];
8182 if (ET == ElementwiseType::kBinary ) {
82- const VecType *y_vec = reinterpret_cast <const VecType *>(in1);
83+ const InVecType *y_vec = reinterpret_cast <const InVecType *>(in1);
8384 args[1 ] = y_vec[idx];
8485 }
8586 }
8687
87- inline __device__ void load_scalar (T args[], int idx) {
88+ inline __device__ void load_scalar (InT args[], int idx) {
8889 args[0 ] = in0[idx];
8990 if (ET == ElementwiseType::kBinary ) {
9091 args[1 ] = in1[idx];
9192 }
9293 }
9394
94- inline __device__ void store_vector (VecType res, int idx) {
95- VecType *out_vec = reinterpret_cast <VecType *>(out);
95+ inline __device__ void store_vector (OutVecType res, int idx) {
96+ OutVecType *out_vec = reinterpret_cast <OutVecType *>(out);
9697 out_vec[idx] = res;
9798 }
9899
99- inline __device__ void store_scalar (T res, int idx) { out[idx] = res; }
100+ inline __device__ void store_scalar (OutT res, int idx) { out[idx] = res; }
100101};
101102
102- template <ElementwiseType ET, int VecSize, typename T, typename Functor>
103+ template <ElementwiseType ET, int VecSize, typename InT, typename OutT,
104+ typename Functor>
103105__device__ void VectorizedKernelImpl (
104- ElementwiseDataWrapper<ET, VecSize, T> data, Functor func, int tid) {
105- using VecType = CudaAlignedVector<T, VecSize>;
106- VecType ins_vec[ET];
107- VecType out_vec;
108- T *ins_ptr[ET];
109- T *out_ptr;
106+ ElementwiseDataWrapper<ET, VecSize, InT, OutT> data, Functor func,
107+ int tid) {
108+ using InVecType = CudaAlignedVector<InT, VecSize>;
109+ using OutVecType = CudaAlignedVector<OutT, VecSize>;
110+ InVecType ins_vec[ET];
111+ OutVecType out_vec;
112+ InT *ins_ptr[ET];
113+ OutT *out_ptr;
110114#pragma unroll
111115 for (int i = 0 ; i < ET; ++i) {
112- ins_ptr[i] = reinterpret_cast <T *>(&(ins_vec[i]));
116+ ins_ptr[i] = reinterpret_cast <InT *>(&(ins_vec[i]));
113117 }
114- out_ptr = reinterpret_cast <T *>(&out_vec);
118+ out_ptr = reinterpret_cast <OutT *>(&out_vec);
115119
116120 // load
117121 data.load_vector (ins_vec, tid);
118122
119123// compute
120124#pragma unroll
121125 for (int i = 0 ; i < VecSize; ++i) {
122- T ins[ET];
126+ InT ins[ET];
123127#pragma unroll
124128 for (int j = 0 ; j < ET; ++j) {
125129 ins[j] = ins_ptr[j][i];
@@ -131,11 +135,13 @@ __device__ void VectorizedKernelImpl(
131135 data.store_vector (out_vec, tid);
132136}
133137
134- template <ElementwiseType ET, int VecSize, typename T, typename Functor>
135- __device__ void ScalarKernelImpl (ElementwiseDataWrapper<ET, VecSize, T> data,
136- Functor func, int start, int remain) {
137- T ins[ET];
138- T out;
138+ template <ElementwiseType ET, int VecSize, typename InT, typename OutT,
139+ typename Functor>
140+ __device__ void ScalarKernelImpl (
141+ ElementwiseDataWrapper<ET, VecSize, InT, OutT> data, Functor func,
142+ int start, int remain) {
143+ InT ins[ET];
144+ OutT out;
139145
140146 for (int i = 0 ; i < remain; ++i) {
141147 int idx = start + i;
@@ -148,45 +154,47 @@ __device__ void ScalarKernelImpl(ElementwiseDataWrapper<ET, VecSize, T> data,
148154 }
149155}
150156
151- template <ElementwiseType ET, int VecSize, typename T, typename Functor>
152- __global__ void VectorizedKernel (const T *__restrict__ in0,
153- const T *__restrict__ in1, T *out, int size,
154- Functor func) {
157+ template <ElementwiseType ET, int VecSize, typename InT, typename OutT,
158+ typename Functor>
159+ __global__ void VectorizedKernel (const InT *__restrict__ in0,
160+ const InT *__restrict__ in1, OutT *out,
161+ int size, Functor func) {
155162 int tid = blockIdx.x * blockDim.x + threadIdx.x ;
156163 int remain = size - VecSize * tid;
157164 remain = remain > 0 ? remain : 0 ;
158- auto data = ElementwiseDataWrapper<ET, VecSize, T >(out, in0, in1);
165+ auto data = ElementwiseDataWrapper<ET, VecSize, InT, OutT >(out, in0, in1);
159166 if (remain >= VecSize) {
160167 VectorizedKernelImpl (data, func, tid);
161168 } else {
162169 ScalarKernelImpl (data, func, tid * VecSize, remain);
163170 }
164171}
165172
166- template <ElementwiseType ET, typename T , typename Functor>
167- __global__ void ScalarKernel (const T *__restrict__ in0,
168- const T *__restrict__ in1, T *out, int size,
173+ template <ElementwiseType ET, typename InT, typename OutT , typename Functor>
174+ __global__ void ScalarKernel (const InT *__restrict__ in0,
175+ const InT *__restrict__ in1, OutT *out, int size,
169176 Functor func) {
170- auto data = ElementwiseDataWrapper<ET, 1 , T >(out, in0, in1);
177+ auto data = ElementwiseDataWrapper<ET, 1 , InT, OutT >(out, in0, in1);
171178 int tid = blockIdx.x * blockDim.x + threadIdx.x ;
172179 int remain = tid < size ? 1 : 0 ;
173180 ScalarKernelImpl (data, func, tid, remain);
174181}
175182
176- template <ElementwiseType ET, typename T , typename Functor>
183+ template <ElementwiseType ET, typename InT, typename OutT , typename Functor>
177184void LaunchElementwiseCudaKernel (
178185 const platform::CUDADeviceContext &ctx,
179186 const std::vector<const framework::Tensor *> &ins,
180187 std::vector<framework::Tensor *> *outs, Functor func) {
181188 // calculate the max vec_size for all ins and outs
182189 auto size = ins[0 ]->numel ();
183- int vec_size = GetVectorizedSize<T >(ins, *outs);
190+ int vec_size = GetVectorizedSize<InT, OutT >(ins, *outs);
184191 int block_size = ELEMENTWISE_BLOCK_SIZE;
185192 int grid_size =
186193 ((size + vec_size - 1 ) / vec_size + block_size - 1 ) / block_size;
187- const T *in0 = ins[0 ]->data <T>();
188- const T *in1 = (ET == ElementwiseType::kBinary ) ? ins[1 ]->data <T>() : nullptr ;
189- T *out = (*outs)[0 ]->data <T>();
194+ const InT *in0 = ins[0 ]->data <InT>();
195+ const InT *in1 =
196+ (ET == ElementwiseType::kBinary ) ? ins[1 ]->data <InT>() : nullptr ;
197+ OutT *out = (*outs)[0 ]->data <OutT>();
190198 // cuda kernel
191199 auto stream = ctx.stream ();
192200 switch (vec_size) {
0 commit comments