Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Support Cohere2ForCausalLM (Cohere R7B) #11203

Merged
merged 5 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,9 @@ Text Generation (``--task generate``)
- :code:`THUDM/chatglm2-6b`, :code:`THUDM/chatglm3-6b`, etc.
- ✅︎
- ✅︎
* - :code:`CohereForCausalLM`
* - :code:`CohereForCausalLM`,:code:`Cohere2ForCausalLM`
- Command-R
- :code:`CohereForAI/c4ai-command-r-v01`, etc.
- :code:`CohereForAI/c4ai-command-r-v01`, :code:`CohereForAI/c4ai-command-r7b-12-2024`, etc.
- ✅︎
- ✅︎
* - :code:`DbrxForCausalLM`
Expand Down
2 changes: 2 additions & 0 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ class _HfExamplesInfo:
# ChatGLMModel supports multimodal
"CohereForCausalLM": _HfExamplesInfo("CohereForAI/c4ai-command-r-v01",
trust_remote_code=True),
"Cohere2ForCausalLM": _HfExamplesInfo("CohereForAI/c4ai-command-r7b-12-2024", # noqa: E501
trust_remote_code=True),
"DbrxForCausalLM": _HfExamplesInfo("databricks/dbrx-instruct"),
"DeciLMForCausalLM": _HfExamplesInfo("Deci/DeciLM-7B-instruct",
trust_remote_code=True),
Expand Down
4 changes: 4 additions & 0 deletions tests/models/test_initialization.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from unittest.mock import patch

import pytest
import transformers
from transformers import PretrainedConfig

from vllm import LLM
Expand All @@ -11,6 +12,9 @@
@pytest.mark.parametrize("model_arch", HF_EXAMPLE_MODELS.get_supported_archs())
def test_can_initialize(model_arch):
model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch)
if (model_arch == "Cohere2ForCausalLM"
and transformers.__version__ < "4.48.0"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The model is only available in transformers v4.48 which has not been released yet.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. That makes sense.

pytest.skip(reason="Model introduced in HF >= 4.48.0")
if not model_info.is_available_online:
pytest.skip("Model is not available online")

Expand Down
19 changes: 17 additions & 2 deletions vllm/model_executor/models/commandr.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
from vllm.sequence import IntermediateTensors

from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
from .utils import (extract_layer_index, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)

Expand Down Expand Up @@ -171,12 +171,26 @@ def __init__(
rope_scaling=self.rope_scaling,
is_neox_style=False,
)

sliding_window = getattr(config, "sliding_window", None)
# Model v2 has sliding windows, v1 does not
self.v1 = sliding_window is None

layer_idx = extract_layer_index(prefix)
layer_has_sliding_window = (
getattr(config, "sliding_window_pattern", False)
and (layer_idx + 1) % self.config.sliding_window_pattern != 0)

self.sliding_window = (sliding_window
if layer_has_sliding_window else None)

self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
per_layer_sliding_window=self.sliding_window,
prefix=f"{prefix}.attn")
if self.use_qk_norm:
self.q_norm = LayerNorm(param_shape=(self.num_heads,
Expand Down Expand Up @@ -206,7 +220,8 @@ def forward(
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
if self.use_qk_norm:
q, k = self._apply_qk_norm(q, k)
q, k = self.rotary_emb(positions, q, k)
if self.v1 or self.sliding_window:
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.o_proj(attn_output)
return output
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
"BloomForCausalLM": ("bloom", "BloomForCausalLM"),
# ChatGLMModel supports multimodal
"CohereForCausalLM": ("commandr", "CohereForCausalLM"),
"Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
"DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
"DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
"DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
Expand Down
Loading