11#include  " common.cuh" 
22#include  " dispatch_utils.h" 
3- 
3+ # include   " ../vectorization_utils.cuh " 
44#include  < c10/cuda/CUDAGuard.h> 
5+ #include  < ATen/cuda/Exceptions.h> 
56
67#ifndef  USE_ROCM
78  #include  < cub/cub.cuh> 
1213namespace  vllm  {
1314
1415template  <typename  scalar_t , typename  fp8_type>
15- __global__  void  scaled_fp8_quant_kernel (fp8_type* __restrict__  out,
16-                                         const  scalar_t * __restrict__  input,
17-                                         const  float * __restrict__  scale,
18-                                         int64_t  num_elems) {
19-   int  tid = blockDim .x  * blockIdx .x  + threadIdx .x ;
20- 
21-   //  Invert the scale so that we can use multiplications to avoid expensive
22-   //  division.
23-   const  float  inverted_scale = 1 .0f  / (*scale);
24-   scaled_fp8_conversion_vec<scalar_t , true >(
25-       out, input, inverted_scale, num_elems, tid, blockDim .x  * gridDim .x );
16+ __global__  void  scaled_fp8_quant_kernel_strided (
17+     fp8_type* __restrict__  out, const  scalar_t * __restrict__  input,
18+     const  float * __restrict__  scale, int  hidden_size, int64_t  in_row_stride,
19+     int64_t  out_row_stride) {
20+   const  int64_t  token_idx = blockIdx .x ;  //  one token per block
21+   const  int  tid = threadIdx .x ;
22+ 
23+   const  scalar_t * token_in = input + token_idx * in_row_stride;
24+   fp8_type* token_out = out + token_idx * out_row_stride;
25+ 
26+   const  float  inv_scale = 1 .0f  / (*scale);
27+ 
28+   vectorize_with_alignment<16 >(
29+       token_in, token_out, hidden_size, tid, blockDim .x ,
30+       [=] __device__ (fp8_type & dst, const  scalar_t & src) {
31+         dst = scaled_fp8_conversion<true , fp8_type>(static_cast <float >(src),
32+                                                     inv_scale);
33+       });
2634}
2735
2836template  <typename  scalar_t , typename  fp8_type>
29- __global__  void  dynamic_per_token_scaled_fp8_quant_kernel (
30-     fp8_type* __restrict__  out, float * __restrict__  scale,
31-     scalar_t  const * __restrict__  input, float  const * __restrict__  scale_ub,
32-     const  int  hidden_size) {
33-   int  const  tid = threadIdx .x ;
34-   int  const  token_idx = blockIdx .x ;
37+ __global__  void  segmented_max_reduction_strided (
38+     float * __restrict__  scale, const  scalar_t * __restrict__  input,
39+     int  hidden_size, int64_t  in_row_stride, int64_t  num_tokens) {
40+   __shared__  float  cache[256 ];
41+   const  int  tid = threadIdx .x ;
42+   int64_t  token_idx = blockIdx .x ;
43+ 
44+   //  one block per token. Guard in case gridDim.x > num_tokens.
45+   if  (token_idx >= num_tokens) {
46+     return ;
47+   }
3548
36-   //  Use int64 to avoid overflowing an int32 when calculating this offset 
37-    int64_t  offset =  static_cast < int64_t >(token_idx) * hidden_size; 
38-   scalar_t   const *  __restrict__  token_input = &input[offset]; 
39-   fp8_type*  __restrict__  token_output = &out[offset] ;
40- 
41-   //  For vectorization, token_input and token_output pointers need to be 
42-   //  aligned at 32-byte and 16-byte addresses respectively. 
43-   bool   const  can_vectorize = hidden_size %  16  ==  0 ; 
44- 
45-   float  absmax_val =  0 . 0f ;
46-   if  (can_vectorize) { 
47-     absmax_val =  thread_max_vec (token_input, hidden_size, tid,  blockDim . x ); 
48-   }  else  { 
49-      for  (int  i  = tid; i < hidden_size; i +=  blockDim . x ) {
50-        float   const  x =  static_cast < float >(token_input[i]); 
51-       absmax_val  = fmaxf (absmax_val,  fabsf (x) );
49+   const   scalar_t * row_ptr = input + token_idx * in_row_stride; 
50+ 
51+   //  each thread scans elements of the row in a strided fashion. 
52+   float  thread_max =  0 . 0f ;
53+    for  ( int  e = tid; e < hidden_size; e +=  blockDim . x ) { 
54+      float  v =  fabsf ( static_cast < float >(row_ptr[e])); 
55+     thread_max =  fmaxf (thread_max, v); 
56+   } 
57+ 
58+   cache[tid] = thread_max ;
59+   __syncthreads (); 
60+ 
61+   //  parallel reduction to find row max. 
62+   for  (int  offset  = blockDim . x  /  2 ; offset >  0 ; offset >>=  1 ) {
63+     if  (tid < offset) { 
64+       cache[tid]  = fmaxf (cache[tid], cache[tid + offset] );
5265    }
66+     __syncthreads ();
5367  }
5468
69+   //  thread 0 updates global scale (per-tensor) atomically.
70+   if  (tid == 0 ) {
71+     atomicMaxFloat (scale, cache[0 ] / quant_type_max_v<fp8_type>);
72+   }
73+ }
74+ 
75+ template  <typename  scalar_t , typename  fp8_type>
76+ __global__  void  scaled_fp8_quant_kernel_strided_dynamic (
77+     fp8_type* __restrict__  out, const  scalar_t * __restrict__  input,
78+     const  float * __restrict__  scale, int  hidden_size, int64_t  in_row_stride,
79+     int64_t  out_row_stride) {
80+   const  int64_t  token_idx = blockIdx .x ;
81+   const  int  tid = threadIdx .x ;
82+ 
83+   const  scalar_t * token_in = input + token_idx * in_row_stride;
84+   fp8_type* token_out = out + token_idx * out_row_stride;
85+ 
86+   const  float  reciprocal_scale = 1 .0f  / (*scale);
87+   vectorize_with_alignment<16 >(
88+       token_in, token_out, hidden_size, tid, blockDim .x ,
89+       [=] __device__ (fp8_type & dst, const  scalar_t & src) {
90+         dst = scaled_fp8_conversion<true , fp8_type>(static_cast <float >(src),
91+                                                     reciprocal_scale);
92+       });
93+ }
94+ 
95+ template  <typename  scalar_t , typename  fp8_type>
96+ __global__  void  dynamic_per_token_scaled_fp8_quant_kernel_strided (
97+     fp8_type* __restrict__  out, float * __restrict__  scale,
98+     const  scalar_t * __restrict__  input, const  float * __restrict__  scale_ub,
99+     int  hidden_size, int64_t  in_row_stride, int64_t  out_row_stride) {
100+   const  int64_t  token_idx = blockIdx .x ;
101+   const  int  tid = threadIdx .x ;
102+ 
103+   //  Use int64 to avoid overflowing an int32 when calculating this offset
104+   int64_t  in_offset = static_cast <int64_t >(token_idx) * in_row_stride;
105+   int64_t  out_offset = static_cast <int64_t >(token_idx) * out_row_stride;
106+   const  scalar_t * token_in = input + in_offset;
107+   fp8_type* token_out = out + out_offset;
108+ 
109+   //  1) per-token absmax
110+   float  absmax_val = 0 .f ;
111+   vectorize_read_with_alignment<16 >(
112+       token_in, hidden_size, tid, blockDim .x , [&] __device__ (scalar_t  v) {
113+         absmax_val = fmaxf (absmax_val, fabsf (static_cast <float >(v)));
114+       });
115+ 
55116  using  BlockReduce = cub::BlockReduce<float , 256 >;
56-   __shared__  typename  BlockReduce::TempStorage reduceStorage;
57-   float  const  block_absmax_val_maybe =
58-       BlockReduce (reduceStorage).Reduce (absmax_val, cub::Max{}, blockDim .x );
117+   __shared__  typename  BlockReduce::TempStorage tmp;
118+   const  float  block_max =
119+       BlockReduce (tmp).Reduce (absmax_val, cub::Max{}, blockDim .x );
120+ 
59121  __shared__  float  token_scale;
60122  if  (tid == 0 ) {
61-     if  (scale_ub) {
62-       token_scale = fminf (block_absmax_val_maybe, *scale_ub);
63-     } else  {
64-       token_scale = block_absmax_val_maybe;
65-     }
66-     //  token scale computation
123+     token_scale = scale_ub ? fminf (block_max, *scale_ub) : block_max;
67124    token_scale = fmaxf (token_scale / quant_type_max_v<fp8_type>,
68125                        min_scaling_factor<fp8_type>::val ());
69126    scale[token_idx] = token_scale;
70127  }
71128  __syncthreads ();
72129
73-   //  Note that we don't use inverted scales so we can match FBGemm impl.
74-   if  (can_vectorize) {
75-     scaled_fp8_conversion_vec<scalar_t , false >(
76-         token_output, token_input, token_scale, hidden_size, tid, blockDim .x );
77-   } else  {
78-     for  (int  i = tid; i < hidden_size; i += blockDim .x ) {
79-       token_output[i] = scaled_fp8_conversion<false , fp8_type>(
80-           static_cast <float >(token_input[i]), token_scale);
81-     }
82-   }
130+   //  2) quantize
131+   vectorize_with_alignment<16 >(
132+       token_in, token_out, hidden_size, tid, blockDim .x ,
133+       [=] __device__ (fp8_type & dst, const  scalar_t & src) {
134+         dst = scaled_fp8_conversion<false , fp8_type>(static_cast <float >(src),
135+                                                      token_scale);
136+       });
83137}
84138
85139}  //  namespace vllm
@@ -88,23 +142,31 @@ void static_scaled_fp8_quant(torch::Tensor& out,          // [..., d]
88142                             torch::Tensor const & input,  //  [..., d]
89143                             torch::Tensor const & scale)  //  [1]
90144{
91-   TORCH_CHECK (input.is_contiguous ());
92-   TORCH_CHECK (out.is_contiguous ());
93-   int  const  block_size = 256 ;
94-   int  const  num_tokens = input.numel () / input.size (-1 );
95-   int  const  num_elems = input.numel ();
96-   dim3  const  grid (num_tokens);
97-   dim3  const  block (block_size);
145+   TORCH_CHECK (input.stride (-1 ) == 1 ,
146+               " last dimension of input must be contiguous" 
147+   TORCH_CHECK (out.stride (-1 ) == 1 ,
148+               " last dimension of output must be contiguous" 
149+ 
150+   const  int  hidden_size = input.size (-1 );
151+   const  int  num_tokens = input.numel () / hidden_size;
152+   const  int  block_size = 256 ;
153+   dim3  grid (num_tokens);
154+   dim3  block (block_size);
155+ 
156+   const  int64_t  in_row_stride = input.stride (-2 );
157+   const  int64_t  out_row_stride = out.stride (-2 );
158+ 
98159  const  at::cuda::OptionalCUDAGuard device_guard (device_of (input));
99160  const  cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
100161  VLLM_DISPATCH_FLOATING_TYPES (
101162      input.scalar_type (), " scaled_fp8_quant_kernel_scalar_type" 
102163        VLLM_DISPATCH_FP8_TYPES (
103164            out.scalar_type (), " scaled_fp8_quant_kernel_fp8_type" 
104-               vllm::scaled_fp8_quant_kernel <scalar_t , fp8_t >
165+               vllm::scaled_fp8_quant_kernel_strided <scalar_t , fp8_t >
105166                  <<<grid, block, 0 , stream>>> (
106167                      out.data_ptr <fp8_t >(), input.data_ptr <scalar_t >(),
107-                       scale.data_ptr <float >(), num_elems);
168+                       scale.data_ptr <float >(), hidden_size, in_row_stride,
169+                       out_row_stride);
108170            });
109171      });
110172}
@@ -113,27 +175,42 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out,          // [..., d]
113175                              torch::Tensor const & input,  //  [..., d]
114176                              torch::Tensor& scale)        //  [1]
115177{
116-   TORCH_CHECK (input.is_contiguous ());
117-   TORCH_CHECK (out.is_contiguous ());
118-   int  const  block_size = 256 ;
119-   int  const  num_tokens = input.numel () / input.size (-1 );
120-   int  const  num_elems = input.numel ();
121-   dim3  const  grid (num_tokens);
122-   dim3  const  block (block_size);
178+   TORCH_CHECK (input.stride (-1 ) == 1 ,
179+               " last dimension of input must be contiguous" 
180+   TORCH_CHECK (out.stride (-1 ) == 1 ,
181+               " last dimension of output must be contiguous" 
182+ 
183+   const  int  hidden_size = input.size (-1 );
184+   const  int  num_tokens = input.numel () / hidden_size;
185+   const  int  block_size = 256 ;
186+   dim3  grid (num_tokens);
187+   dim3  block (block_size);
188+ 
189+   const  int64_t  in_row_stride = input.stride (-2 );
190+   const  int64_t  out_row_stride = out.stride (-2 );
191+ 
123192  const  at::cuda::OptionalCUDAGuard device_guard (device_of (input));
124193  const  cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
194+ 
195+   //  scale tensor should be initialised to <=0 before reduction
196+   AT_CUDA_CHECK (
197+       cudaMemsetAsync (scale.data_ptr <float >(), 0 , sizeof (float ), stream));
198+ 
125199  VLLM_DISPATCH_FLOATING_TYPES (
126200      input.scalar_type (), " scaled_fp8_quant_kernel_scalar_type" 
127201        VLLM_DISPATCH_FP8_TYPES (
128202            out.scalar_type (), " scaled_fp8_quant_kernel_fp8_type" 
129-               vllm::segmented_max_reduction<scalar_t , fp8_t >
130-                   <<<grid, block, 0 , stream>>> (scale.data_ptr <float >(),
131-                                                input.data_ptr <scalar_t >(),
132-                                                num_elems);
133-               vllm::scaled_fp8_quant_kernel<scalar_t , fp8_t >
203+               vllm::segmented_max_reduction_strided<scalar_t , fp8_t >
204+                   <<<grid, block, 0 , stream>>> (
205+                       scale.data_ptr <float >(), input.data_ptr <scalar_t >(),
206+                       hidden_size, in_row_stride,
207+                       static_cast <int64_t >(num_tokens));
208+ 
209+               vllm::scaled_fp8_quant_kernel_strided_dynamic<scalar_t , fp8_t >
134210                  <<<grid, block, 0 , stream>>> (
135211                      out.data_ptr <fp8_t >(), input.data_ptr <scalar_t >(),
136-                       scale.data_ptr <float >(), num_elems);
212+                       scale.data_ptr <float >(), hidden_size, in_row_stride,
213+                       out_row_stride);
137214            });
138215      });
139216}
@@ -142,14 +219,19 @@ void dynamic_per_token_scaled_fp8_quant(
142219    torch::Tensor& out,          //  [..., d]
143220    torch::Tensor const & input,  //  [..., d]
144221    torch::Tensor& scales, std::optional<at::Tensor> const & scale_ub) {
145-   TORCH_CHECK (input.is_contiguous ());
146-   TORCH_CHECK (out.is_contiguous ());
222+   TORCH_CHECK (input.stride (-1 ) == 1 ,
223+               " last dimension of input must be contiguous" 
224+   TORCH_CHECK (out.stride (-1 ) == 1 ,
225+               " last dimension of output must be contiguous" 
147226
148-   int  const  hidden_size = input.size (-1 );
149-   int  const  num_tokens = input.numel () / hidden_size;
150-   int  const  block_size = 256 ;
151-   dim3  const  grid (num_tokens);
152-   dim3  const  block (std::min (hidden_size, block_size));
227+   const  int  hidden_size = input.size (-1 );
228+   const  int  num_tokens = input.numel () / hidden_size;
229+   const  int  block_size = 256 ;
230+   dim3  grid (num_tokens);
231+   dim3  block (std::min (hidden_size, block_size));
232+ 
233+   const  int64_t  in_row_stride = input.stride (-2 );
234+   const  int64_t  out_row_stride = out.stride (-2 );
153235
154236  const  at::cuda::OptionalCUDAGuard device_guard (device_of (input));
155237  const  cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
@@ -159,13 +241,12 @@ void dynamic_per_token_scaled_fp8_quant(
159241        VLLM_DISPATCH_FP8_TYPES (
160242            out.scalar_type (),
161243            " dynamic_per_token_scaled_fp8_quant_kernel_fp8_type" 
162-               vllm::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t , fp8_t >
163-                   <<<grid, block, 0 , stream>>> (
164-                       out.data_ptr <fp8_t >(), scales.data_ptr <float >(),
165-                       input.data_ptr <scalar_t >(),
166-                       scale_ub.has_value () ? scale_ub->data_ptr <float >()
167-                                            : nullptr ,
168-                       hidden_size);
244+               vllm::dynamic_per_token_scaled_fp8_quant_kernel_strided<
245+                   scalar_t , fp8_t ><<<grid, block, 0 , stream>>> (
246+                   out.data_ptr <fp8_t >(), scales.data_ptr <float >(),
247+                   input.data_ptr <scalar_t >(),
248+                   scale_ub.has_value () ? scale_ub->data_ptr <float >() : nullptr ,
249+                   hidden_size, in_row_stride, out_row_stride);
169250            });
170251      });
171252}
0 commit comments