Skip to content

Commit 7342a7d

Browse files
authored
[Model] Support Mamba (#6484)
1 parent df3dcdf commit 7342a7d

29 files changed

+1603
-343
lines changed

.buildkite/run-cpu-test-ppc64le.sh

+7-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,13 @@ docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/hugg
1818
# Run basic model test
1919
docker exec cpu-test bash -c "
2020
pip install pytest matplotlib einops transformers_stream_generator
21-
pytest -v -s tests/models -m \"not vlm\" --ignore=tests/models/test_embedding.py --ignore=tests/models/test_oot_registration.py --ignore=tests/models/test_registry.py --ignore=tests/models/test_jamba.py --ignore=tests/models/test_danube3_4b.py" # Mamba and Danube3-4B on CPU is not supported
21+
pytest -v -s tests/models -m \"not vlm\" \
22+
--ignore=tests/models/test_embedding.py \
23+
--ignore=tests/models/test_oot_registration.py \
24+
--ignore=tests/models/test_registry.py \
25+
--ignore=tests/models/test_jamba.py \
26+
--ignore=tests/models/test_mamba.py \
27+
--ignore=tests/models/test_danube3_4b.py" # Mamba kernels and Danube3-4B on CPU is not supported
2228

2329
# online inference
2430
docker exec cpu-test bash -c "

.buildkite/run-cpu-test.sh

+1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ docker exec cpu-test bash -c "
2727
pytest -v -s tests/models/decoder_only/language \
2828
--ignore=tests/models/test_fp8.py \
2929
--ignore=tests/models/decoder_only/language/test_jamba.py \
30+
--ignore=tests/models/decoder_only/language/test_mamba.py \
3031
--ignore=tests/models/decoder_only/language/test_granitemoe.py \
3132
--ignore=tests/models/decoder_only/language/test_danube3_4b.py" # Mamba and Danube3-4B on CPU is not supported
3233

docs/source/models/supported_models.rst

+5
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,11 @@ Text Generation
152152
- :code:`meta-llama/Meta-Llama-3.1-405B-Instruct`, :code:`meta-llama/Meta-Llama-3.1-70B`, :code:`meta-llama/Meta-Llama-3-70B-Instruct`, :code:`meta-llama/Llama-2-70b-hf`, :code:`01-ai/Yi-34B`, etc.
153153
- ✅︎
154154
- ✅︎
155+
* - :code:`MambaForCausalLM`
156+
- Mamba
157+
- :code:`state-spaces/mamba-130m-hf`, :code:`state-spaces/mamba-790m-hf`, :code:`state-spaces/mamba-2.8b-hf`, etc.
158+
- ✅︎
159+
-
155160
* - :code:`MiniCPMForCausalLM`
156161
- MiniCPM
157162
- :code:`openbmb/MiniCPM-2B-sft-bf16`, :code:`openbmb/MiniCPM-2B-dpo-bf16`, etc.

tests/kernels/test_attention_selector.py

+21-16
Original file line numberDiff line numberDiff line change
@@ -20,22 +20,22 @@ def test_env(name: str, device: str, monkeypatch):
2020

2121
if device == "cpu":
2222
with patch("vllm.attention.selector.is_cpu", return_value=True):
23-
backend = which_attn_to_use(8, 16, 8, None, torch.float16,
24-
torch.float16, 16)
23+
backend = which_attn_to_use(16, None, torch.float16, torch.float16,
24+
16, False)
2525
assert backend.name == "TORCH_SDPA"
2626
elif device == "hip":
2727
with patch("vllm.attention.selector.is_hip", return_value=True):
28-
backend = which_attn_to_use(8, 16, 8, None, torch.float16,
29-
torch.float16, 16)
28+
backend = which_attn_to_use(16, None, torch.float16, torch.float16,
29+
16, False)
3030
assert backend.name == "ROCM_FLASH"
3131
elif device == "openvino":
3232
with patch("vllm.attention.selector.is_openvino", return_value=True):
33-
backend = which_attn_to_use(8, 16, 8, None, torch.float16,
34-
torch.float16, 16)
33+
backend = which_attn_to_use(16, None, torch.float16, torch.float16,
34+
16, False)
3535
assert backend.name == "OPENVINO"
3636
else:
37-
backend = which_attn_to_use(8, 16, 8, None, torch.float16,
38-
torch.float16, 16)
37+
backend = which_attn_to_use(16, None, torch.float16, torch.float16, 16,
38+
False)
3939
assert backend.name == name
4040

4141

@@ -46,37 +46,42 @@ def test_flash_attn(monkeypatch):
4646

4747
# Unsupported CUDA arch
4848
with patch("torch.cuda.get_device_capability", return_value=(7, 5)):
49-
backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16)
49+
backend = which_attn_to_use(16, None, torch.float16, None, 16, False)
5050
assert backend.name != STR_FLASH_ATTN_VAL
5151

5252
# Unsupported data type
53-
backend = which_attn_to_use(8, 16, 8, None, torch.float8_e4m3fn, None, 16)
53+
backend = which_attn_to_use(16, None, torch.float8_e4m3fn, None, 16, False)
5454
assert backend.name != STR_FLASH_ATTN_VAL
5555

5656
# Unsupported kv cache data type
57-
backend = which_attn_to_use(8, 16, 8, None, torch.float16, "fp8", 16)
57+
backend = which_attn_to_use(16, None, torch.float16, "fp8", 16, False)
5858
assert backend.name != STR_FLASH_ATTN_VAL
5959

6060
# Unsupported block size
61-
backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 8)
61+
backend = which_attn_to_use(16, None, torch.float16, None, 8, False)
6262
assert backend.name != STR_FLASH_ATTN_VAL
6363

6464
# Unsupported sliding window
65-
backend = which_attn_to_use(8, 16, 8, 1, torch.float16, None, 16)
65+
backend = which_attn_to_use(16, 1, torch.float16, None, 16, False)
6666
assert backend.name != STR_FLASH_ATTN_VAL
6767

6868
# flash-attn is not installed
6969
with patch.dict('sys.modules', {'vllm_flash_attn': None}):
70-
backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16)
70+
backend = which_attn_to_use(16, None, torch.float16, None, 16, False)
7171
assert backend.name != STR_FLASH_ATTN_VAL
7272

7373
# Unsupported head size
74-
backend = which_attn_to_use(8, 17, 8, None, torch.float16, None, 16)
74+
backend = which_attn_to_use(17, None, torch.float16, None, 16, False)
75+
assert backend.name != STR_FLASH_ATTN_VAL
76+
77+
# Attention-free models should bypass env and use PlaceholderAttention
78+
backend = which_attn_to_use(16, None, torch.float16, torch.float16, 16,
79+
True)
7580
assert backend.name != STR_FLASH_ATTN_VAL
7681

7782

7883
def test_invalid_env(monkeypatch):
7984
"""Throw an exception if the backend name is invalid."""
8085
override_backend_env_variable(monkeypatch, STR_INVALID_VAL)
8186
with pytest.raises(ValueError):
82-
which_attn_to_use(8, 16, 8, None, torch.float16, None, 16)
87+
which_attn_to_use(16, None, torch.float16, None, 16, False)

0 commit comments

Comments
 (0)