Skip to content

[MODEL] Add support for Zamba2 models #13185

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Mar 18, 2025
Merged
5 changes: 5 additions & 0 deletions docs/source/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,11 @@ See [this page](#generative-models) for more information on how to use generativ
* `xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc.
* ✅︎
* ✅︎
- * `Zamba2ForCausalLM`
* Zamba2
* `Zyphra/Zamba2-7B-instruct`, `Zyphra/Zamba2-2.7B-instruct`, `Zyphra/Zamba2-1.2B-instruct`, etc.
*
*
:::

:::{note}
Expand Down
51 changes: 29 additions & 22 deletions tests/models/decoder_only/language/test_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from ...utils import check_outputs_equal

# This test is for the hybrid models
MODELS = ["ai21labs/Jamba-tiny-dev"]
MODELS = ["ai21labs/Jamba-tiny-dev", "Zyphra/Zamba2-1.2B-instruct"]
# Bamba at Fp32 is too big for the CI (L4 GPU).
# MODELS = ["ai21labs/Jamba-tiny-dev", "ibm-ai-platform/Bamba-9B"]

Expand All @@ -27,17 +27,19 @@ def test_models(
) -> None:

# numeric error produces different generation
if 'Bamba' in model:
if "Bamba" in model:
example_prompts.pop(3)

with hf_runner(
model,
dtype=dtype,
model_kwargs={
"use_mamba_kernels":
False, # mamba kernels are not installed so HF
# don't use them
}) as hf_model:
model_kwargs = {
"use_mamba_kernels": False, # mamba kernels are not installed so HF
# don't use them
}
if "Zamba2" in model:
# Zamba2 HF implementation automatically checks if mamba kernels are
# installed
model_kwargs = {}

with hf_runner(model, dtype=dtype, model_kwargs=model_kwargs) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)

with vllm_runner(model, dtype=dtype) as vllm_model:
Expand Down Expand Up @@ -112,26 +114,31 @@ def test_mamba_prefill_chunking_with_parallel_sampling(
def test_mamba_prefill_chunking(hf_runner, vllm_runner, example_prompts,
model: str, dtype: str,
max_tokens: int) -> None:
# numeric error during prefill chucking produces different generation
# numeric error during prefill chunking produces different generation
# compared to w/o prefill chunking for those examples, removed them for now
if 'Jamba' in model:
if "Jamba" in model:
example_prompts.pop(7)
example_prompts.pop(2)
example_prompts.pop(1)
elif 'Bamba' in model:
elif "Bamba" in model:
example_prompts.pop(6)
example_prompts.pop(3)
example_prompts.pop(2)
dtype = "half" # use a different dtype for Bamba

with hf_runner(
model,
dtype=dtype,
model_kwargs={
"use_mamba_kernels":
False, # mamba kernels are not installed so HF
# don't use them
}) as hf_model:
elif "Zamba2" in model:
example_prompts.pop(7)
dtype = "half"

model_kwargs = {
"use_mamba_kernels": False, # mamba kernels are not installed so HF
# don't use them
}
if "Zamba2" in model:
# Zamba2 HF implementation automatically checks if mamba kernels are
# installed
model_kwargs = {}

with hf_runner(model, dtype=dtype, model_kwargs=model_kwargs) as hf_model:
non_chunked = hf_model.generate_greedy(example_prompts, max_tokens)

with vllm_runner(model,
Expand Down
2 changes: 2 additions & 0 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,8 @@ def check_available_online(
"XverseForCausalLM": _HfExamplesInfo("xverse/XVERSE-7B-Chat",
is_available_online=False,
trust_remote_code=True),
"Zamba2ForCausalLM": _HfExamplesInfo("Zyphra/Zamba2-7B-instruct",
min_transformers_version="4.49"),
# [Encoder-decoder]
"BartModel": _HfExamplesInfo("facebook/bart-base"),
"BartForConditionalGeneration": _HfExamplesInfo("facebook/bart-large-cnn"),
Expand Down
14 changes: 14 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,6 +821,11 @@ def get_head_size(self) -> int:
if qk_rope_head_dim and qk_nope_head_dim:
return qk_rope_head_dim + qk_nope_head_dim

if hasattr(self.hf_text_config,
"model_type") and (self.hf_text_config.model_type
== "zamba2"):
return self.hf_text_config.attention_head_dim

if self.is_attention_free:
return 0

Expand Down Expand Up @@ -942,6 +947,15 @@ def get_num_layers_by_block_type(
"cannot determine the num of "
f"{block_type.value} layers")

if hasattr(self.hf_text_config,
"model_type") and (self.hf_text_config.model_type
== "zamba2"):
if attn_block_type:
return sum(t == "hybrid"
for t in layers_block_type_value[start:end])
else:
return self.get_num_layers(parallel_config)

return sum(t == block_type.value
for t in layers_block_type_value[start:end])

Expand Down
1 change: 0 additions & 1 deletion vllm/model_executor/layers/mamba/mamba_mixer2.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,6 @@ def __init__(self,
assert num_heads % self.tp_size == 0, \
"Tensor parallel world size must divide num heads."


assert (n_groups % self.tp_size) == 0 or n_groups == 1, \
(
"If tensor parallel world size does not divide num_heads, "
Expand Down
2 changes: 0 additions & 2 deletions vllm/model_executor/models/bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)

KVCache = Tuple[torch.Tensor, torch.Tensor]


class BambaMLP(nn.Module):

Expand Down
2 changes: 0 additions & 2 deletions vllm/model_executor/models/jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)

KVCache = Tuple[torch.Tensor, torch.Tensor]


class JambaMoE(nn.Module):

Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@
"SolarForCausalLM": ("solar", "SolarForCausalLM"),
"TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
"XverseForCausalLM": ("llama", "LlamaForCausalLM"),
"Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
# [Encoder-decoder]
"BartModel": ("bart", "BartForConditionalGeneration"),
"BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
Expand Down
Loading