From 1f21efb5883b4cefbd7baf49721f1b5302b6c52c Mon Sep 17 00:00:00 2001 From: Wei-Lin Chiang Date: Wed, 22 Nov 2023 23:45:31 -0800 Subject: [PATCH] Add Yi support (#2723) --- fastchat/conversation.py | 17 +++++++++++++++++ fastchat/model/model_adapter.py | 11 +++++++++++ fastchat/model/model_registry.py | 7 +++++++ fastchat/serve/vllm_worker.py | 1 + 4 files changed, 36 insertions(+) diff --git a/fastchat/conversation.py b/fastchat/conversation.py index f00efbd6e..20426d080 100644 --- a/fastchat/conversation.py +++ b/fastchat/conversation.py @@ -1021,6 +1021,23 @@ def get_conv_template(name: str) -> Conversation: ) ) +# source: https://huggingface.co/01-ai/Yi-34B-Chat/blob/main/tokenizer_config.json#L60 +register_conv_template( + Conversation( + name="Yi-34b-chat", + roles=("<|im_start|>user", "<|im_start|>assistant"), + sep_style=SeparatorStyle.CHATML, + sep="<|im_end|>", + stop_token_ids=[ + 2, + 6, + 7, + 8, + ], # "<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|im_sep|>" + stop_str="<|endoftext|>", + ) +) + # AquilaChat default template # source: https://github.com/FlagAI-Open/FlagAI/blob/master/examples/Aquila/Aquila-chat/cyg_conversation.py diff --git a/fastchat/model/model_adapter.py b/fastchat/model/model_adapter.py index 194fec6e8..91fe223fb 100644 --- a/fastchat/model/model_adapter.py +++ b/fastchat/model/model_adapter.py @@ -1885,6 +1885,16 @@ def get_default_conv_template(self, model_path: str) -> Conversation: return get_conv_template("orca-2") +class YiAdapter(BaseModelAdapter): + """The model adapter for Yi models""" + + def match(self, model_path: str): + return "yi-34b-chat" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("Yi-34b-chat") + + # Note: the registration order matters. # The one registered earlier has a higher matching priority. register_model_adapter(PeftModelAdapter) @@ -1954,6 +1964,7 @@ def get_default_conv_template(self, model_path: str) -> Conversation: register_model_adapter(LemurAdapter) register_model_adapter(PygmalionAdapter) register_model_adapter(MicrosoftOrcaAdapter) +register_model_adapter(YiAdapter) # After all adapters, try the default base adapter. register_model_adapter(BaseModelAdapter) diff --git a/fastchat/model/model_registry.py b/fastchat/model/model_registry.py index 5bef2996f..a8e603c72 100644 --- a/fastchat/model/model_registry.py +++ b/fastchat/model/model_registry.py @@ -397,3 +397,10 @@ def get_model_info(name: str) -> ModelInfo: "https://huggingface.co/BAAI/AquilaChat2-34B", "Chat models developed by BAAI team", ) + +register_model_info( + ["Yi-34B-Chat"], + "Yi-Chat", + "https://huggingface.co/01-ai", + "A large language model by 01.AI.", +) diff --git a/fastchat/serve/vllm_worker.py b/fastchat/serve/vllm_worker.py index 6428d8b44..59ee172a1 100644 --- a/fastchat/serve/vllm_worker.py +++ b/fastchat/serve/vllm_worker.py @@ -101,6 +101,7 @@ async def generate_stream(self, params): top_p=top_p, use_beam_search=use_beam_search, stop=list(stop), + stop_token_ids=stop_token_ids, max_tokens=max_new_tokens, top_k=top_k, presence_penalty=presence_penalty,