Skip to content

Remove Yi model definition, please use LlamaForCausalLM instead #2854

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 13, 2024

Conversation

pcmoritz
Copy link
Collaborator

This is ported over from #2637 and removes the Yi model definition. The Yi architecture is the same as Llama and using LlamaForCausalLM instead has the advantage that it prevents code duplication and makes sure the Yi models inherit all the fixes we make for llama, like LoRA support.

This was the diff of the models:

23c23
< """Inference-only Yi model (https://01.ai) compatible with HuggingFace weights."""
---
> """Inference-only LLaMA model compatible with HuggingFace weights."""
28c28
< from vllm.transformers_utils.configs.yi import YiConfig
---
> from transformers import LlamaConfig
41c41
<     VocabParallelEmbedding, ParallelLMHead)
---
>     VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE)
47a48
> from vllm.config import LoRAConfig
52c53
< class YiMLP(nn.Module):
---
> class LlamaMLP(nn.Module):
82c83
< class YiAttention(nn.Module):
---
> class LlamaAttention(nn.Module):
130a132
> 
135c137
<             base=self.rope_theta,
---
>             base=rope_theta,
159c161
< class YiDecoderLayer(nn.Module):
---
> class LlamaDecoderLayer(nn.Module):
163c165
<         config: YiConfig,
---
>         config: LlamaConfig,
172c174
<         self.self_attn = YiAttention(
---
>         self.self_attn = LlamaAttention(
181c183
<         self.mlp = YiMLP(
---
>         self.mlp = LlamaMLP(
187,188c189,192
<         self.ln1 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
<         self.ln2 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
---
>         self.input_layernorm = RMSNorm(config.hidden_size,
>                                        eps=config.rms_norm_eps)
>         self.post_attention_layernorm = RMSNorm(config.hidden_size,
>                                                 eps=config.rms_norm_eps)
201c205
<             hidden_states = self.ln1(hidden_states)
---
>             hidden_states = self.input_layernorm(hidden_states)
203c207,208
<             hidden_states, residual = self.ln1(hidden_states, residual)
---
>             hidden_states, residual = self.input_layernorm(
>                 hidden_states, residual)
212c217,218
<         hidden_states, residual = self.ln2(hidden_states, residual)
---
>         hidden_states, residual = self.post_attention_layernorm(
>             hidden_states, residual)
217c223
< class YiModel(nn.Module):
---
> class LlamaModel(nn.Module):
221c227
<         config: YiConfig,
---
>         config: LlamaConfig,
222a229
>         lora_config: Optional[LoRAConfig] = None,
227c234,237
<         self.vocab_size = config.vocab_size
---
>         lora_vocab = (lora_config.lora_extra_vocab_size *
>                       (lora_config.max_loras or 1)) if lora_config else 0
>         self.vocab_size = config.vocab_size + lora_vocab
>         self.org_vocab_size = config.vocab_size
229c239
<             config.vocab_size,
---
>             self.vocab_size,
230a241
>             org_num_embeddings=config.vocab_size,
233c244
<             YiDecoderLayer(config, linear_method)
---
>             LlamaDecoderLayer(config, linear_method)
260c271,272
< class YiForCausalLM(nn.Module):
---
> class LlamaForCausalLM(nn.Module):
>     supports_lora = True
264c276
<         config: YiConfig,
---
>         config: LlamaConfig,
265a278
>         lora_config: Optional[LoRAConfig] = None,
270,272c283,296
<         self.model = YiModel(config, linear_method)
<         self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
<         self.sampler = Sampler(config.vocab_size)
---
>         self.model = LlamaModel(config, linear_method, lora_config=lora_config)
>         unpadded_vocab_size = config.vocab_size
>         if lora_config:
>             unpadded_vocab_size += lora_config.lora_extra_vocab_size
>         self.lm_head = ParallelLMHead(
>             unpadded_vocab_size,
>             config.hidden_size,
>             org_num_embeddings=config.vocab_size,
>             padding_size=DEFAULT_VOCAB_PADDING_SIZE
>             # We need bigger padding if using lora for kernel
>             # compatibility
>             if not lora_config else lora_config.lora_vocab_padding_size,
>         )
>         self.sampler = Sampler(unpadded_vocab_size, config.vocab_size)
311a336,340
>             if ("rotary_emb.cos_cached" in name
>                     or "rotary_emb.sin_cached" in name):
>                 # Models trained using ColossalAI may include these tensors in
>                 # the checkpoint. Skip them.
>                 continue

Co-authored-by: Roy <jasonailu87@gmail.com>
@pcmoritz pcmoritz mentioned this pull request Feb 13, 2024
12 tasks
@pcmoritz
Copy link
Collaborator Author

pcmoritz commented Feb 13, 2024

I tested the PR with

In [2]: from vllm import LLM, SamplingParams
   ...: 

In [3]: prompts = [
   ...:     "Hello, my name is",
   ...:     "The president of the United States is",
   ...:     "The capital of France is",
   ...:     "The future of AI is",
   ...: ]
   ...: sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

In [4]: llm = LLM(model="01-ai/Yi-6B")

In [5]: outputs = llm.generate(prompts, sampling_params)

In [6]: outputs
Out[6]: 
[RequestOutput(request_id=0, prompt='Hello, my name is', prompt_token_ids=[25102, 97, 826, 1815, 620], prompt_logprobs=None, outputs=[CompletionOutput(index=0, text=" Jim and I'm from Germany. In the next 30 days I", token_ids=[8465, 597, 616, 59610, 59583, 742, 8598, 98, 967, 567, 1724, 59568, 80, 77, 2043, 616], cumulative_logprob=-32.77659580856562, logprobs=None, finish_reason=length)], finished=True, lora_request=None),
 RequestOutput(request_id=1, prompt='The president of the United States is', prompt_token_ids=[1263, 4313, 593, 567, 3094, 3557, 620], prompt_logprobs=None, outputs=[CompletionOutput(index=0, text=' the head of state of the United States of America and the commander-in-', token_ids=[567, 1806, 593, 1622, 593, 567, 3094, 3557, 593, 3868, 597, 567, 28588, 59594, 563, 59594], cumulative_logprob=-6.3254562839865685, logprobs=None, finish_reason=length)], finished=True, lora_request=None),
 RequestOutput(request_id=2, prompt='The capital of France is', prompt_token_ids=[1263, 5771, 593, 8718, 620], prompt_logprobs=None, outputs=[CompletionOutput(index=0, text=' Paris and the capital of Japan is Tokyo.\nWhich city is the capital of', token_ids=[11080, 597, 567, 5771, 593, 4968, 620, 22723, 98, 144, 35916, 2726, 620, 567, 5771, 593], cumulative_logprob=-16.77308939769864, logprobs=None, finish_reason=length)], finished=True, lora_request=None),
 RequestOutput(request_id=3, prompt='The future of AI is', prompt_token_ids=[1263, 2653, 593, 13821, 620], prompt_logprobs=None, outputs=[CompletionOutput(index=0, text=' promising and it is also a huge concern for our society. It is very easy', token_ids=[15519, 597, 648, 620, 962, 562, 4284, 4921, 631, 915, 5740, 98, 983, 620, 1196, 2616], cumulative_logprob=-37.312677688896656, logprobs=None, finish_reason=length)], finished=True, lora_request=None)]

I don't even think the old codepath was used in most cases, because Yi on huggingface is now using Llama: https://huggingface.co/01-ai/Yi-6B/blob/main/config.json, https://huggingface.co/01-ai/Yi-34B/blob/main/config.json

@pcmoritz pcmoritz changed the title Remove Yi model definition, use LlamaForCausalLM instead Remove Yi model definition, please use LlamaForCausalLM instead Feb 13, 2024
Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Left a minor comment.

@WoosukKwon
Copy link
Collaborator

This should close #1899

@WoosukKwon WoosukKwon merged commit 317b29d into vllm-project:main Feb 13, 2024
@WoosukKwon WoosukKwon mentioned this pull request Feb 13, 2024
jvmncs pushed a commit to jvmncs/vllm that referenced this pull request Feb 14, 2024
xjpang pushed a commit to xjpang/vllm that referenced this pull request Feb 20, 2024
xjpang pushed a commit to xjpang/vllm that referenced this pull request Feb 22, 2024
xjpang pushed a commit to xjpang/vllm that referenced this pull request Mar 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants