-
-
Notifications
You must be signed in to change notification settings - Fork 8.4k
[Do not merge] Add out of place layernorm #20197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: charlifu <charlifu@amd.com>
Signed-off-by: charlifu <charlifu@amd.com>
Signed-off-by: charlifu <charlifu@amd.com>
Signed-off-by: charlifu <charlifu@amd.com>
Signed-off-by: charlifu <charlifu@amd.com>
Signed-off-by: charlifu <charlifu@amd.com>
Signed-off-by: charlifu <charlifu@amd.com>
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @charlifu, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request introduces a significant change to the fused_add_rms_norm
and related layer normalization kernels by transitioning them from in-place operations to out-of-place operations. This modification enhances flexibility and potentially improves memory management by allowing the normalized output and updated residual to be written to distinct, pre-allocated tensors. The changes span across C++/CUDA kernels, Python bindings, and the compilation pipeline, ensuring the new API is consistently applied and tested. The [Do not merge]
prefix suggests this is an exploratory or work-in-progress change.
Highlights
- API Change for Fused LayerNorm Kernels: Modified the
fused_add_rms_norm
andfused_add_rms_norm_static_fp8_quant
operations to support out-of-place computation. This involves introducing new output tensors for the normalized result (out
) and the updated residual (residual_out
), rather than modifying input tensors in-place. - Kernel Signature Updates: Updated the C++/CUDA kernel signatures (
fused_add_rms_norm_kernel
,fused_add_rms_norm_static_fp8_quant_kernel
,rms_norm_dynamic_per_token_quant_kernel
) to accept the newout
andresidual_out
pointers and mark original inputs asconst
. - Python Binding and Wrapper Updates: The Python
_custom_ops.py
andmodel_executor/layers/layernorm.py
wrappers, along withtorch_bindings.cpp
, have been updated to reflect the new out-of-place API, passing and returning the new output tensors. - Test Adjustments: Corresponding test cases in
test_layernorm.py
were adjusted to allocate and pass the new output tensors, ensuring correctness with the out-of-place behavior. - Functionalization and Fusion Graph Updates: Changes were made in
fix_functionalization.py
andfusion.py
to correctly handle the new output tensors in the graph compilation and optimization passes.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request refactors the fused RMSNorm kernels to be out-of-place. I've found a few critical issues that need to be addressed:
- There are bugs in the test suite where arguments to a custom op are swapped, which could lead to incorrect test results.
- A
TypeError
will occur in thelayernorm
layer due to incorrect usage oftorch.empty()
.
Additionally, there are several places where const
correctness could be improved in the C++ and CUDA code, and a potentially impactful line of code has been commented out in the torch.compile
path. Addressing these points will greatly improve the quality and robustness of the PR.
void fused_add_rms_norm(torch::Tensor& out, // [..., hidden_size] | ||
torch::Tensor& input, // [..., hidden_size] | ||
torch::Tensor& residual_out, // [..., hidden_size] | ||
torch::Tensor& residual, // [..., hidden_size] | ||
torch::Tensor& weight, // [hidden_size] | ||
double epsilon) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The input
, residual
, and weight
tensors are not modified within this function; they are read-only. For const-correctness, they should be passed as const torch::Tensor&
.
void fused_add_rms_norm(torch::Tensor& out, // [..., hidden_size]
const torch::Tensor& input, // [..., hidden_size]
torch::Tensor& residual_out, // [..., hidden_size]
const torch::Tensor& residual, // [..., hidden_size]
const torch::Tensor& weight, // [hidden_size]
scalar_t* __restrict__ input, // [..., hidden_size] | ||
scalar_t* __restrict__ residual_out, // [..., hidden_size] | ||
scalar_t* __restrict__ residual, // [..., hidden_size] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The input
and residual
pointers are only used for reading in this kernel. They should be declared as const scalar_t* __restrict__
to enforce const-correctness.
const scalar_t* __restrict__ input, // [..., hidden_size]
scalar_t* __restrict__ residual_out, // [..., hidden_size]
const scalar_t* __restrict__ residual, // [..., hidden_size]
void fused_add_rms_norm_static_fp8_quant( | ||
torch::Tensor& out, // [..., hidden_size], | ||
torch::Tensor& input, // [..., hidden_size] | ||
torch::Tensor& residual, // [..., hidden_size] | ||
torch::Tensor& weight, // [hidden_size] | ||
torch::Tensor& scale, // [1] | ||
torch::Tensor& out, // [..., hidden_size], | ||
torch::Tensor& input, // [..., hidden_size] | ||
torch::Tensor& residual_out, // [..., hidden_size] | ||
torch::Tensor& residual, // [..., hidden_size] | ||
torch::Tensor& weight, // [hidden_size] | ||
torch::Tensor& scale, // [1] | ||
double epsilon) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The input
, residual
, weight
, and scale
tensors are read-only in this function. They should be passed as const torch::Tensor&
to ensure they are not modified and to improve code clarity.
void fused_add_rms_norm_static_fp8_quant(
torch::Tensor& out, // [..., hidden_size],
const torch::Tensor& input, // [..., hidden_size]
torch::Tensor& residual_out, // [..., hidden_size]
const torch::Tensor& residual, // [..., hidden_size]
const torch::Tensor& weight, // [hidden_size]
const torch::Tensor& scale, // [1]
double epsilon) {
std::optional<torch::Tensor> residual_out, | ||
std::optional<torch::Tensor> residual); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The residual
tensor is read-only. It should be passed by const&
to avoid an unnecessary copy of the std::optional
and to signal its read-only nature.
std::optional<torch::Tensor> residual_out, | |
std::optional<torch::Tensor> residual); | |
std::optional<torch::Tensor> residual_out, | |
const std::optional<torch::Tensor>& residual); |
scalar_t* __restrict__ residual_out = nullptr, | ||
scalar_t* __restrict__ residual = nullptr) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
scalar_t* __restrict__ residual_out = nullptr, | ||
scalar_t* __restrict__ residual = nullptr) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Signed-off-by: charlifu <charlifu@amd.com>
Signed-off-by: charlifu <charlifu@amd.com>
Signed-off-by: charlifu <charlifu@amd.com>
Signed-off-by: charlifu <charlifu@amd.com>
No description provided.