Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ROCm][Bugfix] Fixed several bugs related to rccl path and attention selector logic #3699

Merged
merged 7 commits into from
Mar 29, 2024

Conversation

hongxiayang
Copy link
Collaborator

@hongxiayang hongxiayang commented Mar 28, 2024

FILL IN THE PR DESCRIPTION HERE

FIX #xxxx (link existing issues this PR will resolve)

This pull request fixes several bugs introduced in previous commits, for example: #3661, #3625 , and previous refactoring in attention backend.

(1) Fixed the librccl.so file name, it should be something like:
/opt/rocm/lib/librccl.so.1

(2) a bug related to check whether to use ref-attention resulted from previous refactoring:

Before: even flash-attn is available, it uses naive attention, which is quite slow for our users and is not intended.

WARNING 03-28 18:26:49 xformers.py:410] flash_attn is not installed. Using naive attention. This will take significantly more GPU memory.

Now:

INFO 03-28 18:30:12 selector.py:29] Cannot use FlashAttention backend for AMD GPUs.
Using XFormers backend.

BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE


PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

@youkaichao
Copy link
Member

Thanks for the contribution!

Some comments:

Because we import torch first, if torch already loads librccl.so , then it should just work . So we need to figure out how torch loads it. In NVIDIA case, torch always uses libnccl.so.2 to refer to the nccl library, that's why we use libnccl.so.2 .

For the rccl case, if the convention way of torch is to use librccl.so.1 , then we just need to append librccl.so.1 . It should work by default for pytorch users.

Furthermore, pytorch uses https://pypi.org/project/nvidia-nccl/ as a pip package to maintain nccl dependency. Does this apply for the rccl case? Or pytorch ships rccl with it? Or it just uses the rccl inside the OS?

@youkaichao
Copy link
Member

I agree your method of finding rccl is very robust, but we don't need to be so complicated. By default, we just need it to work with the default way users install torch. Otherwise, when users want to use their own rccl library, we cannot really have a robust way to "find" it, because it might not be in the search path of ldd. That's why I left an environment variable there for further use.

@youkaichao
Copy link
Member

In summary, the following information would be greatly helpful:

  • When people do pip install torch in rocm platform, how does torch use rccl? Does pytorch statically link librccl.a, or dynamically link to librccl.so? If the latter is true, does pytorch install its own version (and if yes, where?) or use the existing version in a typical search path (and if yes, where?)?
  • In the case of dynamic linking, what's the conventional name of librccl.so? For example, when I use rccl==2.18.3, do I get all of the librccl.so/librccl.so.2/librccl.so.2.18/librccl.so.2.18.3 ? Or just have one (if yes, what's the name)?

I can provide the above information for nvidia case, for your reference:

  • When people do pip install torch in cuda platform, pytorch dynamically links to libnccl.so. Pytorch install its own version in ${CONDA_PREFIX}/site-packages/nvidia/nccl/lib/ , and that path is embedded in libtorch_cuda.so's rpath.
  • The conventional name of libnccl.so is libnccl.so.2 .

@@ -41,7 +48,7 @@
if torch.version.cuda is not None:
so_file = "libnccl.so.2"
elif torch.version.hip is not None:
so_file = "librccl.so.2"
so_file = "librccl.so.1"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I looked at https://rocm.docs.amd.com/projects/rccl/en/latest/api.html , and it says the current version is 2.18.3 . Quite strange that the library name is librccl.so.1 .

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that is why I am not assuming what the suffix is.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you talk to rccl team why this is the case? If they keep librccl.so.1 that would also be fine, but just please don't be too random.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My initial test with the current head is that, it does not work for ROCm. There are a bunch of other issues in addition to the ones described in this pull request.

We have tested using cupy and verified that it worked for the hipgraph path with our in-development newer ROCm.

However, this does not work for us.

Another thing, is that, will it be possible we can still opt in using cupy for all-reduce? Can it be abstracted so that people can choose use cupy, nccl, or, whatever?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as how rccl so file name and its version definition: I found information ROCm/rccl repo. Links below:

https://github.com/ROCm/rccl/blob/2f6d59e2e651914d9d6e51b2b702b9a9ac0ea99d/makefiles/version.mk#L2
and
https://github.com/ROCm/rccl/blob/2f6d59e2e651914d9d6e51b2b702b9a9ac0ea99d/CMakeLists.txt#L669C1-L669C19

Hope this answers your question. Let's take a step back, we want to solve the problem of cudagraph mode.
My understanding is that below are possible ways :

  • cupy
  • user-defined nccl/rccl
  • custom all reduce
  • pytorch native all-reduce

How we can easily choose one over the other and what is our long-term plan?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cupy is deprecated and removed now, because we got many bug report with regard to cupy .

pytorch native all-reduce is not available in cudagraph mode, because it usually contains some additional check that will fail graph capture.

Going forward, we will focus on the pynccl wrapper as the first choice, and custom all reduce as a backup plan (it is disabled by default because of instability).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@youkaichao Our users need the fixes for the other part like the one related to naive attention, since now it becomes the default for those users and it was quite slow.
I need to simplify this PR so that it will be merged quickly

@hongxiayang
Copy link
Collaborator Author

Thanks for the contribution!

Some comments:

Because we import torch first, if torch already loads librccl.so , then it should just work . So we need to figure out how torch loads it. In NVIDIA case, torch always uses libnccl.so.2 to refer to the nccl library, that's why we use libnccl.so.2 .

For the rccl case, if the convention way of torch is to use librccl.so.1 , then we just need to append librccl.so.1 . It should work by default for pytorch users.

Furthermore, pytorch uses https://pypi.org/project/nvidia-nccl/ as a pip package to maintain nccl dependency. Does this apply for the rccl case? Or pytorch ships rccl with it? Or it just uses the rccl inside the OS?

The short answer for how pytorch finds rccl during its build, is in its cmake mechanism. By default, it finds rccl related version information in /opt/rocm/lib/cmake/rccl directory.

@hongxiayang hongxiayang marked this pull request as ready for review March 29, 2024 21:35
Copy link
Member

@youkaichao youkaichao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm ok with the modification in pynccl.py . Please ping others for approval on the other parts.

@hongxiayang
Copy link
Collaborator Author

I'm ok with the modification in pynccl.py . Please ping others for approval on the other parts.

cc @simon-mo @WoosukKwon Please take a look at this one since right now users complained that naive attention is used which is 10x slower

Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for the fix and apologies for the late review.

@WoosukKwon WoosukKwon merged commit 9765b5c into vllm-project:main Mar 29, 2024
22 of 33 checks passed
xjpang pushed a commit to xjpang/vllm that referenced this pull request Mar 31, 2024
Temirulan pushed a commit to Temirulan/vllm-whisper that referenced this pull request Sep 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants