Skip to content

Commit cfea7d0

Browse files
Alnusjaponicanzw0301HiroakiMikaminopperl
authored andcommitted
[Model] Add PLaMo2 (vllm-project#14323)
Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Signed-off-by: shemmi <shemmi@preferred.jp> Co-authored-by: Kento Nozawa <nzw0301@preferred.jp> Co-authored-by: Hiroaki Mikami <mhiroaki@preferred.jp> Co-authored-by: Calvin Metzger <metzger@preferred.jp> Signed-off-by: Yang Wang <elainewy@meta.com>
1 parent 80cf79d commit cfea7d0

File tree

9 files changed

+800
-24
lines changed

9 files changed

+800
-24
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -400,8 +400,9 @@ steps:
400400
- pytest -v -s models/test_transformers.py
401401
- pytest -v -s models/test_registry.py
402402
# V1 Test: https://github.com/vllm-project/vllm/issues/14531
403-
- VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4'
403+
- VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4 and not plamo2'
404404
- VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'llama4'
405+
- VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'plamo2'
405406

406407
- label: Language Models Test (Standard) # 32min
407408
#mirror_hardwares: [amd]
@@ -411,6 +412,8 @@ steps:
411412
- tests/models/embedding/language
412413
- tests/models/encoder_decoder/language
413414
commands:
415+
# Install causal-conv1d for plamo2 models here, as it is not compatible with pip-compile.
416+
- pip install causal-conv1d
414417
- pytest -v -s models/decoder_only/language -m 'core_model or quant_model'
415418
- pytest -v -s models/embedding/language -m core_model
416419

@@ -422,6 +425,8 @@ steps:
422425
- tests/models/embedding/language
423426
- tests/models/encoder_decoder/language
424427
commands:
428+
# Install causal-conv1d for plamo2 models here, as it is not compatible with pip-compile.
429+
- pip install causal-conv1d
425430
- pytest -v -s models/decoder_only/language -m 'not core_model and not quant_model'
426431
- pytest -v -s models/embedding/language -m 'not core_model'
427432

docs/source/models/supported_models.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,11 @@ See [this page](#generative-models) for more information on how to use generativ
497497
* `adept/persimmon-8b-base`, `adept/persimmon-8b-chat`, etc.
498498
*
499499
* ✅︎
500+
- * `Plamo2ForCausalLM`
501+
* PLaMo2
502+
* `pfnet/plamo-2-1b`, `pfnet/plamo-2-8b`, etc.
503+
*
504+
*
500505
- * `QWenLMHeadModel`
501506
* Qwen
502507
* `Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.

requirements/test.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ torch==2.6.0
2727
torchaudio==2.6.0
2828
torchvision==0.21.0
2929
transformers_stream_generator # required for qwen-vl test
30+
mamba_ssm # required for plamo2 test
3031
matplotlib # required for qwen-vl test
3132
mistral_common[opencv] >= 1.5.4 # required for pixtral test
3233
num2words # required for smolvlm test

requirements/test.txt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ einops==0.8.0
111111
# via
112112
# -r requirements/test.in
113113
# encodec
114+
# mamba-ssm
114115
# vector-quantize-pytorch
115116
# vocos
116117
einx==0.3.0
@@ -233,6 +234,8 @@ lxml==5.3.0
233234
# via
234235
# blobfile
235236
# sacrebleu
237+
mamba-ssm==2.2.4
238+
# via -r requirements/test.in
236239
markdown-it-py==3.0.0
237240
# via rich
238241
markupsafe==3.0.2
@@ -268,6 +271,8 @@ mypy-extensions==1.0.0
268271
# via black
269272
networkx==3.2.1
270273
# via torch
274+
ninja==1.11.1.3
275+
# via mamba-ssm
271276
nltk==3.9.1
272277
# via rouge-score
273278
num2words==0.5.14
@@ -360,6 +365,7 @@ packaging==24.1
360365
# fastparquet
361366
# huggingface-hub
362367
# lazy-loader
368+
# mamba-ssm
363369
# matplotlib
364370
# peft
365371
# plotly
@@ -571,6 +577,7 @@ sentencepiece==0.2.0
571577
# via mistral-common
572578
setuptools==75.8.0
573579
# via
580+
# mamba-ssm
574581
# pytablewriter
575582
# torch
576583
shellingham==1.5.4
@@ -627,6 +634,7 @@ torch==2.6.0
627634
# encodec
628635
# fastsafetensors
629636
# lm-eval
637+
# mamba-ssm
630638
# peft
631639
# runai-model-streamer
632640
# sentence-transformers
@@ -664,6 +672,7 @@ transformers==4.51.1
664672
# -r requirements/test.in
665673
# genai-perf
666674
# lm-eval
675+
# mamba-ssm
667676
# peft
668677
# sentence-transformers
669678
# transformers-stream-generator

tests/models/decoder_only/language/test_hybrid.py

Lines changed: 18 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,15 @@
99
from ...utils import check_outputs_equal
1010

1111
# This test is for the hybrid models
12-
MODELS = ["ai21labs/Jamba-tiny-dev", "Zyphra/Zamba2-1.2B-instruct"]
12+
MODELS = [
13+
"ai21labs/Jamba-tiny-dev", "Zyphra/Zamba2-1.2B-instruct",
14+
"pfnet/plamo-2-1b"
15+
]
1316
# Bamba at Fp32 is too big for the CI (L4 GPU).
1417
# MODELS = ["ai21labs/Jamba-tiny-dev", "ibm-ai-platform/Bamba-9B"]
18+
# Note: Running Plamo2 in transformers implementation requires to install
19+
# causal-conv1d package, which is not listed as a test dependency as it's
20+
# not compatible with pip-compile.
1521

1622

1723
@pytest.mark.parametrize("model", MODELS)
@@ -25,21 +31,11 @@ def test_models(
2531
dtype: str,
2632
max_tokens: int,
2733
) -> None:
28-
2934
# numeric error produces different generation
3035
if "Bamba" in model:
3136
example_prompts.pop(3)
3237

33-
model_kwargs = {
34-
"use_mamba_kernels": False, # mamba kernels are not installed so HF
35-
# don't use them
36-
}
37-
if "Zamba2" in model:
38-
# Zamba2 HF implementation automatically checks if mamba kernels are
39-
# installed
40-
model_kwargs = {}
41-
42-
with hf_runner(model, dtype=dtype, model_kwargs=model_kwargs) as hf_model:
38+
with hf_runner(model, dtype=dtype) as hf_model:
4339
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
4440

4541
with vllm_runner(model, dtype=dtype) as vllm_model:
@@ -94,6 +90,10 @@ def test_mamba_prefill_chunking_with_parallel_sampling(
9490
# correctly for n > 1 decoding steps inside a
9591
# chunked prefill forward pass (where we have both prefills
9692
# and decoding together )
93+
94+
if 'plamo-2' in model:
95+
dtype = "float" # use a different dtype for plamo
96+
9797
sampling_params = SamplingParams(n=3,
9898
temperature=1,
9999
seed=0,
@@ -125,20 +125,14 @@ def test_mamba_prefill_chunking(hf_runner, vllm_runner, example_prompts,
125125
example_prompts.pop(3)
126126
example_prompts.pop(2)
127127
dtype = "half" # use a different dtype for Bamba
128+
128129
elif "Zamba2" in model:
129130
example_prompts.pop(7)
130131
dtype = "half"
132+
elif "plamo-2-1b" in model:
133+
example_prompts.pop(7)
131134

132-
model_kwargs = {
133-
"use_mamba_kernels": False, # mamba kernels are not installed so HF
134-
# don't use them
135-
}
136-
if "Zamba2" in model:
137-
# Zamba2 HF implementation automatically checks if mamba kernels are
138-
# installed
139-
model_kwargs = {}
140-
141-
with hf_runner(model, dtype=dtype, model_kwargs=model_kwargs) as hf_model:
135+
with hf_runner(model, dtype=dtype) as hf_model:
142136
non_chunked = hf_model.generate_greedy(example_prompts, max_tokens)
143137

144138
with vllm_runner(model,
@@ -208,7 +202,8 @@ def test_mamba_cache_cg_padding(
208202
# This test is for verifying that mamba cache is padded to CG captured
209203
# batch size. If it's not, a torch RuntimeError will be raised because
210204
# tensor dimensions aren't compatible
211-
vllm_config = EngineArgs(model=model).create_engine_config()
205+
vllm_config = EngineArgs(model=model,
206+
trust_remote_code=True).create_engine_config()
212207
while len(example_prompts) == vllm_config.pad_for_cudagraph(
213208
len(example_prompts)):
214209
example_prompts.append(example_prompts[0])

tests/models/registry.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,8 @@ def check_available_online(
204204
trust_remote_code=True),
205205
"PhiMoEForCausalLM": _HfExamplesInfo("microsoft/Phi-3.5-MoE-instruct",
206206
trust_remote_code=True),
207+
"Plamo2ForCausalLM": _HfExamplesInfo("pfnet/plamo-2-1b",
208+
trust_remote_code=True),
207209
"QWenLMHeadModel": _HfExamplesInfo("Qwen/Qwen-7B-Chat",
208210
trust_remote_code=True),
209211
"Qwen2ForCausalLM": _HfExamplesInfo("Qwen/Qwen2-7B-Instruct",

vllm/config.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2838,6 +2838,13 @@ def _get_and_verify_dtype(
28382838
else:
28392839
torch_dtype = config_dtype
28402840

2841+
if config.model_type == "plamo2":
2842+
logger.info(
2843+
"For PLaMo2, we cast models to bfloat16 instead of using "
2844+
"float16 by default. This is because float16 does not work."
2845+
)
2846+
torch_dtype = torch.bfloat16
2847+
28412848
from vllm.platforms import current_platform
28422849
if (current_platform.is_cpu()
28432850
and current_platform.get_cpu_architecture()
@@ -2867,6 +2874,11 @@ def _get_and_verify_dtype(
28672874
"using float16 by default. Please specify `dtype` if you "
28682875
"want to use float16.")
28692876
torch_dtype = torch.bfloat16
2877+
elif dtype == "float16" and config.model_type == "plamo2":
2878+
logger.warning(
2879+
"For PLaMo2, using float16 is unstable and might cause "
2880+
"unexpected behavior. Please use bfloat16 or float32 instead.")
2881+
torch_dtype = torch.float16
28702882
else:
28712883
if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
28722884
raise ValueError(f"Unknown dtype: {dtype}")

0 commit comments

Comments
 (0)