@@ -19,6 +19,43 @@ limitations under the License. */
1919namespace paddle {
2020namespace operators {
2121
22+ // aligned vector generates vectorized load/store on CUDA
23+ template <typename T, int Size>
24+ struct alignas (sizeof (T) * Size) AlignedVector {
25+ T val[Size];
26+ };
27+
28+ template <typename T>
29+ inline int VectorizedSize (const T* pointer) {
30+ uint64_t address = reinterpret_cast <uint64_t >(pointer);
31+ constexpr int vec4 = std::alignment_of<AlignedVector<T, 4 >>::value; // NOLINT
32+ if (address % vec4 == 0 ) {
33+ return 4 ;
34+ }
35+ return 1 ;
36+ }
37+
38+ template <typename InT, typename OutT, int VecSize>
39+ __global__ void VecCastCUDAKernel (const InT* in, const int64_t N, OutT* out) {
40+ int64_t idx = blockDim .x * blockIdx .x + threadIdx .x ;
41+ using LoadT = AlignedVector<InT, VecSize>;
42+ using StoreT = AlignedVector<OutT, VecSize>;
43+ for (int i = idx * VecSize; i < N; i += blockDim .x * gridDim .x * VecSize) {
44+ InT in_vec[VecSize];
45+ LoadT* in_value = reinterpret_cast <LoadT*>(&in_vec);
46+ *in_value = *reinterpret_cast <const LoadT*>(&in[i]);
47+
48+ OutT out_vec[VecSize];
49+ #pragma unroll
50+ for (int ii = 0 ; ii < VecSize; ii++) {
51+ out_vec[ii] = static_cast <OutT>(in_vec[ii]);
52+ }
53+
54+ *(reinterpret_cast <StoreT*>(&out[i])) =
55+ *reinterpret_cast <StoreT*>(&out_vec[0 ]);
56+ }
57+ }
58+
2259template <typename InT, typename OutT>
2360__global__ void CastCUDAKernel (const InT* in, const int64_t N, OutT* out) {
2461 CUDA_KERNEL_LOOP (index, N) { out[index] = static_cast <OutT>(in[index]); }
@@ -40,8 +77,16 @@ struct CastOpFunctor<platform::CUDADeviceContext, InT> {
4077 auto * out = out_->mutable_data <OutT>(ctx_.GetPlace ());
4178 platform::GpuLaunchConfig config =
4279 platform::GetGpuLaunchConfig1D (ctx_, size);
43- CastCUDAKernel<InT, OutT><<<config.block_per_grid, config.thread_per_block,
44- 0 , ctx_.stream()>>> (in, size, out);
80+ int vec_size = VectorizedSize<OutT>(out);
81+ if (!std::is_same<InT, OutT>::value && vec_size == 4 && size % 4 == 0 ) {
82+ VecCastCUDAKernel<InT, OutT, 4 ><<<
83+ config.block_per_grid, config.thread_per_block, 0 , ctx_.stream()>>> (
84+ in, size, out);
85+ } else {
86+ CastCUDAKernel<InT, OutT><<<config.block_per_grid,
87+ config.thread_per_block, 0 , ctx_.stream()>>> (
88+ in, size, out);
89+ }
4590 }
4691};
4792
0 commit comments