Skip to content

Commit 101f278

Browse files
tdoublepdanielafrimibringleinheheda12345
authored andcommitted
[V1] [Hybrid] Support using float32 for state in Hybrid Models (Mamba2, Mamba1, Minimax) (vllm-project#22928)
Signed-off-by: Daniel Afrimi <danielafrimi8@gmail.com> Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Signed-off-by: Chen Zhang <zhangch99@outlook.com> Co-authored-by: Daniel Afrimi <danielafrimi8@gmail.com> Co-authored-by: Burkhard Ringlein <ngl@zurich.ibm.com> Co-authored-by: Chen Zhang <zhangch99@outlook.com> Signed-off-by: Xiao Yu <xiao.yu@amd.com>
1 parent c30d42e commit 101f278

File tree

23 files changed

+467
-87
lines changed

23 files changed

+467
-87
lines changed

tests/models/language/generation/test_hybrid.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,3 +431,65 @@ def test_full_cuda_graph(
431431
name_0="hf" if hf_outputs is not None else "vllm-v0",
432432
name_1="vllm-v1",
433433
)
434+
435+
436+
@pytest.mark.parametrize("model", ["Zyphra/Zamba2-1.2B-instruct"])
437+
@pytest.mark.parametrize("max_tokens", [64])
438+
@pytest.mark.parametrize("num_logprobs", [5])
439+
def test_fp32_state(
440+
hf_runner,
441+
vllm_runner,
442+
example_prompts,
443+
monkeypatch,
444+
model: str,
445+
max_tokens: int,
446+
num_logprobs: int,
447+
) -> None:
448+
449+
try:
450+
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
451+
model_info.check_available_online(on_fail="skip")
452+
model_info.check_transformers_version(on_fail="skip")
453+
except ValueError:
454+
pass
455+
456+
with hf_runner(model) as hf_model:
457+
if model not in HF_UNSUPPORTED_MODELS:
458+
hf_outputs = hf_model.generate_greedy_logprobs_limit(
459+
example_prompts, max_tokens, num_logprobs)
460+
else:
461+
hf_outputs = None
462+
463+
with vllm_runner(model,
464+
max_num_seqs=MAX_NUM_SEQS,
465+
mamba_ssm_cache_dtype="float32") as vllm_model:
466+
vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
467+
example_prompts, max_tokens, num_logprobs)
468+
469+
with monkeypatch.context() as m:
470+
m.setenv("VLLM_USE_V1", "1")
471+
if model in HYBRID_MODELS:
472+
# required due to reorder_batch behaviour
473+
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
474+
with vllm_runner(model,
475+
max_num_seqs=MAX_NUM_SEQS,
476+
mamba_ssm_cache_dtype="float32",
477+
enable_prefix_caching=False) as vllm_model:
478+
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
479+
example_prompts, max_tokens, num_logprobs)
480+
481+
if hf_outputs is not None:
482+
check_logprobs_close(
483+
outputs_0_lst=hf_outputs,
484+
outputs_1_lst=vllm_v0_outputs,
485+
name_0="hf",
486+
name_1="vllm-v0",
487+
)
488+
489+
ref_outputs = hf_outputs if hf_outputs is not None else vllm_v0_outputs
490+
check_logprobs_close(
491+
outputs_0_lst=ref_outputs,
492+
outputs_1_lst=vllm_v1_outputs,
493+
name_0="hf" if hf_outputs is not None else "vllm-v0",
494+
name_1="vllm-v1",
495+
)

tests/v1/worker/test_gpu_model_runner.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -772,6 +772,8 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
772772
head_dim=hf_config.mamba_d_head,
773773
rms_norm_eps=hf_config.rms_norm_eps,
774774
activation=hf_config.hidden_act,
775+
cache_config=cache_config,
776+
model_config=model_config,
775777
prefix=key,
776778
)
777779
# suppress var not used error

vllm/config/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
import vllm.envs as envs
3131
from vllm import version
32-
from vllm.config.cache import (BlockSize, CacheConfig, CacheDType,
32+
from vllm.config.cache import (BlockSize, CacheConfig, CacheDType, MambaDType,
3333
PrefixCachingHashAlgo)
3434
from vllm.config.compilation import (CompilationConfig, CompilationLevel,
3535
PassConfig)

vllm/config/cache.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
BlockSize = Literal[1, 8, 16, 32, 64, 128]
2525
CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"]
26+
MambaDType = Literal["auto", "float32"]
2627
PrefixCachingHashAlgo = Literal["builtin", "sha256", "sha256_cbor_64bit"]
2728

2829

@@ -93,6 +94,15 @@ class CacheConfig:
9394
""" Optional override for mamba page size; used by hybrid mamba/attention
9495
models to ensure exact alignment with attention page size."""
9596

97+
mamba_cache_dtype: MambaDType = "auto"
98+
"""The data type to use for the Mamba cache (both the conv as well as the
99+
ssm state). If set to 'auto', the data type will be inferred from the model
100+
config."""
101+
mamba_ssm_cache_dtype: MambaDType = "auto"
102+
"""The data type to use for the Mamba cache (ssm state only, conv state will
103+
still be controlled by mamba_cache_dtype). If set to 'auto', the data type
104+
for the ssm state will be determined by mamba_cache_dtype."""
105+
96106
# Will be set after profiling.
97107
num_gpu_blocks: Optional[int] = field(default=None, init=False)
98108
"""The number of blocks to allocate for GPU memory."""
@@ -123,6 +133,8 @@ def compute_hash(self) -> str:
123133
"""
124134
factors: list[Any] = []
125135
factors.append(self.cache_dtype)
136+
factors.append(self.mamba_cache_dtype)
137+
factors.append(self.mamba_ssm_cache_dtype)
126138
# `cpu_offload_gb` does not use `torch.compile` yet.
127139
hash_str = hashlib.md5(str(factors).encode(),
128140
usedforsecurity=False).hexdigest()

vllm/engine/arg_utils.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,12 @@
2727
DeviceConfig, DistributedExecutorBackend,
2828
GuidedDecodingBackend, HfOverrides, KVEventsConfig,
2929
KVTransferConfig, LoadConfig, LogprobsMode,
30-
LoRAConfig, ModelConfig, ModelDType, ModelImpl,
31-
MultiModalConfig, ObservabilityConfig, ParallelConfig,
32-
PoolerConfig, PrefixCachingHashAlgo, RunnerOption,
33-
SchedulerConfig, SchedulerPolicy, SpeculativeConfig,
34-
TaskOption, TokenizerMode, VllmConfig, get_attr_docs,
35-
get_field)
30+
LoRAConfig, MambaDType, ModelConfig, ModelDType,
31+
ModelImpl, MultiModalConfig, ObservabilityConfig,
32+
ParallelConfig, PoolerConfig, PrefixCachingHashAlgo,
33+
RunnerOption, SchedulerConfig, SchedulerPolicy,
34+
SpeculativeConfig, TaskOption, TokenizerMode,
35+
VllmConfig, get_attr_docs, get_field)
3636
from vllm.logger import init_logger
3737
from vllm.platforms import CpuArchEnum, current_platform
3838
from vllm.plugins import load_general_plugins
@@ -422,6 +422,8 @@ class EngineArgs:
422422
override_attention_dtype: str = ModelConfig.override_attention_dtype
423423

424424
calculate_kv_scales: bool = CacheConfig.calculate_kv_scales
425+
mamba_cache_dtype: MambaDType = CacheConfig.mamba_cache_dtype
426+
mamba_ssm_cache_dtype: MambaDType = CacheConfig.mamba_ssm_cache_dtype
425427

426428
additional_config: dict[str, Any] = \
427429
get_field(VllmConfig, "additional_config")
@@ -694,6 +696,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
694696
**cache_kwargs["calculate_kv_scales"])
695697
cache_group.add_argument("--kv-sharing-fast-prefill",
696698
**cache_kwargs["kv_sharing_fast_prefill"])
699+
cache_group.add_argument("--mamba-cache-dtype",
700+
**cache_kwargs["mamba_cache_dtype"])
701+
cache_group.add_argument("--mamba-ssm-cache-dtype",
702+
**cache_kwargs["mamba_ssm_cache_dtype"])
697703

698704
# Multimodal related configs
699705
multimodal_kwargs = get_kwargs(MultiModalConfig)
@@ -1105,6 +1111,8 @@ def create_engine_config(
11051111
cpu_offload_gb=self.cpu_offload_gb,
11061112
calculate_kv_scales=self.calculate_kv_scales,
11071113
kv_sharing_fast_prefill=self.kv_sharing_fast_prefill,
1114+
mamba_cache_dtype=self.mamba_cache_dtype,
1115+
mamba_ssm_cache_dtype=self.mamba_ssm_cache_dtype,
11081116
)
11091117

11101118
ray_runtime_env = None

vllm/model_executor/layers/mamba/mamba_mixer.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from vllm import envs
1111
from vllm.attention.backends.abstract import AttentionMetadata
12-
from vllm.config import get_current_vllm_config
12+
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
1313
from vllm.distributed.parallel_state import (
1414
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
1515
from vllm.forward_context import ForwardContext, get_forward_context
@@ -20,7 +20,7 @@
2020
RowParallelLinear)
2121
from vllm.model_executor.layers.mamba.abstract import MambaBase
2222
from vllm.model_executor.layers.mamba.mamba_utils import (
23-
MambaStateShapeCalculator)
23+
MambaStateDtypeCalculator, MambaStateShapeCalculator)
2424
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
2525
causal_conv1d_fn, causal_conv1d_update)
2626
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
@@ -56,6 +56,8 @@ def __init__(self,
5656
rms_norm_eps: float = 1e-5,
5757
activation="silu",
5858
is_lora_enabled: bool = False,
59+
model_config: Optional[ModelConfig] = None,
60+
cache_config: Optional[CacheConfig] = None,
5961
prefix: str = ""):
6062
super().__init__()
6163
self.time_step_rank = time_step_rank
@@ -153,6 +155,8 @@ def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor):
153155
# The inner tuple is (conv_state, ssm_state)
154156
self.kv_cache = [(torch.tensor([]), torch.tensor([]))]
155157

158+
self.model_config = model_config
159+
self.cache_config = cache_config
156160
self.prefix = prefix
157161

158162
def _ssm_transform(
@@ -369,6 +373,15 @@ def forward_cuda(self,
369373

370374
return out
371375

376+
def get_state_dtype(self) -> tuple[torch.dtype]:
377+
assert self.model_config is not None
378+
assert self.cache_config is not None
379+
return MambaStateDtypeCalculator.mamba1_state_dtype(
380+
self.model_config.dtype,
381+
self.cache_config.mamba_cache_dtype,
382+
self.cache_config.mamba_ssm_cache_dtype,
383+
)
384+
372385
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
373386
return MambaStateShapeCalculator.mamba1_state_shape(
374387
tp_world_size=get_tensor_model_parallel_world_size(),

vllm/model_executor/layers/mamba/mamba_mixer2.py

Lines changed: 31 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from vllm import envs
1010
from vllm.attention.backends.abstract import AttentionMetadata
11-
from vllm.config import get_current_vllm_config
11+
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
1212
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
1313
get_tensor_model_parallel_world_size,
1414
tensor_model_parallel_all_gather,
@@ -21,7 +21,7 @@
2121
from vllm.model_executor.layers.mamba.mamba2_metadata import (Mamba2Metadata,
2222
update_metadata)
2323
from vllm.model_executor.layers.mamba.mamba_utils import (
24-
MambaStateShapeCalculator)
24+
MambaStateDtypeCalculator, MambaStateShapeCalculator)
2525
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
2626
causal_conv1d_fn, causal_conv1d_update)
2727
from vllm.model_executor.layers.mamba.ops.layernorm_gated import rms_norm_gated
@@ -218,23 +218,23 @@ class MambaMixer2(MambaBase, CustomOp):
218218
**selective** state spaces)
219219
"""
220220

221-
def __init__(
222-
self,
223-
hidden_size: int,
224-
ssm_state_size: int,
225-
conv_kernel_size: int,
226-
intermediate_size: int,
227-
use_conv_bias: bool,
228-
use_bias: bool,
229-
n_groups: int = 1,
230-
num_heads: int = 128,
231-
head_dim: int = 64,
232-
rms_norm_eps: float = 1e-5,
233-
activation: str = "silu",
234-
use_rms_norm: bool = True,
235-
quant_config: Optional[QuantizationConfig] = None,
236-
prefix: str = "",
237-
):
221+
def __init__(self,
222+
hidden_size: int,
223+
ssm_state_size: int,
224+
conv_kernel_size: int,
225+
intermediate_size: int,
226+
use_conv_bias: bool,
227+
use_bias: bool,
228+
n_groups: int = 1,
229+
num_heads: int = 128,
230+
head_dim: int = 64,
231+
rms_norm_eps: float = 1e-5,
232+
activation: str = "silu",
233+
use_rms_norm: bool = True,
234+
model_config: Optional[ModelConfig] = None,
235+
cache_config: Optional[CacheConfig] = None,
236+
quant_config: Optional[QuantizationConfig] = None,
237+
prefix: str = ""):
238238
super().__init__()
239239

240240
# For TP, the sharding plan is as follows:
@@ -417,6 +417,8 @@ def __init__(
417417
# The inner tuple is (conv_state, ssm_state)
418418
self.kv_cache = [(torch.tensor([]), torch.tensor([]))]
419419

420+
self.model_config = model_config
421+
self.cache_config = cache_config
420422
self.prefix = prefix
421423

422424
def forward_native(
@@ -670,7 +672,7 @@ def forward_cuda(
670672
dt_limit=(0.0, float("inf")),
671673
out=preallocated_ssm_out_p.view(1, num_prefill_tokens, -1,
672674
self.head_dim),
673-
)
675+
state_dtype=ssm_state.dtype)
674676

675677
# update ssm states
676678
# - varlen state is a (num_prefills, nheads, headdim, dstate) tensor
@@ -732,6 +734,15 @@ def forward_cuda(
732734
# 5. Final linear projection
733735
output[:num_actual_tokens], _ = self.out_proj(hidden_states)
734736

737+
def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]:
738+
assert self.model_config is not None
739+
assert self.cache_config is not None
740+
return MambaStateDtypeCalculator.mamba2_state_dtype(
741+
self.model_config.dtype,
742+
self.cache_config.mamba_cache_dtype,
743+
self.cache_config.mamba_ssm_cache_dtype,
744+
)
745+
735746
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
736747
return MambaStateShapeCalculator.mamba2_state_shape(
737748
intermediate_size=self.intermediate_size,

vllm/model_executor/layers/mamba/mamba_utils.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,58 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
from typing import Union
4+
5+
import torch
6+
7+
from vllm.config import MambaDType, ModelDType
38
from vllm.distributed import divide
9+
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_kv_cache_torch_dtype
10+
11+
12+
class MambaStateDtypeCalculator:
13+
14+
@classmethod
15+
def linear_attention_state_dtype(
16+
cls,
17+
model_dtype: Union[ModelDType, torch.dtype],
18+
mamba_cache_dtype: MambaDType,
19+
) -> tuple[torch.dtype, ...]:
20+
# TODO (tdoublep) requires testing
21+
if mamba_cache_dtype == "float32":
22+
raise ValueError("fp32 state for minimax is not yet supported")
23+
state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype)
24+
return (state_dtype, )
25+
26+
@classmethod
27+
def mamba1_state_dtype(
28+
cls,
29+
model_dtype: Union[ModelDType, torch.dtype],
30+
mamba_cache_dtype: MambaDType,
31+
mamba_ssm_cache_dtype: MambaDType,
32+
) -> tuple[torch.dtype, ...]:
33+
# TODO (tdoublep) requires kernel changes
34+
if mamba_cache_dtype == "float32" or mamba_ssm_cache_dtype == "float32":
35+
raise ValueError("fp32 state for mamba1 is not yet supported")
36+
else:
37+
return MambaStateDtypeCalculator.mamba2_state_dtype(
38+
model_dtype, mamba_cache_dtype, mamba_ssm_cache_dtype)
39+
40+
@classmethod
41+
def mamba2_state_dtype(
42+
cls,
43+
model_dtype: Union[ModelDType, torch.dtype],
44+
mamba_cache_dtype: MambaDType,
45+
mamba_ssm_cache_dtype: MambaDType,
46+
) -> tuple[torch.dtype, ...]:
47+
conv_state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype,
48+
model_dtype)
49+
if mamba_ssm_cache_dtype == "auto":
50+
temporal_state_dtype = conv_state_dtype
51+
else:
52+
temporal_state_dtype = (
53+
STR_DTYPE_TO_TORCH_DTYPE[mamba_ssm_cache_dtype])
54+
55+
return (conv_state_dtype, temporal_state_dtype)
456

557

658
class MambaStateShapeCalculator:

0 commit comments

Comments
 (0)