Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix / Core] Prefix Caching Guards (merged with main) #4846

Merged
merged 32 commits into from
May 27, 2024

Conversation

zhuohan123
Copy link
Member

Updated version of #3903


PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

@zhuohan123 zhuohan123 changed the title Prefix caching guards new [Bugfix / Core] Prefix Caching Guards (merged with main) May 16, 2024
@@ -251,6 +263,18 @@ def get_sliding_window(self) -> Optional[int]:
return None
return getattr(self.hf_text_config, "sliding_window", None)

def get_sliding_window(self) -> Optional[int]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does it work for the model that already has sliding window like mistral?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Im not sure what you mean?

If the user does not specify --disable-sliding-window then we use sliding window if the model supports it

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh maybe it is a dumb question, but my question is for models that has slinding window by default https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/26bca36bde8333b5d7f72e9ed20ccda6a618af24/config.json#L18, if we use --disable-sliding-window, does it work properly?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, specifically what this does is handle a case like Mistral.

--disable-sliding-window means we turn off sliding window and set max_model_len=sliding_window

So in the case of Mistral, we then would treat the model as a 4096 ctx-len model with no sliding window.

The reason for this feature is that if we want to use features that are incompatible with sliding window (e.g. APC or chunked prefill), then there is a pathway to disable sliding window

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. that makes sense! Thanks for the explanation

Copy link
Collaborator

@rkooo567 rkooo567 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Minor comments. Didn't review _get_and_verify_max_len and _get_and_verify_dtype assuming it is just code refactored (lmk if it is wrong)

vllm/config.py Outdated
if self.disable_sliding_window:
logger.info("Sliding window is disabled per configuration. "
"Model max length will be capped at sliding window "
"length.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"length.")
"length, %d tokens", self.get_hf_config_sliding_window())

parser.add_argument('--disable-sliding-window',
action='store_true',
help='Disables sliding window if the model '
'supports sliding window')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you mention the model length is capped by the slinding window size?

@robertgshaw2-neuralmagic
Copy link
Collaborator

because sliding window is propogated to attention, this is going to require me to edit most model files.

Will get back to this tomorrow after I get mistral over the line

@@ -173,18 +168,14 @@ def __init__(
# Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 1000000)
rope_scaling = getattr(config, "rope_scaling", None)
use_sliding_window = (config.use_sliding_window
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

{
  "architectures": [
    "Qwen2ForCausalLM"
  ],
  "attention_dropout": 0.0,
  "bos_token_id": 151643,
  "eos_token_id": 151645,
  "hidden_act": "silu",
  "hidden_size": 4096,
  "initializer_range": 0.02,
  "intermediate_size": 11008,
  "max_position_embeddings": 32768,
  "max_window_layers": 28,          << qwen2 uses sliding window for some layers
  "model_type": "qwen2",
  "num_attention_heads": 32,
  "num_hidden_layers": 32,
  "num_key_value_heads": 32,
  "rms_norm_eps": 1e-06,
  "rope_theta": 1000000.0,
  "sliding_window": 32768,
  "tie_word_embeddings": false,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.37.0",
  "use_cache": true,
  "use_sliding_window": false,       << qwen2 does not use sliding window by default
  "vocab_size": 151936
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we had a bug in Qwen2 - this path will not be followed very often b/c qwen2 does not use sliding window by default

Currently, if use_sliding_window=True, only some layers will use sliding window. But we have global KV cache management that would treat KVs the same. So I do not see how it is possible that this could be working correctly.

This is not a very common user path because they would have to opt into sliding window on Qwen.

So I disabled this by default.

@robertgshaw2-neuralmagic
Copy link
Collaborator

Disabling sliding window ended up being more work than I expected because we broke some abstractions where the models are accessing the hf_config to determine whether sliding window is used when passing arguments to attention. As a result, the user's specification is ignored.

So, I updated Attention to use the cache_config's value, which is set properly rather than letting the model definition specify. This enabled me to remove the sliding_window argument from the various layers in the model.

Additionally, I noticed in this that Qwen2 attempts to support having only some layers with sliding window. We do not support this in our KV cache management, so I removed this bug by Failing if the system is configured this way. (note: this is not a popular codepath b/c Qwen2 does not use sliding window by default.

@zhuohan123 zhuohan123 merged commit 1102bef into main May 27, 2024
63 checks passed
dtrifiro pushed a commit to opendatahub-io/vllm that referenced this pull request May 31, 2024
…t#4846)

Co-authored-by: rsnm2 <rshaw@neuralmagic.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com>
robertgshaw2-neuralmagic added a commit to neuralmagic/nm-vllm that referenced this pull request Jun 8, 2024
…t#4846)

Co-authored-by: rsnm2 <rshaw@neuralmagic.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com>
robertgshaw2-neuralmagic added a commit to neuralmagic/nm-vllm that referenced this pull request Jun 8, 2024
…t#4846)

Co-authored-by: rsnm2 <rshaw@neuralmagic.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com>
joerunde pushed a commit to joerunde/vllm that referenced this pull request Jun 17, 2024
…t#4846)

Co-authored-by: rsnm2 <rshaw@neuralmagic.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com>
robertgshaw2-neuralmagic added a commit to neuralmagic/nm-vllm that referenced this pull request Jul 14, 2024
…t#4846)

Co-authored-by: rsnm2 <rshaw@neuralmagic.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com>
Temirulan pushed a commit to Temirulan/vllm-whisper that referenced this pull request Sep 6, 2024
…t#4846)

Co-authored-by: rsnm2 <rshaw@neuralmagic.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com>
@simon-mo simon-mo deleted the prefix-caching-guards-new branch October 28, 2024 16:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants