Skip to content

[Model] Add PLaMo2 #14323

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 43 commits into from
Apr 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
b7fb14d
Add PLaMo2 model at v0.6.3.post1
Alnusjaponica Mar 3, 2025
e58e384
Follow-up to the latest based on Jamba implementaion
Alnusjaponica Mar 4, 2025
b0f101e
Modify interfaces
Alnusjaponica Mar 5, 2025
a783e31
Add workaround to use IsHybrid interface
Alnusjaponica Mar 5, 2025
5ffec2c
Update dependencies for test
Alnusjaponica Mar 5, 2025
49dd3b0
Add test for plamo2 model
Alnusjaponica Mar 5, 2025
68d3bed
Modify code comment
Alnusjaponica Mar 5, 2025
bee8035
Resolve mypy error
Alnusjaponica Mar 6, 2025
1058777
Add plamo to test_registry
Alnusjaponica Mar 6, 2025
80a9abb
Merge branch 'main' into add-plamo2
Alnusjaponica Mar 10, 2025
e86e46f
pip-compile
Alnusjaponica Mar 10, 2025
7659755
pip-compile
Alnusjaponica Mar 10, 2025
e394371
Add workarounds to hundle the difference in config assumptions
Alnusjaponica Mar 11, 2025
9d7efcc
Make workaround simple
Alnusjaponica Mar 11, 2025
b0b222e
Merge branch 'main' into add-plamo2
Alnusjaponica Mar 17, 2025
121ab1d
Merge branch 'main' into add-plamo2
Alnusjaponica Mar 19, 2025
f4a6ac1
yapf
Alnusjaponica Mar 19, 2025
9e01348
Added PLaMo to docs
Alnusjaponica Mar 20, 2025
d051b1f
Set trust_remote_code=true for PLaMo in the test
Alnusjaponica Mar 20, 2025
1a7111b
Clean-up unused lines
Alnusjaponica Mar 20, 2025
b318d0f
Revert renaming final norm component on loading model
Alnusjaponica Mar 21, 2025
d8df40d
Clean-up PlamoConfig
Alnusjaponica Mar 21, 2025
a36caaf
Revert PlamoDecoder for class structure consistency with transformers
Alnusjaponica Mar 23, 2025
7cbdc8c
Rename PlamoDecoder to Plamo2Decoder
Alnusjaponica Mar 23, 2025
4368f63
Revert Plamo2DecoderLayer for consistency with transformers
Alnusjaponica Mar 23, 2025
a451011
Drop Plamo2MoE for consistency with transformers implementaion
Alnusjaponica Mar 23, 2025
256957f
Minimize model's member renaming
Alnusjaponica Mar 23, 2025
0f9f140
Move causal-conv1d installation to buildkite config
Alnusjaponica Mar 23, 2025
dc50e0a
Simplefy DenseMLP
Alnusjaponica Mar 23, 2025
c6adb46
Stop specifying use_mamba_kernels=False as a mamba kernel is installe…
Alnusjaponica Mar 24, 2025
83f6be5
Remove nn.Linear for quantization support
Alnusjaponica Mar 26, 2025
81a1954
Properly pass prefixes
Alnusjaponica Apr 1, 2025
0ed0042
Stop using float16 when dtype=auto is specified.
Alnusjaponica Apr 1, 2025
63283c1
Revert "Stop using float16 when dtype=auto is specified."
Alnusjaponica Apr 1, 2025
9ac51c5
Merge branch 'main' into add-plamo2
Alnusjaponica Apr 1, 2025
19fcd5f
Handle dtype for plamo2 in config
Alnusjaponica Apr 1, 2025
3f44675
Update object names to plamo2-prefixed
Alnusjaponica Apr 7, 2025
f43d02a
Update object names to plamo2-prefixed in the tests
Alnusjaponica Apr 7, 2025
2f3bed1
Merge branch 'main' into add-plamo2
Alnusjaponica Apr 15, 2025
7b41a18
Fix Plamo2ForCausalLM class name
Alnusjaponica Apr 15, 2025
f5bf80a
Merge branch 'main' into add-plamo2
Alnusjaponica Apr 15, 2025
313e050
Merge branch 'main' into add-plamo2
Alnusjaponica Apr 15, 2025
0c8fb36
Split plamo2 initialization test for debugging purpose
Alnusjaponica Apr 15, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -400,8 +400,9 @@ steps:
- pytest -v -s models/test_transformers.py
- pytest -v -s models/test_registry.py
# V1 Test: https://github.com/vllm-project/vllm/issues/14531
- VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4'
- VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4 and not plamo2'
- VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'llama4'
- VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'plamo2'

- label: Language Models Test (Standard) # 32min
#mirror_hardwares: [amd]
Expand All @@ -411,6 +412,8 @@ steps:
- tests/models/embedding/language
- tests/models/encoder_decoder/language
commands:
# Install causal-conv1d for plamo2 models here, as it is not compatible with pip-compile.
- pip install causal-conv1d
- pytest -v -s models/decoder_only/language -m 'core_model or quant_model'
- pytest -v -s models/embedding/language -m core_model

Expand All @@ -422,6 +425,8 @@ steps:
- tests/models/embedding/language
- tests/models/encoder_decoder/language
commands:
# Install causal-conv1d for plamo2 models here, as it is not compatible with pip-compile.
- pip install causal-conv1d
- pytest -v -s models/decoder_only/language -m 'not core_model and not quant_model'
- pytest -v -s models/embedding/language -m 'not core_model'

Expand Down
5 changes: 5 additions & 0 deletions docs/source/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,11 @@ See [this page](#generative-models) for more information on how to use generativ
* `adept/persimmon-8b-base`, `adept/persimmon-8b-chat`, etc.
*
* ✅︎
- * `Plamo2ForCausalLM`
* PLaMo2
* `pfnet/plamo-2-1b`, `pfnet/plamo-2-8b`, etc.
Comment on lines +500 to +502
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In https://huggingface.co/pfnet/plamo-2-1b/blob/main/config.json the architecture is PlamoForCausalLM instead of Plamo2ForCausalLM. Is this a mistake?

Copy link
Contributor Author

@Alnusjaponica Alnusjaponica Mar 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, we have another PlamoForCausalLM with a the different architecture (https://huggingface.co/pfnet/plamo-100b), and I used Plamo2ForCausalLM in vLLM to avoid misunderstandings. If it is necessary to use the same class name, I can ask our pre-training team if it's possible to rename the config. I apologize for any confusion."

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO this is a fairly minor thing but users will see a warning message like the following, which would be nice to fix:

You are using a model of type plamo2 to instantiate a model of type plamo. This is not supported for all configurations of models and can yield errors.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your comment. We internally agreed to change the class name to Plamo2ForCausalLM, so I'll be updating it here after our public models are updated.

*
*
- * `QWenLMHeadModel`
* Qwen
* `Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.
Expand Down
1 change: 1 addition & 0 deletions requirements/test.in
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ torch==2.6.0
torchaudio==2.6.0
torchvision==0.21.0
transformers_stream_generator # required for qwen-vl test
mamba_ssm # required for plamo2 test
matplotlib # required for qwen-vl test
mistral_common[opencv] >= 1.5.4 # required for pixtral test
num2words # required for smolvlm test
Expand Down
9 changes: 9 additions & 0 deletions requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ einops==0.8.0
# via
# -r requirements/test.in
# encodec
# mamba-ssm
# vector-quantize-pytorch
# vocos
einx==0.3.0
Expand Down Expand Up @@ -233,6 +234,8 @@ lxml==5.3.0
# via
# blobfile
# sacrebleu
mamba-ssm==2.2.4
# via -r requirements/test.in
markdown-it-py==3.0.0
# via rich
markupsafe==3.0.2
Expand Down Expand Up @@ -268,6 +271,8 @@ mypy-extensions==1.0.0
# via black
networkx==3.2.1
# via torch
ninja==1.11.1.3
# via mamba-ssm
nltk==3.9.1
# via rouge-score
num2words==0.5.14
Expand Down Expand Up @@ -360,6 +365,7 @@ packaging==24.1
# fastparquet
# huggingface-hub
# lazy-loader
# mamba-ssm
# matplotlib
# peft
# plotly
Expand Down Expand Up @@ -571,6 +577,7 @@ sentencepiece==0.2.0
# via mistral-common
setuptools==75.8.0
# via
# mamba-ssm
# pytablewriter
# torch
shellingham==1.5.4
Expand Down Expand Up @@ -627,6 +634,7 @@ torch==2.6.0
# encodec
# fastsafetensors
# lm-eval
# mamba-ssm
# peft
# runai-model-streamer
# sentence-transformers
Expand Down Expand Up @@ -664,6 +672,7 @@ transformers==4.51.1
# -r requirements/test.in
# genai-perf
# lm-eval
# mamba-ssm
# peft
# sentence-transformers
# transformers-stream-generator
Expand Down
41 changes: 18 additions & 23 deletions tests/models/decoder_only/language/test_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,15 @@
from ...utils import check_outputs_equal

# This test is for the hybrid models
MODELS = ["ai21labs/Jamba-tiny-dev", "Zyphra/Zamba2-1.2B-instruct"]
MODELS = [
"ai21labs/Jamba-tiny-dev", "Zyphra/Zamba2-1.2B-instruct",
"pfnet/plamo-2-1b"
]
# Bamba at Fp32 is too big for the CI (L4 GPU).
# MODELS = ["ai21labs/Jamba-tiny-dev", "ibm-ai-platform/Bamba-9B"]
# Note: Running Plamo2 in transformers implementation requires to install
# causal-conv1d package, which is not listed as a test dependency as it's
# not compatible with pip-compile.


@pytest.mark.parametrize("model", MODELS)
Expand All @@ -25,21 +31,11 @@ def test_models(
dtype: str,
max_tokens: int,
) -> None:

# numeric error produces different generation
if "Bamba" in model:
example_prompts.pop(3)

model_kwargs = {
"use_mamba_kernels": False, # mamba kernels are not installed so HF
# don't use them
}
if "Zamba2" in model:
# Zamba2 HF implementation automatically checks if mamba kernels are
# installed
model_kwargs = {}

with hf_runner(model, dtype=dtype, model_kwargs=model_kwargs) as hf_model:
with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)

with vllm_runner(model, dtype=dtype) as vllm_model:
Expand Down Expand Up @@ -94,6 +90,10 @@ def test_mamba_prefill_chunking_with_parallel_sampling(
# correctly for n > 1 decoding steps inside a
# chunked prefill forward pass (where we have both prefills
# and decoding together )

if 'plamo-2' in model:
dtype = "float" # use a different dtype for plamo

sampling_params = SamplingParams(n=3,
temperature=1,
seed=0,
Expand Down Expand Up @@ -125,20 +125,14 @@ def test_mamba_prefill_chunking(hf_runner, vllm_runner, example_prompts,
example_prompts.pop(3)
example_prompts.pop(2)
dtype = "half" # use a different dtype for Bamba

elif "Zamba2" in model:
example_prompts.pop(7)
dtype = "half"
elif "plamo-2-1b" in model:
example_prompts.pop(7)

model_kwargs = {
"use_mamba_kernels": False, # mamba kernels are not installed so HF
# don't use them
}
if "Zamba2" in model:
# Zamba2 HF implementation automatically checks if mamba kernels are
# installed
model_kwargs = {}

with hf_runner(model, dtype=dtype, model_kwargs=model_kwargs) as hf_model:
with hf_runner(model, dtype=dtype) as hf_model:
non_chunked = hf_model.generate_greedy(example_prompts, max_tokens)

with vllm_runner(model,
Expand Down Expand Up @@ -208,7 +202,8 @@ def test_mamba_cache_cg_padding(
# This test is for verifying that mamba cache is padded to CG captured
# batch size. If it's not, a torch RuntimeError will be raised because
# tensor dimensions aren't compatible
vllm_config = EngineArgs(model=model).create_engine_config()
vllm_config = EngineArgs(model=model,
trust_remote_code=True).create_engine_config()
while len(example_prompts) == vllm_config.pad_for_cudagraph(
len(example_prompts)):
example_prompts.append(example_prompts[0])
Expand Down
2 changes: 2 additions & 0 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,8 @@ def check_available_online(
trust_remote_code=True),
"PhiMoEForCausalLM": _HfExamplesInfo("microsoft/Phi-3.5-MoE-instruct",
trust_remote_code=True),
"Plamo2ForCausalLM": _HfExamplesInfo("pfnet/plamo-2-1b",
trust_remote_code=True),
"QWenLMHeadModel": _HfExamplesInfo("Qwen/Qwen-7B-Chat",
trust_remote_code=True),
"Qwen2ForCausalLM": _HfExamplesInfo("Qwen/Qwen2-7B-Instruct",
Expand Down
12 changes: 12 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2838,6 +2838,13 @@ def _get_and_verify_dtype(
else:
torch_dtype = config_dtype

if config.model_type == "plamo2":
logger.info(
"For PLaMo2, we cast models to bfloat16 instead of using "
"float16 by default. This is because float16 does not work."
)
torch_dtype = torch.bfloat16

from vllm.platforms import current_platform
if (current_platform.is_cpu()
and current_platform.get_cpu_architecture()
Expand Down Expand Up @@ -2867,6 +2874,11 @@ def _get_and_verify_dtype(
"using float16 by default. Please specify `dtype` if you "
"want to use float16.")
torch_dtype = torch.bfloat16
elif dtype == "float16" and config.model_type == "plamo2":
logger.warning(
"For PLaMo2, using float16 is unstable and might cause "
"unexpected behavior. Please use bfloat16 or float32 instead.")
torch_dtype = torch.float16
else:
if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
raise ValueError(f"Unknown dtype: {dtype}")
Expand Down
Loading