Skip to content

[Misc] Fix a config typo in disable_hybrid_kv_cache_manager configuration #19383

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 9, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4497,13 +4497,13 @@ def __post_init__(self):
# warning message here and will log it later.
if not (current_platform.is_cuda() or current_platform.is_rocm()):
# Hybrid KV cache manager is not supported on non-GPU platforms.
self.disable_hybrid_kv_cache_manager = True
self.scheduler_config.disable_hybrid_kv_cache_manager = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This change, along with the similar ones on lines 4503 and 4506, correctly targets self.scheduler_config.disable_hybrid_kv_cache_manager.

The disable_hybrid_kv_cache_manager attribute is defined within the SchedulerConfig class (see line 2120). By modifying it on self.scheduler_config, this PR ensures that the flag is set in the intended configuration object.

The previous assignments to self.disable_hybrid_kv_cache_manager would have dynamically created a new attribute on the VllmConfig instance. This new attribute was likely not being read or utilized by the relevant logic, meaning the hybrid KV cache manager would not have been disabled under these critical conditions (non-GPU platform, KV transfer enabled, or KV events enabled).

This is a good catch and an important fix for ensuring the correctness and robustness of the hybrid KV cache manager's behavior.

if self.kv_transfer_config is not None:
# Hybrid KV cache manager is not compatible with KV transfer.
self.disable_hybrid_kv_cache_manager = True
self.scheduler_config.disable_hybrid_kv_cache_manager = True
if self.kv_events_config is not None:
# Hybrid KV cache manager is not compatible with KV events.
self.disable_hybrid_kv_cache_manager = True
self.scheduler_config.disable_hybrid_kv_cache_manager = True

def update_sizes_for_sequence_parallelism(self,
possible_sizes: list) -> list:
Expand Down