Skip to content

Commit b602260

Browse files
zhangting2020zhangting2020
authored andcommitted
add VecCastCUDAKernel (#30296)
1 parent 35dfec6 commit b602260

File tree

1 file changed

+47
-2
lines changed

1 file changed

+47
-2
lines changed

paddle/fluid/operators/cast_op.cu

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,43 @@ limitations under the License. */
1919
namespace paddle {
2020
namespace operators {
2121

22+
// aligned vector generates vectorized load/store on CUDA
23+
template <typename T, int Size>
24+
struct alignas(sizeof(T) * Size) AlignedVector {
25+
T val[Size];
26+
};
27+
28+
template <typename T>
29+
inline int VectorizedSize(const T* pointer) {
30+
uint64_t address = reinterpret_cast<uint64_t>(pointer);
31+
constexpr int vec4 = std::alignment_of<AlignedVector<T, 4>>::value; // NOLINT
32+
if (address % vec4 == 0) {
33+
return 4;
34+
}
35+
return 1;
36+
}
37+
38+
template <typename InT, typename OutT, int VecSize>
39+
__global__ void VecCastCUDAKernel(const InT* in, const int64_t N, OutT* out) {
40+
int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
41+
using LoadT = AlignedVector<InT, VecSize>;
42+
using StoreT = AlignedVector<OutT, VecSize>;
43+
for (int i = idx * VecSize; i < N; i += blockDim.x * gridDim.x * VecSize) {
44+
InT in_vec[VecSize];
45+
LoadT* in_value = reinterpret_cast<LoadT*>(&in_vec);
46+
*in_value = *reinterpret_cast<const LoadT*>(&in[i]);
47+
48+
OutT out_vec[VecSize];
49+
#pragma unroll
50+
for (int ii = 0; ii < VecSize; ii++) {
51+
out_vec[ii] = static_cast<OutT>(in_vec[ii]);
52+
}
53+
54+
*(reinterpret_cast<StoreT*>(&out[i])) =
55+
*reinterpret_cast<StoreT*>(&out_vec[0]);
56+
}
57+
}
58+
2259
template <typename InT, typename OutT>
2360
__global__ void CastCUDAKernel(const InT* in, const int64_t N, OutT* out) {
2461
CUDA_KERNEL_LOOP(index, N) { out[index] = static_cast<OutT>(in[index]); }
@@ -40,8 +77,16 @@ struct CastOpFunctor<platform::CUDADeviceContext, InT> {
4077
auto* out = out_->mutable_data<OutT>(ctx_.GetPlace());
4178
platform::GpuLaunchConfig config =
4279
platform::GetGpuLaunchConfig1D(ctx_, size);
43-
CastCUDAKernel<InT, OutT><<<config.block_per_grid, config.thread_per_block,
44-
0, ctx_.stream()>>>(in, size, out);
80+
int vec_size = VectorizedSize<OutT>(out);
81+
if (!std::is_same<InT, OutT>::value && vec_size == 4 && size % 4 == 0) {
82+
VecCastCUDAKernel<InT, OutT, 4><<<
83+
config.block_per_grid, config.thread_per_block, 0, ctx_.stream()>>>(
84+
in, size, out);
85+
} else {
86+
CastCUDAKernel<InT, OutT><<<config.block_per_grid,
87+
config.thread_per_block, 0, ctx_.stream()>>>(
88+
in, size, out);
89+
}
4590
}
4691
};
4792

0 commit comments

Comments
 (0)