Skip to content

Commit

Permalink
Add Yi support (#2723)
Browse files Browse the repository at this point in the history
  • Loading branch information
infwinston authored Nov 23, 2023
1 parent 6ac7d76 commit 1f21efb
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 0 deletions.
17 changes: 17 additions & 0 deletions fastchat/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1021,6 +1021,23 @@ def get_conv_template(name: str) -> Conversation:
)
)

# source: https://huggingface.co/01-ai/Yi-34B-Chat/blob/main/tokenizer_config.json#L60
register_conv_template(
Conversation(
name="Yi-34b-chat",
roles=("<|im_start|>user", "<|im_start|>assistant"),
sep_style=SeparatorStyle.CHATML,
sep="<|im_end|>",
stop_token_ids=[
2,
6,
7,
8,
], # "<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|im_sep|>"
stop_str="<|endoftext|>",
)
)


# AquilaChat default template
# source: https://github.com/FlagAI-Open/FlagAI/blob/master/examples/Aquila/Aquila-chat/cyg_conversation.py
Expand Down
11 changes: 11 additions & 0 deletions fastchat/model/model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1885,6 +1885,16 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("orca-2")


class YiAdapter(BaseModelAdapter):
"""The model adapter for Yi models"""

def match(self, model_path: str):
return "yi-34b-chat" in model_path.lower()

def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("Yi-34b-chat")


# Note: the registration order matters.
# The one registered earlier has a higher matching priority.
register_model_adapter(PeftModelAdapter)
Expand Down Expand Up @@ -1954,6 +1964,7 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
register_model_adapter(LemurAdapter)
register_model_adapter(PygmalionAdapter)
register_model_adapter(MicrosoftOrcaAdapter)
register_model_adapter(YiAdapter)

# After all adapters, try the default base adapter.
register_model_adapter(BaseModelAdapter)
7 changes: 7 additions & 0 deletions fastchat/model/model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,3 +397,10 @@ def get_model_info(name: str) -> ModelInfo:
"https://huggingface.co/BAAI/AquilaChat2-34B",
"Chat models developed by BAAI team",
)

register_model_info(
["Yi-34B-Chat"],
"Yi-Chat",
"https://huggingface.co/01-ai",
"A large language model by 01.AI.",
)
1 change: 1 addition & 0 deletions fastchat/serve/vllm_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ async def generate_stream(self, params):
top_p=top_p,
use_beam_search=use_beam_search,
stop=list(stop),
stop_token_ids=stop_token_ids,
max_tokens=max_new_tokens,
top_k=top_k,
presence_penalty=presence_penalty,
Expand Down

0 comments on commit 1f21efb

Please sign in to comment.