|
8 | 8 |
|
9 | 9 | from vllm import envs |
10 | 10 | from vllm.attention.backends.abstract import AttentionMetadata |
11 | | -from vllm.config import get_current_vllm_config |
| 11 | +from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config |
12 | 12 | from vllm.distributed import (divide, get_tensor_model_parallel_rank, |
13 | 13 | get_tensor_model_parallel_world_size, |
14 | 14 | tensor_model_parallel_all_gather, |
|
21 | 21 | from vllm.model_executor.layers.mamba.mamba2_metadata import (Mamba2Metadata, |
22 | 22 | update_metadata) |
23 | 23 | from vllm.model_executor.layers.mamba.mamba_utils import ( |
24 | | - MambaStateShapeCalculator) |
| 24 | + MambaStateDtypeCalculator, MambaStateShapeCalculator) |
25 | 25 | from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( |
26 | 26 | causal_conv1d_fn, causal_conv1d_update) |
27 | 27 | from vllm.model_executor.layers.mamba.ops.layernorm_gated import rms_norm_gated |
@@ -218,23 +218,23 @@ class MambaMixer2(MambaBase, CustomOp): |
218 | 218 | **selective** state spaces) |
219 | 219 | """ |
220 | 220 |
|
221 | | - def __init__( |
222 | | - self, |
223 | | - hidden_size: int, |
224 | | - ssm_state_size: int, |
225 | | - conv_kernel_size: int, |
226 | | - intermediate_size: int, |
227 | | - use_conv_bias: bool, |
228 | | - use_bias: bool, |
229 | | - n_groups: int = 1, |
230 | | - num_heads: int = 128, |
231 | | - head_dim: int = 64, |
232 | | - rms_norm_eps: float = 1e-5, |
233 | | - activation: str = "silu", |
234 | | - use_rms_norm: bool = True, |
235 | | - quant_config: Optional[QuantizationConfig] = None, |
236 | | - prefix: str = "", |
237 | | - ): |
| 221 | + def __init__(self, |
| 222 | + hidden_size: int, |
| 223 | + ssm_state_size: int, |
| 224 | + conv_kernel_size: int, |
| 225 | + intermediate_size: int, |
| 226 | + use_conv_bias: bool, |
| 227 | + use_bias: bool, |
| 228 | + n_groups: int = 1, |
| 229 | + num_heads: int = 128, |
| 230 | + head_dim: int = 64, |
| 231 | + rms_norm_eps: float = 1e-5, |
| 232 | + activation: str = "silu", |
| 233 | + use_rms_norm: bool = True, |
| 234 | + model_config: Optional[ModelConfig] = None, |
| 235 | + cache_config: Optional[CacheConfig] = None, |
| 236 | + quant_config: Optional[QuantizationConfig] = None, |
| 237 | + prefix: str = ""): |
238 | 238 | super().__init__() |
239 | 239 |
|
240 | 240 | # For TP, the sharding plan is as follows: |
@@ -417,6 +417,8 @@ def __init__( |
417 | 417 | # The inner tuple is (conv_state, ssm_state) |
418 | 418 | self.kv_cache = [(torch.tensor([]), torch.tensor([]))] |
419 | 419 |
|
| 420 | + self.model_config = model_config |
| 421 | + self.cache_config = cache_config |
420 | 422 | self.prefix = prefix |
421 | 423 |
|
422 | 424 | def forward_native( |
@@ -670,7 +672,7 @@ def forward_cuda( |
670 | 672 | dt_limit=(0.0, float("inf")), |
671 | 673 | out=preallocated_ssm_out_p.view(1, num_prefill_tokens, -1, |
672 | 674 | self.head_dim), |
673 | | - ) |
| 675 | + state_dtype=ssm_state.dtype) |
674 | 676 |
|
675 | 677 | # update ssm states |
676 | 678 | # - varlen state is a (num_prefills, nheads, headdim, dstate) tensor |
@@ -732,6 +734,15 @@ def forward_cuda( |
732 | 734 | # 5. Final linear projection |
733 | 735 | output[:num_actual_tokens], _ = self.out_proj(hidden_states) |
734 | 736 |
|
| 737 | + def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]: |
| 738 | + assert self.model_config is not None |
| 739 | + assert self.cache_config is not None |
| 740 | + return MambaStateDtypeCalculator.mamba2_state_dtype( |
| 741 | + self.model_config.dtype, |
| 742 | + self.cache_config.mamba_cache_dtype, |
| 743 | + self.cache_config.mamba_ssm_cache_dtype, |
| 744 | + ) |
| 745 | + |
735 | 746 | def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: |
736 | 747 | return MambaStateShapeCalculator.mamba2_state_shape( |
737 | 748 | intermediate_size=self.intermediate_size, |
|
0 commit comments