Skip to content

Commit 4f6593b

Browse files
authored
[HybridKVCache][Platform] Add support_hybrid_kv_cache for platform (#24646)
Signed-off-by: MengqingCao <cmq0113@163.com>
1 parent 94e6b2d commit 4f6593b

File tree

5 files changed

+20
-2
lines changed

5 files changed

+20
-2
lines changed

vllm/config/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3529,8 +3529,7 @@ def __post_init__(self):
35293529
# logger should only print warning message for hybrid models. As we
35303530
# can't know whether the model is hybrid or not now, so we don't log
35313531
# warning message here and will log it later.
3532-
if not (current_platform.is_cuda() or current_platform.is_rocm()
3533-
or current_platform.is_cpu()):
3532+
if not current_platform.support_hybrid_kv_cache():
35343533
# Hybrid KV cache manager is not supported on non-GPU platforms.
35353534
self.scheduler_config.disable_hybrid_kv_cache_manager = True
35363535
if self.kv_transfer_config is not None:

vllm/platforms/cpu.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,3 +347,7 @@ def default_v1(cls, model_config) -> bool:
347347
@classmethod
348348
def opaque_attention_op(cls) -> bool:
349349
return True
350+
351+
@classmethod
352+
def support_hybrid_kv_cache(cls) -> bool:
353+
return True

vllm/platforms/cuda.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -571,6 +571,10 @@ def check_if_supports_dtype(cls, torch_dtype: torch.dtype):
571571
"You can use float16 instead by explicitly setting the "
572572
"`dtype` flag in CLI, for example: --dtype=half.")
573573

574+
@classmethod
575+
def support_hybrid_kv_cache(cls) -> bool:
576+
return True
577+
574578

575579
# NVML utils
576580
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,

vllm/platforms/interface.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -586,6 +586,13 @@ def check_if_supports_dtype(cls, torch_dtype: torch.dtype):
586586
"""
587587
raise NotImplementedError
588588

589+
@classmethod
590+
def support_hybrid_kv_cache(cls) -> bool:
591+
"""
592+
Returns if the hybrid kv cache is supported by the current platform.
593+
"""
594+
return False
595+
589596

590597
class UnspecifiedPlatform(Platform):
591598
_enum = PlatformEnum.UNSPECIFIED

vllm/platforms/rocm.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -498,3 +498,7 @@ def check_if_supports_dtype(cls, torch_dtype: torch.dtype):
498498
f"Your {gpu_name} GPU {compute_str}. "
499499
"You can use float16 instead by explicitly setting the "
500500
"`dtype` flag in CLI, for example: --dtype=half.")
501+
502+
@classmethod
503+
def support_hybrid_kv_cache(cls) -> bool:
504+
return True

0 commit comments

Comments
 (0)