Skip to content

Commit 7eacffd

Browse files
pcmoritzesmeetu
andauthored
Migrate InternLMForCausalLM to LlamaForCausalLM (#2860)
Co-authored-by: Roy <jasonailu87@gmail.com>
1 parent 2a543d6 commit 7eacffd

File tree

3 files changed

+5
-302
lines changed

3 files changed

+5
-302
lines changed

vllm/model_executor/models/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
"GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
2525
"GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
2626
"GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
27-
"InternLMForCausalLM": ("internlm", "InternLMForCausalLM"),
27+
"InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
2828
"InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
2929
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
3030
# For decapoda-research/llama-*

vllm/model_executor/models/internlm.py

Lines changed: 0 additions & 299 deletions
This file was deleted.

vllm/model_executor/models/llama.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ def __init__(
9191
rope_scaling: Optional[Dict[str, Any]] = None,
9292
max_position_embeddings: int = 8192,
9393
linear_method: Optional[LinearMethodBase] = None,
94+
bias: bool = False,
9495
) -> None:
9596
super().__init__()
9697
self.hidden_size = hidden_size
@@ -120,13 +121,13 @@ def __init__(
120121
self.head_dim,
121122
self.total_num_heads,
122123
self.total_num_kv_heads,
123-
bias=False,
124+
bias=bias,
124125
linear_method=linear_method,
125126
)
126127
self.o_proj = RowParallelLinear(
127128
self.total_num_heads * self.head_dim,
128129
hidden_size,
129-
bias=False,
130+
bias=bias,
130131
linear_method=linear_method,
131132
)
132133

@@ -179,6 +180,7 @@ def __init__(
179180
rope_scaling=rope_scaling,
180181
max_position_embeddings=max_position_embeddings,
181182
linear_method=linear_method,
183+
bias=getattr(config, "bias", False),
182184
)
183185
self.mlp = LlamaMLP(
184186
hidden_size=self.hidden_size,

0 commit comments

Comments
 (0)