Skip to content

Commit 3a2cda0

Browse files
pcmoritzjimpang
authored andcommitted
Revert "Refactor llama family models (vllm-project#2637)" (vllm-project#2851)
This reverts commit 5c976a7.
1 parent dd90bd1 commit 3a2cda0

File tree

17 files changed

+2720
-236
lines changed

17 files changed

+2720
-236
lines changed

vllm/model_executor/layers/layernorm.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -7,31 +7,6 @@
77
from vllm._C import ops
88

99

10-
class LayerNorm(nn.LayerNorm):
11-
12-
def __init__(
13-
self,
14-
hidden_size: int,
15-
eps: float = 1e-6,
16-
) -> None:
17-
super().__init__(hidden_size, eps=eps)
18-
19-
def forward(
20-
self,
21-
x: torch.Tensor,
22-
residual: Optional[torch.Tensor] = None,
23-
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
24-
"""normalization."""
25-
if residual is not None:
26-
x = x + residual
27-
residual = x
28-
x = super().forward(x)
29-
if residual is None:
30-
return x
31-
else:
32-
return x, residual
33-
34-
3510
class RMSNorm(nn.Module):
3611
"""Root mean square normalization.
3712

vllm/model_executor/models/__init__.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010

1111
# Architecture -> (module, class).
1212
_MODELS = {
13-
"AquilaModel": ("llama", "LlamaForCausalLM"),
14-
"AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2
13+
"AquilaModel": ("aquila", "AquilaForCausalLM"),
14+
"AquilaForCausalLM": ("aquila", "AquilaForCausalLM"), # AquilaChat2
1515
"BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b
1616
"BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), # baichuan-13b
1717
"BloomForCausalLM": ("bloom", "BloomForCausalLM"),
@@ -24,12 +24,12 @@
2424
"GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
2525
"GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
2626
"GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
27-
"InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
27+
"InternLMForCausalLM": ("internlm", "InternLMForCausalLM"),
2828
"InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
2929
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
3030
# For decapoda-research/llama-*
3131
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
32-
"MistralForCausalLM": ("llama", "LlamaForCausalLM"),
32+
"MistralForCausalLM": ("mistral", "MistralForCausalLM"),
3333
"MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
3434
"QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
3535
# transformers's mpt class has lower case
@@ -41,6 +41,7 @@
4141
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
4242
"RWForCausalLM": ("falcon", "FalconForCausalLM"),
4343
"StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
44+
"YiForCausalLM": ("yi", "YiForCausalLM")
4445
}
4546

4647
# Models not supported by ROCm.

0 commit comments

Comments
 (0)