Skip to content

[Perf] Tune scaled_fp8_quant by increasing vectorization #18844

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jun 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 19 additions & 16 deletions csrc/quantization/fp8/common.cu
Original file line number Diff line number Diff line change
Expand Up @@ -39,33 +39,33 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
fp8_type* __restrict__ token_output = &out[offset];

// For vectorization, token_input and token_output pointers need to be
// aligned at 8-byte and 4-byte addresses respectively.
bool const can_vectorize = hidden_size % 4 == 0;
// aligned at 32-byte and 16-byte addresses respectively.
bool const can_vectorize = hidden_size % 16 == 0;

float absmax_val = 0.0f;
if (can_vectorize) {
absmax_val = thread_max_vec(token_input, hidden_size, tid, blockDim.x);
} else {
for (int i = tid; i < hidden_size; i += blockDim.x) {
float const x = static_cast<float>(token_input[i]);
absmax_val = max(absmax_val, fabs(x));
absmax_val = fmaxf(absmax_val, fabsf(x));
}
}

using BlockReduce = cub::BlockReduce<float, 1024>;
using BlockReduce = cub::BlockReduce<float, 256>;
__shared__ typename BlockReduce::TempStorage reduceStorage;
float const block_absmax_val_maybe =
BlockReduce(reduceStorage).Reduce(absmax_val, cub::Max{}, blockDim.x);
__shared__ float token_scale;
if (tid == 0) {
if (scale_ub) {
token_scale = min(block_absmax_val_maybe, *scale_ub);
token_scale = fminf(block_absmax_val_maybe, *scale_ub);
} else {
token_scale = block_absmax_val_maybe;
}
// token scale computation
token_scale = max(token_scale / quant_type_max_v<fp8_type>,
min_scaling_factor<fp8_type>::val());
token_scale = fmaxf(token_scale / quant_type_max_v<fp8_type>,
min_scaling_factor<fp8_type>::val());
scale[token_idx] = token_scale;
}
__syncthreads();
Expand All @@ -88,10 +88,11 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
torch::Tensor const& input, // [..., d]
torch::Tensor const& scale) // [1]
{
int64_t num_tokens = input.numel() / input.size(-1);
int64_t num_elems = input.numel();
dim3 grid(num_tokens);
dim3 block(1024);
int const block_size = 256;
int const num_tokens = input.numel() / input.size(-1);
int const num_elems = input.numel();
dim3 const grid(num_tokens);
dim3 const block(block_size);
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
Expand All @@ -110,10 +111,11 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
torch::Tensor const& input, // [..., d]
torch::Tensor& scale) // [1]
{
int64_t num_tokens = input.numel() / input.size(-1);
int64_t num_elems = input.numel();
dim3 grid(num_tokens);
dim3 block(1024);
int const block_size = 256;
int const num_tokens = input.numel() / input.size(-1);
int const num_elems = input.numel();
dim3 const grid(num_tokens);
dim3 const block(block_size);
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
Expand Down Expand Up @@ -141,8 +143,9 @@ void dynamic_per_token_scaled_fp8_quant(

int const hidden_size = input.size(-1);
int const num_tokens = input.numel() / hidden_size;
int const block_size = 256;
dim3 const grid(num_tokens);
dim3 const block(std::min(hidden_size, 1024));
dim3 const block(std::min(hidden_size, block_size));

const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
Expand Down
68 changes: 35 additions & 33 deletions csrc/quantization/fp8/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ __device__ __forceinline__ fp8_type scaled_fp8_conversion(float const val,
}

float r =
fmax(-quant_type_max_v<fp8_type>, fmin(x, quant_type_max_v<fp8_type>));
fmaxf(-quant_type_max_v<fp8_type>, fminf(x, quant_type_max_v<fp8_type>));
#ifndef USE_ROCM
return static_cast<fp8_type>(r);
#else
Expand All @@ -65,15 +65,15 @@ template <typename scalar_t, typename fp8_type>
__global__ void segmented_max_reduction(float* __restrict__ scale,
const scalar_t* __restrict__ input,
int64_t num_elems) {
__shared__ float cache[1024];
__shared__ float cache[256];
int64_t i = blockDim.x * blockIdx.x + threadIdx.x;

// First store maximum for all values processes by
// the current thread in cache[threadIdx.x]
scalar_t tmp = 0.0;
while (i < num_elems) {
float x = static_cast<float>(input[i]);
tmp = max(tmp, fabs(x));
tmp = fmaxf(tmp, fabsf(x));
i += blockDim.x * gridDim.x;
}
cache[threadIdx.x] = tmp;
Expand All @@ -100,25 +100,27 @@ template <typename scalar_t>
__device__ float thread_max_vec(scalar_t const* __restrict__ input,
int64_t const num_elems, int const tid,
int const step) {
constexpr size_t VEC_SIZE = 16;
using scalarxN_t = vec_n_t<scalar_t, VEC_SIZE>;
// Vectorized input/output to better utilize memory bandwidth.
vec4_t<scalar_t> const* vectorized_in =
reinterpret_cast<vec4_t<scalar_t> const*>(input);
auto const* vectorized_in = reinterpret_cast<scalarxN_t const*>(input);

int64_t const num_vec_elems = num_elems >> 2;
// num_elems / VEC_SIZE (which is 16)
int64_t const num_vec_elems = num_elems >> 4;
float absmax_val = 0.0f;

#pragma unroll 4
#pragma unroll
for (int64_t i = tid; i < num_vec_elems; i += step) {
vec4_t<scalar_t> in_vec = vectorized_in[i];
absmax_val = max(absmax_val, fabs(in_vec.x));
absmax_val = max(absmax_val, fabs(in_vec.y));
absmax_val = max(absmax_val, fabs(in_vec.z));
absmax_val = max(absmax_val, fabs(in_vec.w));
scalarxN_t in_vec = vectorized_in[i];
#pragma unroll
for (int j = 0; j < VEC_SIZE; ++j) {
absmax_val = fmaxf(absmax_val, fabsf(in_vec.val[j]));
}
}

// Handle the remaining elements if num_elems is not divisible by 4
for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
absmax_val = max(absmax_val, fabs(input[i]));
// Handle the remaining elements if num_elems is not divisible by VEC_SIZE
for (int64_t i = num_vec_elems * VEC_SIZE + tid; i < num_elems; i += step) {
absmax_val = fmaxf(absmax_val, fabsf(input[i]));
}

return absmax_val;
Expand All @@ -130,31 +132,31 @@ __device__ void scaled_fp8_conversion_vec(fp8_type* __restrict__ out,
float const scale,
int64_t const num_elems,
int const tid, int const step) {
using float8x4_t = q8x4_t<fp8_type>;
constexpr size_t VEC_SIZE = 16;
using scalarxN_t = vec_n_t<scalar_t, VEC_SIZE>;
using float8xN_t = q8_n_t<fp8_type, VEC_SIZE>;
// Vectorized input/output to better utilize memory bandwidth.
auto const* vectorized_in = reinterpret_cast<vec4_t<scalar_t> const*>(input);
auto* vectorized_out = reinterpret_cast<float8x4_t*>(out);
auto const* vectorized_in = reinterpret_cast<scalarxN_t const*>(input);
auto* vectorized_out = reinterpret_cast<float8xN_t*>(out);

int64_t const num_vec_elems = num_elems >> 2;
// num_elems / VEC_SIZE (which is 16)
int64_t const num_vec_elems = num_elems >> 4;

#pragma unroll 4
#pragma unroll
for (int64_t i = tid; i < num_vec_elems; i += step) {
vec4_t<scalar_t> in_vec = vectorized_in[i];
float8x4_t out_vec;

out_vec.x = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
static_cast<float>(in_vec.x), scale);
out_vec.y = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
static_cast<float>(in_vec.y), scale);
out_vec.z = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
static_cast<float>(in_vec.z), scale);
out_vec.w = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
static_cast<float>(in_vec.w), scale);
scalarxN_t in_vec = vectorized_in[i];
float8xN_t out_vec;

#pragma unroll
for (int j = 0; j < VEC_SIZE; ++j) {
out_vec.val[j] = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
static_cast<float>(in_vec.val[j]), scale);
}
vectorized_out[i] = out_vec;
}

// Handle the remaining elements if num_elems is not divisible by 4
for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
// Handle the remaining elements if num_elems is not divisible by VEC_SIZE
for (int64_t i = num_vec_elems * VEC_SIZE + tid; i < num_elems; i += step) {
out[i] = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
static_cast<float>(input[i]), scale);
}
Expand Down
99 changes: 50 additions & 49 deletions csrc/quantization/fused_kernels/layernorm_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -140,29 +140,31 @@ __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input,
// sum of squares
float ss = 0.0f;

const int VEC_SIZE = 4;
int32_t const num_vec_elems = hidden_size >> 2;

#pragma unroll 4
for (auto i = threadIdx.x; i < num_vec_elems; i += blockDim.x) {
vec4_t<scalar_t> in = vec_input[i];

vec4_t<float> x;
x.x = static_cast<float>(in.x);
x.y = static_cast<float>(in.y);
x.z = static_cast<float>(in.z);
x.w = static_cast<float>(in.w);
#pragma unroll
for (int j = 0; j < VEC_SIZE; ++j) {
x.val[j] = static_cast<float>(in.val[j]);
}

if constexpr (has_residual) {
vec4_t<scalar_t> r = vec_residual[i];
x.x += static_cast<float>(r.x);
x.y += static_cast<float>(r.y);
x.z += static_cast<float>(r.z);
x.w += static_cast<float>(r.w);
#pragma unroll
for (int j = 0; j < VEC_SIZE; ++j) {
x.val[j] += static_cast<float>(r.val[j]);
}
}

ss += x.x * x.x;
ss += x.y * x.y;
ss += x.z * x.z;
ss += x.w * x.w;
#pragma unroll
for (int j = 0; j < VEC_SIZE; ++j) {
ss += x.val[j] * x.val[j];
}
}

using BlockReduce = cub::BlockReduce<float, 1024>;
Expand Down Expand Up @@ -203,6 +205,7 @@ __device__ void compute_dynamic_per_token_scales(

constexpr scalar_out_t qmax{quant_type_max_v<scalar_out_t>};

const int VEC_SIZE = 4;
int32_t const num_vec_elems = hidden_size >> 2;
float block_absmax_val_maybe = 0.0f;

Expand All @@ -212,26 +215,25 @@ __device__ void compute_dynamic_per_token_scales(
vec4_t<scalar_t> const w = vec_weight[i];

vec4_t<float> x;
x.x = static_cast<float>(in.x);
x.y = static_cast<float>(in.y);
x.z = static_cast<float>(in.z);
x.w = static_cast<float>(in.w);
#pragma unroll
for (int j = 0; j < VEC_SIZE; ++j) {
x.val[j] = static_cast<float>(in.val[j]);
}

if constexpr (has_residual) {
vec4_t<scalar_t> r = vec_residual[i];
x.x += static_cast<float>(r.x);
x.y += static_cast<float>(r.y);
x.z += static_cast<float>(r.z);
x.w += static_cast<float>(r.w);
#pragma unroll
for (int j = 0; j < VEC_SIZE; ++j) {
x.val[j] += static_cast<float>(r.val[j]);
}
}

block_absmax_val_maybe = fmaxf(
block_absmax_val_maybe, fabs(static_cast<scalar_t>(x.x * rms) * w.x));
block_absmax_val_maybe = fmaxf(
block_absmax_val_maybe, fabs(static_cast<scalar_t>(x.y * rms) * w.y));
block_absmax_val_maybe = fmaxf(
block_absmax_val_maybe, fabs(static_cast<scalar_t>(x.z * rms) * w.z));
block_absmax_val_maybe = fmaxf(
block_absmax_val_maybe, fabs(static_cast<scalar_t>(x.w * rms) * w.w));
#pragma unroll
for (int j = 0; j < VEC_SIZE; ++j) {
block_absmax_val_maybe =
fmaxf(block_absmax_val_maybe,
fabs(static_cast<scalar_t>(x.val[j] * rms) * w.val[j]));
}
}

using BlockReduce = cub::BlockReduce<float, 1024>;
Expand Down Expand Up @@ -282,6 +284,7 @@ __device__ void norm_and_quant(scalar_out_t* __restrict__ output,
vec_residual = reinterpret_cast<vec4_t<scalar_t>*>(&residual[token_offset]);
}

const int VEC_SIZE = 4;
int32_t const num_vec_elems = hidden_size >> 2;

// TODO(luka/varun) extract into type-agnostic vectorized quant function to
Expand All @@ -292,33 +295,31 @@ __device__ void norm_and_quant(scalar_out_t* __restrict__ output,
vec4_t<scalar_t> const w = vec_weight[i];

vec4_t<float> x;
x.x = static_cast<float>(in.x);
x.y = static_cast<float>(in.y);
x.z = static_cast<float>(in.z);
x.w = static_cast<float>(in.w);
#pragma unroll
for (int j = 0; j < VEC_SIZE; ++j) {
x.val[j] = static_cast<float>(in.val[j]);
}

if constexpr (has_residual) {
vec4_t<scalar_t> r = vec_residual[i];
x.x += static_cast<float>(r.x);
x.y += static_cast<float>(r.y);
x.z += static_cast<float>(r.z);
x.w += static_cast<float>(r.w);
// Update residual
r.x = static_cast<scalar_t>(x.x);
r.y = static_cast<scalar_t>(x.y);
r.z = static_cast<scalar_t>(x.z);
r.w = static_cast<scalar_t>(x.w);
#pragma unroll
for (int j = 0; j < VEC_SIZE; ++j) {
x.val[j] += static_cast<float>(r.val[j]);
}
// Update residual
#pragma unroll
for (int j = 0; j < VEC_SIZE; ++j) {
r.val[j] = static_cast<scalar_t>(x.val[j]);
}
vec_residual[i] = r;
}

q8x4_t<scalar_out_t> out;
out.x = ScaledQuant<scalar_out_t, is_scale_inverted>::quant_fn(
static_cast<scalar_t>(x.x * rms) * w.x, scale);
out.y = ScaledQuant<scalar_out_t, is_scale_inverted>::quant_fn(
static_cast<scalar_t>(x.y * rms) * w.y, scale);
out.z = ScaledQuant<scalar_out_t, is_scale_inverted>::quant_fn(
static_cast<scalar_t>(x.z * rms) * w.z, scale);
out.w = ScaledQuant<scalar_out_t, is_scale_inverted>::quant_fn(
static_cast<scalar_t>(x.w * rms) * w.w, scale);
#pragma unroll
for (int j = 0; j < VEC_SIZE; ++j) {
out.val[j] = ScaledQuant<scalar_out_t, is_scale_inverted>::quant_fn(
static_cast<scalar_t>(x.val[j] * rms) * w.val[j], scale);
}
vec_output[i] = out;
}
}
Expand Down
23 changes: 11 additions & 12 deletions csrc/quantization/vectorization.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,22 @@
namespace vllm {

// Vectorization containers
template <typename scalar_t>
struct __align__(8) vec4_t {
scalar_t x;
scalar_t y;
scalar_t z;
scalar_t w;
template <typename scalar_t, size_t vec_size>
struct __align__(vec_size * sizeof(scalar_t)) vec_n_t {
scalar_t val[vec_size];
};

template <typename quant_type_t>
struct __align__(4) q8x4_t {
template <typename quant_type_t, size_t vec_size>
struct __align__(vec_size * sizeof(quant_type_t)) q8_n_t {
static_assert(std::is_same_v<quant_type_t, int8_t> ||
std::is_same_v<quant_type_t, c10::Float8_e4m3fn> ||
std::is_same_v<quant_type_t, c10::Float8_e4m3fnuz>);
quant_type_t x;
quant_type_t y;
quant_type_t z;
quant_type_t w;
quant_type_t val[vec_size];
};

template <typename scalar_t>
using vec4_t = vec_n_t<scalar_t, 4>;
template <typename quant_type_t>
using q8x4_t = q8_n_t<quant_type_t, 4>;

} // namespace vllm