@@ -21,74 +21,74 @@ namespace math {
2121
2222template <typename T, bool Padding>
2323__global__ void SequencePaddingKernel (
24- T* padding_data, T* seq_data, const size_t * abs_offset,
25- const size_t & seq_num, const size_t & max_seq_len, const size_t & seq_width,
26- const PaddingLayout& padding_layout, bool norm_by_times = false ,
27- const T& padding_value = 0 ) {
28- size_t padding_idx = blockIdx .y ;
29- size_t seq_start = abs_offset[padding_idx];
30- size_t seq_len = abs_offset[padding_idx + 1 ] - seq_start;
24+ T* pad_data, T* seq_data, const size_t * seq_offset, const size_t & seq_num,
25+ const size_t & max_seq_len, const size_t & seq_width, bool norm_by_times,
26+ const T& pad_value, const OutputLayout& output_layout) {
27+ size_t seq_idx = blockIdx .y ;
28+ size_t seq_start = seq_offset[seq_idx];
29+ size_t seq_len = seq_offset[seq_idx + 1 ] - seq_start;
3130
32- size_t seq_idx = blockIdx .x * blockDim .y + threadIdx .y ;
31+ size_t seq_step_idx = blockIdx .x * blockDim .y + threadIdx .y ;
3332
34- size_t seq_offset = (seq_start + seq_idx ) * seq_width;
33+ size_t seq_data_offset = (seq_start + seq_step_idx ) * seq_width;
3534
36- size_t padding_offset = 0 ;
35+ size_t pad_data_offset = 0 ;
3736
38- if (padding_layout == LENGTH_BATCH_WIDTH ) {
39- padding_offset = (seq_idx * seq_num + padding_idx ) * seq_width;
37+ if (output_layout == kLengthBatchWidth ) {
38+ pad_data_offset = (seq_step_idx * seq_num + seq_idx ) * seq_width;
4039 } else {
41- padding_offset = (padding_idx * max_seq_len + seq_idx ) * seq_width;
40+ pad_data_offset = (seq_idx * max_seq_len + seq_step_idx ) * seq_width;
4241 }
4342
44- if (seq_idx < seq_len) {
43+ if (seq_step_idx < seq_len) {
4544 T scale = norm_by_times ? (1 .0f / static_cast <T>(seq_len)) : 1 .0f ;
4645 if (Padding) {
47- /* sequence -> padding */
46+ /* seq -> pad */
4847 for (size_t i = threadIdx .x ; i < seq_width; i += blockDim .x ) {
49- padding_data[padding_offset + i] = scale * seq_data[seq_offset + i];
48+ pad_data[pad_data_offset + i] = scale * seq_data[seq_data_offset + i];
5049 }
5150 } else {
52- /* padding -> sequence */
51+ /* pad -> seq */
5352 for (size_t i = threadIdx .x ; i < seq_width; i += blockDim .x ) {
54- seq_data[seq_offset + i] = scale * padding_data[padding_offset + i];
53+ seq_data[seq_data_offset + i] = scale * pad_data[pad_data_offset + i];
5554 }
5655 }
57- } else if (seq_idx < max_seq_len) {
56+ } else if (seq_step_idx < max_seq_len) {
5857 if (Padding) {
59- /* sequence -> padding */
58+ /* seq -> pad */
6059 for (size_t i = threadIdx .x ; i < seq_width; i += blockDim .x ) {
61- padding_data[padding_offset + i] = padding_value ;
60+ pad_data[pad_data_offset + i] = pad_value ;
6261 }
6362 }
6463 }
6564}
6665
67- template <typename T, PaddingLayout padding_layout >
68- class PaddingLoDTensorFunctor <platform::CUDADeviceContext, T, padding_layout > {
66+ template <typename T>
67+ class PaddingLoDTensorFunctor <platform::CUDADeviceContext, T> {
6968 public:
7069 void operator ()(const platform::CUDADeviceContext& context,
7170 const framework::LoDTensor& seq_tensor,
72- framework::Tensor* padding_tensor,
73- T padding_value = static_cast <T>(0 ),
74- bool norm_by_times = false, size_t lod_level = 0) {
75- ValidateLoD (seq_tensor, lod_level);
71+ framework::Tensor* pad_tensor,
72+ T pad_value = static_cast <T>(0 ), bool norm_by_times = false,
73+ size_t lod_level = 0,
74+ OutputLayout output_layout = kBatchLengthWidth) {
75+ CheckLoD (seq_tensor, lod_level);
7676
7777 auto & lod = seq_tensor.lod ();
78- auto & abs_offset = framework::ToAbsOffset (lod)[lod_level];
78+ auto & seq_offset = framework::ToAbsOffset (lod)[lod_level];
7979
80- auto seq_dims = seq_tensor.dims ();
81- auto padding_dims = padding_tensor ->dims ();
82- int64_t max_seq_len = MaximumSequenceLength (lod, lod_level );
83- const int64_t seq_num = abs_offset .size () - 1 ;
84- const int64_t seq_width = seq_tensor.numel () / seq_dims [0 ];
80+ auto seq_tensor_dims = seq_tensor.dims ();
81+ auto pad_tensor_dims = pad_tensor ->dims ();
82+ int64_t max_seq_len = MaximumSequenceLength (seq_offset );
83+ int64_t seq_num = seq_offset .size () - 1 ;
84+ int64_t seq_width = seq_tensor.numel () / seq_tensor_dims [0 ];
8585
86- ValidateShape (seq_dims, abs_offset .back (), padding_dims , max_seq_len,
87- seq_num, seq_width, padding_layout );
86+ CheckDims (seq_tensor_dims, seq_offset .back (), pad_tensor_dims , max_seq_len,
87+ seq_num, seq_width, output_layout );
8888
8989 if (!norm_by_times && seq_num == 1UL ) {
90- TensorCopy (seq_tensor, context.GetPlace (), context, padding_tensor );
91- padding_tensor ->Resize (padding_dims );
90+ TensorCopy (seq_tensor, context.GetPlace (), context, pad_tensor );
91+ pad_tensor ->Resize (pad_tensor_dims );
9292 return ;
9393 }
9494
@@ -107,37 +107,40 @@ class PaddingLoDTensorFunctor<platform::CUDADeviceContext, T, padding_layout> {
107107 dim3 grid (grid_dim_x, grid_dim_y);
108108
109109 const T* seq_data = seq_tensor.data <T>();
110- T* padding_data = padding_tensor ->data <T>();
110+ T* pad_data = pad_tensor ->data <T>();
111111
112112 SequencePaddingKernel<T, 1 ><<<grid, threads, 0 , context.stream()>>> (
113- padding_data , const_cast <T*>(seq_data),
114- abs_offset .CUDAData (context.GetPlace ()), seq_num, max_seq_len,
115- seq_width, padding_layout, norm_by_times, padding_value );
113+ pad_data , const_cast <T*>(seq_data),
114+ seq_offset .CUDAData (context.GetPlace ()), seq_num, max_seq_len,
115+ seq_width, norm_by_times, pad_value, output_layout );
116116 }
117117};
118118
119- template <typename T, PaddingLayout padding_layout>
120- class UnpaddingLoDTensorFunctor <platform::CUDADeviceContext, T,
121- padding_layout> {
119+ template <typename T>
120+ class UnpaddingLoDTensorFunctor <platform::CUDADeviceContext, T> {
122121 public:
123122 void operator ()(const platform::CUDADeviceContext& context,
124123 framework::LoDTensor* seq_tensor,
125- const framework::Tensor& padding_tensor,
126- bool norm_by_times = false , size_t lod_level = 0 ) {
127- ValidateLoD (*seq_tensor, lod_level);
124+ const framework::Tensor& pad_tensor,
125+ bool norm_by_times = false , size_t lod_level = 0 ,
126+ OutputLayout output_layout = kBatchLengthWidth ) {
127+ CheckLoD (*seq_tensor, lod_level);
128128
129129 auto & lod = seq_tensor->lod ();
130- auto & abs_offset = framework::ToAbsOffset (lod)[lod_level];
130+ auto & seq_offset = framework::ToAbsOffset (lod)[lod_level];
131131
132- auto seq_dims = seq_tensor->dims ();
133- auto padding_dims = padding_tensor.dims ();
134- int64_t max_seq_len = MaximumSequenceLength (lod, lod_level);
135- int64_t seq_num = abs_offset.size () - 1 ;
136- int64_t seq_width = seq_tensor->numel () / seq_dims[0 ];
132+ auto seq_tensor_dims = seq_tensor->dims ();
133+ auto pad_tensor_dims = pad_tensor.dims ();
134+ int64_t max_seq_len = MaximumSequenceLength (seq_offset);
135+ int64_t seq_num = seq_offset.size () - 1 ;
136+ int64_t seq_width = seq_tensor->numel () / seq_tensor_dims[0 ];
137+
138+ CheckDims (seq_tensor_dims, seq_offset.back (), pad_tensor_dims, max_seq_len,
139+ seq_num, seq_width, output_layout);
137140
138141 if (!norm_by_times && seq_num == 1UL ) {
139- TensorCopy (padding_tensor , context.GetPlace (), context, seq_tensor);
140- seq_tensor->Resize (seq_dims );
142+ TensorCopy (pad_tensor , context.GetPlace (), context, seq_tensor);
143+ seq_tensor->Resize (seq_tensor_dims );
141144 return ;
142145 }
143146
@@ -155,20 +158,25 @@ class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, T,
155158 size_t grid_dim_y = seq_num;
156159 dim3 grid (grid_dim_x, grid_dim_y);
157160
158- const T* padding_data = padding_tensor .data <T>();
161+ const T* pad_data = pad_tensor .data <T>();
159162 T* seq_data = seq_tensor->data <T>();
160163
161- SequencePaddingKernel<T, 1 ><<<grid, threads, 0 , context.stream()>>> (
162- const_cast <T*>(padding_data ), seq_data,
163- abs_offset .CUDAData (context.GetPlace ()), seq_num, max_seq_len,
164- seq_width, padding_layout, norm_by_times );
164+ SequencePaddingKernel<T, 0 ><<<grid, threads, 0 , context.stream()>>> (
165+ const_cast <T*>(pad_data ), seq_data,
166+ seq_offset .CUDAData (context.GetPlace ()), seq_num, max_seq_len,
167+ seq_width, norm_by_times, static_cast <T>( 0 ), output_layout );
165168 }
166169};
167170
168- template class PaddingLoDTensorFunctor <platform::CUDADeviceContext, float ,
169- LENGTH_BATCH_WIDTH>;
170- template class UnpaddingLoDTensorFunctor <platform::CUDADeviceContext, float ,
171- LENGTH_BATCH_WIDTH>;
171+ template class PaddingLoDTensorFunctor <platform::CUDADeviceContext, int >;
172+ template class PaddingLoDTensorFunctor <platform::CUDADeviceContext, int64_t >;
173+ template class PaddingLoDTensorFunctor <platform::CUDADeviceContext, float >;
174+ template class PaddingLoDTensorFunctor <platform::CUDADeviceContext, double >;
175+
176+ template class UnpaddingLoDTensorFunctor <platform::CUDADeviceContext, int >;
177+ template class UnpaddingLoDTensorFunctor <platform::CUDADeviceContext, int64_t >;
178+ template class UnpaddingLoDTensorFunctor <platform::CUDADeviceContext, float >;
179+ template class UnpaddingLoDTensorFunctor <platform::CUDADeviceContext, double >;
172180
173181} // namespace math
174182} // namespace operators
0 commit comments