Skip to content

[TPU] fix kv cache dtype in model runner #19244

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions vllm/v1/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@
PlaceholderRange)
from vllm.multimodal.utils import group_mm_inputs_by_modality
from vllm.sequence import IntermediateTensors
from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, cdiv,
is_pin_memory_available)
from vllm.v1.attention.backends.pallas import (PallasAttentionBackend,
PallasMetadata)
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
Expand Down Expand Up @@ -138,6 +139,11 @@ def __init__(

self.pin_memory = is_pin_memory_available()
self.dtype = self.model_config.dtype
if cache_config.cache_dtype == "auto":
self.kv_cache_dtype = self.dtype
else:
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
cache_config.cache_dtype]
self._hidden_states_dtype = self.dtype

self.is_multimodal_model = model_config.is_multimodal_model
Expand Down Expand Up @@ -480,7 +486,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=attn_module.dtype,
dtype=self.kv_cache_dtype,
sliding_window=attn_module.sliding_window,
use_mla=False,
)
Expand All @@ -489,7 +495,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=attn_module.dtype,
dtype=self.kv_cache_dtype,
use_mla=False,
)
elif attn_module.attn_type in (AttentionType.ENCODER,
Expand Down