-
-
Notifications
You must be signed in to change notification settings - Fork 8.4k
[Do not merge] Add out of place layernorm #20197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
81ac362
c368bc9
b271d60
a16b777
f59cd8d
b394144
aaf65de
ab0d2c7
047d774
c59f627
a059e0f
4f761d8
0c504ce
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -65,6 +65,7 @@ __global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists> | |
fused_add_rms_norm_static_fp8_quant_kernel( | ||
fp8_type* __restrict__ out, // [..., hidden_size] | ||
scalar_t* __restrict__ input, // [..., hidden_size] | ||
scalar_t* __restrict__ residual_out, // [..., hidden_size] | ||
scalar_t* __restrict__ residual, // [..., hidden_size] | ||
Comment on lines
67
to
69
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
|
||
const scalar_t* __restrict__ weight, // [hidden_size] | ||
const float* __restrict__ scale, // [1] | ||
|
@@ -81,6 +82,8 @@ fused_add_rms_norm_static_fp8_quant_kernel( | |
in this kernel as that would be undefined behavior */ | ||
auto* __restrict__ input_v = | ||
reinterpret_cast<_f16Vec<scalar_t, width>*>(input); | ||
auto* __restrict__ residual_out_v = | ||
reinterpret_cast<_f16Vec<scalar_t, width>*>(residual_out); | ||
auto* __restrict__ residual_v = | ||
reinterpret_cast<_f16Vec<scalar_t, width>*>(residual); | ||
auto* __restrict__ weight_v = | ||
|
@@ -91,7 +94,7 @@ fused_add_rms_norm_static_fp8_quant_kernel( | |
_f16Vec<scalar_t, width> temp = input_v[id]; | ||
temp += residual_v[id]; | ||
variance += temp.sum_squares(); | ||
residual_v[id] = temp; | ||
residual_out_v[id] = temp; | ||
} | ||
|
||
using BlockReduce = cub::BlockReduce<float, 1024>; | ||
|
@@ -108,7 +111,7 @@ fused_add_rms_norm_static_fp8_quant_kernel( | |
|
||
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { | ||
int id = blockIdx.x * vec_hidden_size + idx; | ||
_f16Vec<scalar_t, width> temp = residual_v[id]; | ||
_f16Vec<scalar_t, width> temp = residual_out_v[id]; | ||
temp *= s_variance; | ||
temp *= weight_v[idx]; | ||
#pragma unroll | ||
|
@@ -127,6 +130,7 @@ __global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists> | |
fused_add_rms_norm_static_fp8_quant_kernel( | ||
fp8_type* __restrict__ out, // [..., hidden_size] | ||
scalar_t* __restrict__ input, // [..., hidden_size] | ||
scalar_t* __restrict__ residual_out, // [..., hidden_size] | ||
scalar_t* __restrict__ residual, // [..., hidden_size] | ||
const scalar_t* __restrict__ weight, // [hidden_size] | ||
const float* __restrict__ scale, // [1] | ||
|
@@ -139,7 +143,7 @@ fused_add_rms_norm_static_fp8_quant_kernel( | |
z += residual[blockIdx.x * hidden_size + idx]; | ||
float x = (float)z; | ||
variance += x * x; | ||
residual[blockIdx.x * hidden_size + idx] = z; | ||
residual_out[blockIdx.x * hidden_size + idx] = z; | ||
} | ||
|
||
using BlockReduce = cub::BlockReduce<float, 1024>; | ||
|
@@ -155,7 +159,7 @@ fused_add_rms_norm_static_fp8_quant_kernel( | |
float const scale_inv = 1.0f / *scale; | ||
|
||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { | ||
float x = (float)residual[blockIdx.x * hidden_size + idx]; | ||
float x = (float)residual_out[blockIdx.x * hidden_size + idx]; | ||
float const out_norm = ((scalar_t)(x * s_variance)) * weight[idx]; | ||
out[blockIdx.x * hidden_size + idx] = | ||
scaled_fp8_conversion<true, fp8_type>(out_norm, scale_inv); | ||
|
@@ -198,17 +202,19 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size] | |
width, fp8_t> \ | ||
<<<grid, block, 0, stream>>>( \ | ||
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(), \ | ||
residual_out.data_ptr<scalar_t>(), \ | ||
residual.data_ptr<scalar_t>(), \ | ||
weight.data_ptr<scalar_t>(), scale.data_ptr<float>(), \ | ||
epsilon, num_tokens, hidden_size); \ | ||
}); \ | ||
}); | ||
void fused_add_rms_norm_static_fp8_quant( | ||
torch::Tensor& out, // [..., hidden_size], | ||
torch::Tensor& input, // [..., hidden_size] | ||
torch::Tensor& residual, // [..., hidden_size] | ||
torch::Tensor& weight, // [hidden_size] | ||
torch::Tensor& scale, // [1] | ||
torch::Tensor& out, // [..., hidden_size], | ||
torch::Tensor& input, // [..., hidden_size] | ||
torch::Tensor& residual_out, // [..., hidden_size] | ||
torch::Tensor& residual, // [..., hidden_size] | ||
torch::Tensor& weight, // [hidden_size] | ||
torch::Tensor& scale, // [1] | ||
double epsilon) { | ||
Comment on lines
211
to
218
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
|
||
int hidden_size = input.size(-1); | ||
int num_tokens = input.numel() / hidden_size; | ||
|
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -89,7 +89,8 @@ void convert_vertical_slash_indexes_mergehead( | |||||||||
void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, | ||||||||||
double epsilon); | ||||||||||
|
||||||||||
void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual, | ||||||||||
void fused_add_rms_norm(torch::Tensor& out, torch::Tensor& input, | ||||||||||
torch::Tensor& residual_out, torch::Tensor& residual, | ||||||||||
torch::Tensor& weight, double epsilon); | ||||||||||
|
||||||||||
void apply_repetition_penalties_(torch::Tensor& logits, | ||||||||||
|
@@ -103,6 +104,7 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input, | |||||||||
|
||||||||||
void fused_add_rms_norm_static_fp8_quant(torch::Tensor& out, | ||||||||||
torch::Tensor& input, | ||||||||||
torch::Tensor& residual_out, | ||||||||||
torch::Tensor& residual, | ||||||||||
torch::Tensor& weight, | ||||||||||
torch::Tensor& scale, double epsilon); | ||||||||||
|
@@ -113,6 +115,7 @@ void rms_norm_dynamic_per_token_quant(torch::Tensor& out, | |||||||||
torch::Tensor& scales, | ||||||||||
double const epsilon, | ||||||||||
std::optional<torch::Tensor> scale_ub, | ||||||||||
std::optional<torch::Tensor> residual_out, | ||||||||||
std::optional<torch::Tensor> residual); | ||||||||||
Comment on lines
+118
to
119
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
|
||||||||||
|
||||||||||
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, | ||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,6 +15,7 @@ __device__ void rms_norm_dynamic_per_token_quant_vec( | |
scalar_t const* __restrict__ input, // [..., hidden_size] | ||
scalar_t const* __restrict__ weight, // [hidden_size] | ||
float const* scale_ub, float const var_epsilon, int32_t const hidden_size, | ||
scalar_t* __restrict__ residual_out = nullptr, | ||
scalar_t* __restrict__ residual = nullptr) { | ||
Comment on lines
+18
to
19
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
float rms = 0.0f; | ||
float token_scale = 0.0f; | ||
|
@@ -33,12 +34,14 @@ __device__ void rms_norm_dynamic_per_token_quant_vec( | |
if constexpr (std::is_same_v<scalar_out_t, int8_t>) { | ||
vllm::vectorized::norm_and_quant<scalar_t, scalar_out_t, true, | ||
has_residual>( | ||
out, input, weight, rms, 1.0f / token_scale, hidden_size, residual); | ||
out, input, weight, rms, 1.0f / token_scale, hidden_size, residual_out, | ||
residual); | ||
} else { | ||
// FP8 - Do not invert token_scale for exact match with FBGemm | ||
vllm::vectorized::norm_and_quant<scalar_t, scalar_out_t, false, | ||
has_residual>( | ||
out, input, weight, rms, token_scale, hidden_size, residual); | ||
has_residual>(out, input, weight, rms, | ||
token_scale, hidden_size, | ||
residual_out, residual); | ||
} | ||
} | ||
|
||
|
@@ -50,6 +53,7 @@ __global__ void rms_norm_dynamic_per_token_quant_kernel( | |
scalar_t const* __restrict__ input, // [..., hidden_size] | ||
scalar_t const* __restrict__ weight, // [hidden_size] | ||
float const* scale_ub, float const var_epsilon, int32_t const hidden_size, | ||
scalar_t* __restrict__ residual_out = nullptr, | ||
scalar_t* __restrict__ residual = nullptr) { | ||
// For vectorization, token_input and token_output pointers need to be | ||
// aligned at 8-byte and 4-byte addresses respectively. | ||
|
@@ -76,11 +80,13 @@ __global__ void rms_norm_dynamic_per_token_quant_kernel( | |
// RMS Norm + Quant | ||
if constexpr (std::is_same_v<scalar_out_t, int8_t>) { | ||
vllm::norm_and_quant<scalar_t, scalar_out_t, true, has_residual>( | ||
out, input, weight, rms, 1.0f / token_scale, hidden_size, residual); | ||
out, input, weight, rms, 1.0f / token_scale, hidden_size, residual_out, | ||
residual); | ||
} else { | ||
// FP8 - Do not invert s_token_scale for exact match with FBGemm | ||
vllm::norm_and_quant<scalar_t, scalar_out_t, false, has_residual>( | ||
out, input, weight, rms, token_scale, hidden_size, residual); | ||
out, input, weight, rms, token_scale, hidden_size, residual_out, | ||
residual); | ||
} | ||
} | ||
} // namespace vllm | ||
|
@@ -94,6 +100,7 @@ void rms_norm_dynamic_per_token_quant_dispatch( | |
torch::Tensor& scales, // [num_tokens] | ||
double const var_epsilon, // Variance epsilon used in norm calculation | ||
std::optional<at::Tensor> const& scale_ub, | ||
std::optional<at::Tensor>& residual_out, | ||
std::optional<at::Tensor>& residual) { | ||
int32_t hidden_size = input.size(-1); | ||
auto num_tokens = input.numel() / hidden_size; | ||
|
@@ -112,7 +119,9 @@ void rms_norm_dynamic_per_token_quant_dispatch( | |
out.data_ptr<scalar_t>(), scales.data_ptr<float>(), | ||
input.data_ptr<scalar_in_t>(), weight.data_ptr<scalar_in_t>(), | ||
scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr, | ||
var_epsilon, hidden_size, residual->data_ptr<scalar_in_t>()); | ||
var_epsilon, hidden_size, | ||
residual_out->data_ptr<scalar_in_t>(), | ||
residual->data_ptr<scalar_in_t>()); | ||
}); | ||
|
||
} else { | ||
|
@@ -124,7 +133,7 @@ void rms_norm_dynamic_per_token_quant_dispatch( | |
out.data_ptr<scalar_t>(), scales.data_ptr<float>(), | ||
input.data_ptr<scalar_in_t>(), weight.data_ptr<scalar_in_t>(), | ||
scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr, | ||
var_epsilon, hidden_size, nullptr); | ||
var_epsilon, hidden_size, nullptr, nullptr); | ||
}); | ||
} | ||
} | ||
|
@@ -135,13 +144,18 @@ void rms_norm_dynamic_per_token_quant( | |
torch::Tensor const& weight, // [hidden_size] | ||
torch::Tensor& scales, // [num_tokens] | ||
double const var_epsilon, // Variance epsilon used in norm calculation | ||
std::optional<at::Tensor> scale_ub, std::optional<at::Tensor> residual) { | ||
std::optional<at::Tensor> scale_ub, std::optional<at::Tensor> residual_out, | ||
std::optional<at::Tensor> residual) { | ||
static c10::ScalarType kFp8Type = is_fp8_ocp() | ||
? c10::ScalarType::Float8_e4m3fn | ||
: c10::ScalarType::Float8_e4m3fnuz; | ||
TORCH_CHECK(out.dtype() == kFp8Type || out.dtype() == torch::kInt8); | ||
TORCH_CHECK(out.is_contiguous() && input.is_contiguous()); | ||
|
||
if (residual.has_value()) { | ||
TORCH_CHECK(residual_out.has_value()); | ||
} | ||
|
||
if (scale_ub.has_value()) { | ||
TORCH_CHECK(out.dtype() == kFp8Type); | ||
} | ||
|
@@ -150,6 +164,7 @@ void rms_norm_dynamic_per_token_quant( | |
VLLM_DISPATCH_FLOATING_TYPES( | ||
input.scalar_type(), "rms_norm_dynamic_per_token_quant_dispatch", [&] { | ||
rms_norm_dynamic_per_token_quant_dispatch<scalar_t>( | ||
out, input, weight, scales, var_epsilon, scale_ub, residual); | ||
out, input, weight, scales, var_epsilon, scale_ub, residual_out, | ||
residual); | ||
}); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
input
,residual
, andweight
tensors are not modified within this function; they are read-only. For const-correctness, they should be passed asconst torch::Tensor&
.