Skip to content

[Do not merge] Add out of place layernorm #20197

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
8 changes: 7 additions & 1 deletion benchmarks/fused_kernels/layernorm_rms_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,17 @@ def unfused_fp8_impl(
def fused_impl(
rms_norm_layer: RMSNorm, # this stores the weights
x: torch.Tensor,
residual_out: Optional[torch.Tensor],
residual: Optional[torch.Tensor],
quant_dtype: torch.dtype,
):
out, _ = ops.rms_norm_dynamic_per_token_quant(
x, rms_norm_layer.weight, 1e-6, quant_dtype, residual=residual
x,
rms_norm_layer.weight,
1e-6,
quant_dtype,
residual_out=residual_out,
residual=residual,
)


Expand Down
68 changes: 41 additions & 27 deletions csrc/layernorm_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,11 @@ __global__ void rms_norm_kernel(
template <typename scalar_t, int width>
__global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
fused_add_rms_norm_kernel(
scalar_t* __restrict__ input, // [..., hidden_size]
scalar_t* __restrict__ residual, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
scalar_t* __restrict__ output, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size]
scalar_t* __restrict__ residual_out, // [..., hidden_size]
const scalar_t* __restrict__ residual, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon, const int num_tokens, const int hidden_size) {
// Sanity checks on our vector struct and type-punned pointer arithmetic
static_assert(std::is_pod_v<_f16Vec<scalar_t, width>>);
Expand All @@ -64,10 +66,14 @@ fused_add_rms_norm_kernel(
/* These and the argument pointers are all declared `restrict` as they are
not aliased in practice. Argument pointers should not be dereferenced
in this kernel as that would be undefined behavior */
auto* __restrict__ output_v =
reinterpret_cast<_f16Vec<scalar_t, width>*>(output);
auto* __restrict__ input_v =
reinterpret_cast<_f16Vec<scalar_t, width>*>(input);
reinterpret_cast<const _f16Vec<scalar_t, width>*>(input);
auto* __restrict__ residual_out_v =
reinterpret_cast<_f16Vec<scalar_t, width>*>(residual_out);
auto* __restrict__ residual_v =
reinterpret_cast<_f16Vec<scalar_t, width>*>(residual);
reinterpret_cast<const _f16Vec<scalar_t, width>*>(residual);
auto* __restrict__ weight_v =
reinterpret_cast<const _f16Vec<scalar_t, width>*>(weight);

Expand All @@ -76,7 +82,7 @@ fused_add_rms_norm_kernel(
_f16Vec<scalar_t, width> temp = input_v[id];
temp += residual_v[id];
variance += temp.sum_squares();
residual_v[id] = temp;
residual_out_v[id] = temp;
}

using BlockReduce = cub::BlockReduce<float, 1024>;
Expand All @@ -90,10 +96,10 @@ fused_add_rms_norm_kernel(

for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
int id = blockIdx.x * vec_hidden_size + idx;
_f16Vec<scalar_t, width> temp = residual_v[id];
_f16Vec<scalar_t, width> temp = residual_out_v[id];
temp *= s_variance;
temp *= weight_v[idx];
input_v[id] = temp;
output_v[id] = temp;
}
}

Expand All @@ -103,9 +109,11 @@ fused_add_rms_norm_kernel(
template <typename scalar_t, int width>
__global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
fused_add_rms_norm_kernel(
scalar_t* __restrict__ input, // [..., hidden_size]
scalar_t* __restrict__ residual, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
scalar_t* __restrict__ output, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size]
scalar_t* __restrict__ residual_out, // [..., hidden_size]
const scalar_t* __restrict__ residual, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon, const int num_tokens, const int hidden_size) {
__shared__ float s_variance;
float variance = 0.0f;
Expand All @@ -115,7 +123,7 @@ fused_add_rms_norm_kernel(
z += residual[blockIdx.x * hidden_size + idx];
float x = (float)z;
variance += x * x;
residual[blockIdx.x * hidden_size + idx] = z;
residual_out[blockIdx.x * hidden_size + idx] = z;
}

using BlockReduce = cub::BlockReduce<float, 1024>;
Expand All @@ -128,8 +136,8 @@ fused_add_rms_norm_kernel(
__syncthreads();

for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float)residual[blockIdx.x * hidden_size + idx];
input[blockIdx.x * hidden_size + idx] =
float x = (float)residual_out[blockIdx.x * hidden_size + idx];
output[blockIdx.x * hidden_size + idx] =
((scalar_t)(x * s_variance)) * weight[idx];
}
}
Expand Down Expand Up @@ -158,19 +166,22 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
});
}

#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \
vllm::fused_add_rms_norm_kernel<scalar_t, width> \
<<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(), \
residual.data_ptr<scalar_t>(), \
weight.data_ptr<scalar_t>(), epsilon, \
num_tokens, hidden_size); \
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \
vllm::fused_add_rms_norm_kernel<scalar_t, width> \
<<<grid, block, 0, stream>>>( \
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), \
residual_out.data_ptr<scalar_t>(), \
residual.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(), \
epsilon, num_tokens, hidden_size); \
});

void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
torch::Tensor& residual, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size]
void fused_add_rms_norm(torch::Tensor& out, // [..., hidden_size]
torch::Tensor& input, // [..., hidden_size]
torch::Tensor& residual_out, // [..., hidden_size]
torch::Tensor& residual, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size]
double epsilon) {
Comment on lines +180 to 185
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The input, residual, and weight tensors are not modified within this function; they are read-only. For const-correctness, they should be passed as const torch::Tensor&.

void fused_add_rms_norm(torch::Tensor& out,           // [..., hidden_size]
                        const torch::Tensor& input,         // [..., hidden_size]
                        torch::Tensor& residual_out,  // [..., hidden_size]
                        const torch::Tensor& residual,      // [..., hidden_size]
                        const torch::Tensor& weight,        // [hidden_size]

int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
Expand All @@ -191,11 +202,14 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
However, this requires each tensor's data to be aligned to 16
bytes.
*/
auto out_ptr = reinterpret_cast<std::uintptr_t>(out.data_ptr());
auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
auto res_out_ptr = reinterpret_cast<std::uintptr_t>(residual_out.data_ptr());
auto res_ptr = reinterpret_cast<std::uintptr_t>(residual.data_ptr());
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
bool ptrs_are_aligned =
inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0;
bool ptrs_are_aligned = out_ptr % 16 == 0 && inp_ptr % 16 == 0 &&
res_out_ptr % 16 == 0 && res_ptr % 16 == 0 &&
wt_ptr % 16 == 0;
if (ptrs_are_aligned && hidden_size % 8 == 0) {
LAUNCH_FUSED_ADD_RMS_NORM(8);
} else {
Expand Down
24 changes: 15 additions & 9 deletions csrc/layernorm_quant_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ __global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
fused_add_rms_norm_static_fp8_quant_kernel(
fp8_type* __restrict__ out, // [..., hidden_size]
scalar_t* __restrict__ input, // [..., hidden_size]
scalar_t* __restrict__ residual_out, // [..., hidden_size]
scalar_t* __restrict__ residual, // [..., hidden_size]
Comment on lines 67 to 69
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The input and residual pointers are only used for reading in this kernel. They should be declared as const scalar_t* __restrict__ to enforce const-correctness.

    const scalar_t* __restrict__ input,         // [..., hidden_size]
    scalar_t* __restrict__ residual_out,  // [..., hidden_size]
    const scalar_t* __restrict__ residual,      // [..., hidden_size]

const scalar_t* __restrict__ weight, // [hidden_size]
const float* __restrict__ scale, // [1]
Expand All @@ -81,6 +82,8 @@ fused_add_rms_norm_static_fp8_quant_kernel(
in this kernel as that would be undefined behavior */
auto* __restrict__ input_v =
reinterpret_cast<_f16Vec<scalar_t, width>*>(input);
auto* __restrict__ residual_out_v =
reinterpret_cast<_f16Vec<scalar_t, width>*>(residual_out);
auto* __restrict__ residual_v =
reinterpret_cast<_f16Vec<scalar_t, width>*>(residual);
auto* __restrict__ weight_v =
Expand All @@ -91,7 +94,7 @@ fused_add_rms_norm_static_fp8_quant_kernel(
_f16Vec<scalar_t, width> temp = input_v[id];
temp += residual_v[id];
variance += temp.sum_squares();
residual_v[id] = temp;
residual_out_v[id] = temp;
}

using BlockReduce = cub::BlockReduce<float, 1024>;
Expand All @@ -108,7 +111,7 @@ fused_add_rms_norm_static_fp8_quant_kernel(

for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
int id = blockIdx.x * vec_hidden_size + idx;
_f16Vec<scalar_t, width> temp = residual_v[id];
_f16Vec<scalar_t, width> temp = residual_out_v[id];
temp *= s_variance;
temp *= weight_v[idx];
#pragma unroll
Expand All @@ -127,6 +130,7 @@ __global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
fused_add_rms_norm_static_fp8_quant_kernel(
fp8_type* __restrict__ out, // [..., hidden_size]
scalar_t* __restrict__ input, // [..., hidden_size]
scalar_t* __restrict__ residual_out, // [..., hidden_size]
scalar_t* __restrict__ residual, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float* __restrict__ scale, // [1]
Expand All @@ -139,7 +143,7 @@ fused_add_rms_norm_static_fp8_quant_kernel(
z += residual[blockIdx.x * hidden_size + idx];
float x = (float)z;
variance += x * x;
residual[blockIdx.x * hidden_size + idx] = z;
residual_out[blockIdx.x * hidden_size + idx] = z;
}

using BlockReduce = cub::BlockReduce<float, 1024>;
Expand All @@ -155,7 +159,7 @@ fused_add_rms_norm_static_fp8_quant_kernel(
float const scale_inv = 1.0f / *scale;

for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float)residual[blockIdx.x * hidden_size + idx];
float x = (float)residual_out[blockIdx.x * hidden_size + idx];
float const out_norm = ((scalar_t)(x * s_variance)) * weight[idx];
out[blockIdx.x * hidden_size + idx] =
scaled_fp8_conversion<true, fp8_type>(out_norm, scale_inv);
Expand Down Expand Up @@ -198,17 +202,19 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size]
width, fp8_t> \
<<<grid, block, 0, stream>>>( \
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(), \
residual_out.data_ptr<scalar_t>(), \
residual.data_ptr<scalar_t>(), \
weight.data_ptr<scalar_t>(), scale.data_ptr<float>(), \
epsilon, num_tokens, hidden_size); \
}); \
});
void fused_add_rms_norm_static_fp8_quant(
torch::Tensor& out, // [..., hidden_size],
torch::Tensor& input, // [..., hidden_size]
torch::Tensor& residual, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size]
torch::Tensor& scale, // [1]
torch::Tensor& out, // [..., hidden_size],
torch::Tensor& input, // [..., hidden_size]
torch::Tensor& residual_out, // [..., hidden_size]
torch::Tensor& residual, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size]
torch::Tensor& scale, // [1]
double epsilon) {
Comment on lines 211 to 218
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The input, residual, weight, and scale tensors are read-only in this function. They should be passed as const torch::Tensor& to ensure they are not modified and to improve code clarity.

void fused_add_rms_norm_static_fp8_quant(
    torch::Tensor& out,           // [..., hidden_size],
    const torch::Tensor& input,         // [..., hidden_size]
    torch::Tensor& residual_out,  // [..., hidden_size]
    const torch::Tensor& residual,      // [..., hidden_size]
    const torch::Tensor& weight,        // [hidden_size]
    const torch::Tensor& scale,         // [1]
    double epsilon) {

int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
Expand Down
5 changes: 4 additions & 1 deletion csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ void convert_vertical_slash_indexes_mergehead(
void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
double epsilon);

void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
void fused_add_rms_norm(torch::Tensor& out, torch::Tensor& input,
torch::Tensor& residual_out, torch::Tensor& residual,
torch::Tensor& weight, double epsilon);

void apply_repetition_penalties_(torch::Tensor& logits,
Expand All @@ -103,6 +104,7 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input,

void fused_add_rms_norm_static_fp8_quant(torch::Tensor& out,
torch::Tensor& input,
torch::Tensor& residual_out,
torch::Tensor& residual,
torch::Tensor& weight,
torch::Tensor& scale, double epsilon);
Expand All @@ -113,6 +115,7 @@ void rms_norm_dynamic_per_token_quant(torch::Tensor& out,
torch::Tensor& scales,
double const epsilon,
std::optional<torch::Tensor> scale_ub,
std::optional<torch::Tensor> residual_out,
std::optional<torch::Tensor> residual);
Comment on lines +118 to 119
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The residual tensor is read-only. It should be passed by const& to avoid an unnecessary copy of the std::optional and to signal its read-only nature.

Suggested change
std::optional<torch::Tensor> residual_out,
std::optional<torch::Tensor> residual);
std::optional<torch::Tensor> residual_out,
const std::optional<torch::Tensor>& residual);


void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ __device__ void rms_norm_dynamic_per_token_quant_vec(
scalar_t const* __restrict__ input, // [..., hidden_size]
scalar_t const* __restrict__ weight, // [hidden_size]
float const* scale_ub, float const var_epsilon, int32_t const hidden_size,
scalar_t* __restrict__ residual_out = nullptr,
scalar_t* __restrict__ residual = nullptr) {
Comment on lines +18 to 19
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The residual pointer is only read from in this function. For const-correctness, it should be declared as const scalar_t* __restrict__.

    scalar_t* __restrict__ residual_out = nullptr,
    const scalar_t* __restrict__ residual = nullptr) {

float rms = 0.0f;
float token_scale = 0.0f;
Expand All @@ -33,12 +34,14 @@ __device__ void rms_norm_dynamic_per_token_quant_vec(
if constexpr (std::is_same_v<scalar_out_t, int8_t>) {
vllm::vectorized::norm_and_quant<scalar_t, scalar_out_t, true,
has_residual>(
out, input, weight, rms, 1.0f / token_scale, hidden_size, residual);
out, input, weight, rms, 1.0f / token_scale, hidden_size, residual_out,
residual);
} else {
// FP8 - Do not invert token_scale for exact match with FBGemm
vllm::vectorized::norm_and_quant<scalar_t, scalar_out_t, false,
has_residual>(
out, input, weight, rms, token_scale, hidden_size, residual);
has_residual>(out, input, weight, rms,
token_scale, hidden_size,
residual_out, residual);
}
}

Expand All @@ -50,6 +53,7 @@ __global__ void rms_norm_dynamic_per_token_quant_kernel(
scalar_t const* __restrict__ input, // [..., hidden_size]
scalar_t const* __restrict__ weight, // [hidden_size]
float const* scale_ub, float const var_epsilon, int32_t const hidden_size,
scalar_t* __restrict__ residual_out = nullptr,
scalar_t* __restrict__ residual = nullptr) {
// For vectorization, token_input and token_output pointers need to be
// aligned at 8-byte and 4-byte addresses respectively.
Expand All @@ -76,11 +80,13 @@ __global__ void rms_norm_dynamic_per_token_quant_kernel(
// RMS Norm + Quant
if constexpr (std::is_same_v<scalar_out_t, int8_t>) {
vllm::norm_and_quant<scalar_t, scalar_out_t, true, has_residual>(
out, input, weight, rms, 1.0f / token_scale, hidden_size, residual);
out, input, weight, rms, 1.0f / token_scale, hidden_size, residual_out,
residual);
} else {
// FP8 - Do not invert s_token_scale for exact match with FBGemm
vllm::norm_and_quant<scalar_t, scalar_out_t, false, has_residual>(
out, input, weight, rms, token_scale, hidden_size, residual);
out, input, weight, rms, token_scale, hidden_size, residual_out,
residual);
}
}
} // namespace vllm
Expand All @@ -94,6 +100,7 @@ void rms_norm_dynamic_per_token_quant_dispatch(
torch::Tensor& scales, // [num_tokens]
double const var_epsilon, // Variance epsilon used in norm calculation
std::optional<at::Tensor> const& scale_ub,
std::optional<at::Tensor>& residual_out,
std::optional<at::Tensor>& residual) {
int32_t hidden_size = input.size(-1);
auto num_tokens = input.numel() / hidden_size;
Expand All @@ -112,7 +119,9 @@ void rms_norm_dynamic_per_token_quant_dispatch(
out.data_ptr<scalar_t>(), scales.data_ptr<float>(),
input.data_ptr<scalar_in_t>(), weight.data_ptr<scalar_in_t>(),
scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
var_epsilon, hidden_size, residual->data_ptr<scalar_in_t>());
var_epsilon, hidden_size,
residual_out->data_ptr<scalar_in_t>(),
residual->data_ptr<scalar_in_t>());
});

} else {
Expand All @@ -124,7 +133,7 @@ void rms_norm_dynamic_per_token_quant_dispatch(
out.data_ptr<scalar_t>(), scales.data_ptr<float>(),
input.data_ptr<scalar_in_t>(), weight.data_ptr<scalar_in_t>(),
scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
var_epsilon, hidden_size, nullptr);
var_epsilon, hidden_size, nullptr, nullptr);
});
}
}
Expand All @@ -135,13 +144,18 @@ void rms_norm_dynamic_per_token_quant(
torch::Tensor const& weight, // [hidden_size]
torch::Tensor& scales, // [num_tokens]
double const var_epsilon, // Variance epsilon used in norm calculation
std::optional<at::Tensor> scale_ub, std::optional<at::Tensor> residual) {
std::optional<at::Tensor> scale_ub, std::optional<at::Tensor> residual_out,
std::optional<at::Tensor> residual) {
static c10::ScalarType kFp8Type = is_fp8_ocp()
? c10::ScalarType::Float8_e4m3fn
: c10::ScalarType::Float8_e4m3fnuz;
TORCH_CHECK(out.dtype() == kFp8Type || out.dtype() == torch::kInt8);
TORCH_CHECK(out.is_contiguous() && input.is_contiguous());

if (residual.has_value()) {
TORCH_CHECK(residual_out.has_value());
}

if (scale_ub.has_value()) {
TORCH_CHECK(out.dtype() == kFp8Type);
}
Expand All @@ -150,6 +164,7 @@ void rms_norm_dynamic_per_token_quant(
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "rms_norm_dynamic_per_token_quant_dispatch", [&] {
rms_norm_dynamic_per_token_quant_dispatch<scalar_t>(
out, input, weight, scales, var_epsilon, scale_ub, residual);
out, input, weight, scales, var_epsilon, scale_ub, residual_out,
residual);
});
}
Loading
Loading