Skip to content

Commit

Permalink
Merge branch 'main' into logprob_protocol
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored Nov 1, 2023
2 parents c1ce8dd + d5e4b27 commit aa1d6b4
Show file tree
Hide file tree
Showing 8 changed files with 113 additions and 4 deletions.
2 changes: 2 additions & 0 deletions docs/model_support.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
- Vicuna, Alpaca, LLaMA, Koala
- example: `python3 -m fastchat.serve.cli --model-path lmsys/vicuna-7b-v1.5`
- [BAAI/AquilaChat-7B](https://huggingface.co/BAAI/AquilaChat-7B)
- [BAAI/AquilaChat2-7B](https://huggingface.co/BAAI/AquilaChat2-7B)
- [BAAI/AquilaChat2-34B](https://huggingface.co/BAAI/AquilaChat2-34B)
- [BAAI/bge-large-en](https://huggingface.co/BAAI/bge-large-en#using-huggingface-transformers)
- [baichuan-inc/baichuan-7B](https://huggingface.co/baichuan-inc/baichuan-7B)
- [BlinkDL/RWKV-4-Raven](https://huggingface.co/BlinkDL/rwkv-4-raven)
Expand Down
46 changes: 45 additions & 1 deletion fastchat/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -969,13 +969,57 @@ def get_conv_template(name: str) -> Conversation:
name="aquila-chat",
system_message="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
roles=("Human", "Assistant", "System"),
roles=("Human", "Assistant"),
sep_style=SeparatorStyle.ADD_COLON_SINGLE,
sep="###",
sep2="",
stop_str=["###", "</s>", "[UNK]"],
)
)
# AquilaChat2-34B default template
# source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L212
register_conv_template(
Conversation(
name="aquila-legacy",
system_message="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
roles=("### Human: ", "### Assistant: "),
offset=0,
sep_style=SeparatorStyle.NO_COLON_TWO,
sep="\n",
sep2="</s>",
stop_str=["</s>", "[UNK]"],
)
)
# AquilaChat2-7B-16K and AquilaChat2-34B-16K default template
# source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L227
register_conv_template(
Conversation(
name="aquila",
system_message="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
roles=("Human", "Assistant"),
offset=0,
sep_style=SeparatorStyle.ADD_COLON_TWO,
sep="###",
sep2="</s>",
stop_str=["</s>", "[UNK]"],
)
)

# AquilaChat2-7B default template
# source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L242
register_conv_template(
Conversation(
name="aquila-v1",
roles=("<|startofpiece|>", "<|endofpiece|>"),
offset=0,
sep_style=SeparatorStyle.NO_COLON_TWO,
sep="",
sep2="</s>",
stop_str=["</s>", "<|endoftext|>"],
)
)

# Llama2-Chinese default template
# source: https://huggingface.co/FlagAlpha
Expand Down
20 changes: 18 additions & 2 deletions fastchat/model/model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1532,7 +1532,13 @@ def get_default_conv_template(self, model_path: str) -> Conversation:


class AquilaChatAdapter(BaseModelAdapter):
"""The model adapter for BAAI/AquilaChat-7B"""
"""The model adapter for BAAI/Aquila
Now supports:
- BAAI/AquilaChat-7B
- BAAI/AquilaChat2-7B
- BAAI/AquilaChat2-34B
"""

def match(self, model_path: str):
return "aquila" in model_path.lower()
Expand All @@ -1552,7 +1558,17 @@ def load_model(self, model_path: str, from_pretrained_kwargs: dict):
return model, tokenizer

def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("aquila-chat")
model_path = model_path.lower()
# See: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L347
if "aquilachat2" in model_path:
if "16k" in model_path:
return get_conv_template("aquila")
elif "34b" in model_path:
return get_conv_template("aquila-legacy")
else:
return get_conv_template("aquila-v1")
else:
return get_conv_template("aquila-chat")


class Lamma2ChineseAdapter(BaseModelAdapter):
Expand Down
11 changes: 11 additions & 0 deletions fastchat/model/model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,3 +352,14 @@ def get_model_info(name: str) -> ModelInfo:
"https://huggingface.co/Open-Orca/Mistral-7B-OpenOrca",
"A fine-tune of [Mistral 7B](https://huggingface.co/mistralai/Mistral-7B-v0.1) using [OpenOrca dataset](https://huggingface.co/datasets/Open-Orca/OpenOrca)",
)

register_model_info(
[
"AquilaChat-7B",
"AquilaChat2-7B",
"AquilaChat2-34B",
],
"Aquila-Chat",
"https://huggingface.co/BAAI/AquilaChat2-34B",
"Chat models developed by BAAI team",
)
4 changes: 4 additions & 0 deletions fastchat/protocol/api_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,15 @@ class APIChatCompletionRequest(BaseModel):
messages: Union[str, List[Dict[str, str]]]
temperature: Optional[float] = 0.7
top_p: Optional[float] = 1.0
top_k: Optional[int] = -1
n: Optional[int] = 1
max_tokens: Optional[int] = None
stop: Optional[Union[str, List[str]]] = None
stream: Optional[bool] = False
user: Optional[str] = None
repetition_penalty: Optional[float] = 1.0
frequency_penalty: Optional[float] = 0.0
presence_penalty: Optional[float] = 0.0


class ChatMessage(BaseModel):
Expand Down Expand Up @@ -130,6 +133,7 @@ class CompletionRequest(BaseModel):
stop: Optional[Union[str, List[str]]] = None
stream: Optional[bool] = False
top_p: Optional[float] = 1.0
top_k: Optional[int] = -1
logprobs: Optional[int] = None
echo: Optional[bool] = False
presence_penalty: Optional[float] = 0.0
Expand Down
2 changes: 2 additions & 0 deletions fastchat/protocol/openai_api_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class ChatCompletionRequest(BaseModel):
messages: Union[str, List[Dict[str, str]]]
temperature: Optional[float] = 0.7
top_p: Optional[float] = 1.0
top_k: Optional[int] = -1
n: Optional[int] = 1
max_tokens: Optional[int] = None
stop: Optional[Union[str, List[str]]] = None
Expand Down Expand Up @@ -153,6 +154,7 @@ class CompletionRequest(BaseModel):
stop: Optional[Union[str, List[str]]] = None
stream: Optional[bool] = False
top_p: Optional[float] = 1.0
top_k: Optional[int] = -1
logprobs: Optional[int] = None
echo: Optional[bool] = False
presence_penalty: Optional[float] = 0.0
Expand Down
25 changes: 24 additions & 1 deletion fastchat/serve/openai_api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,11 @@ def check_requests(request) -> Optional[JSONResponse]:
ErrorCode.PARAM_OUT_OF_RANGE,
f"{request.top_p} is greater than the maximum of 1 - 'temperature'",
)
if request.top_k is not None and (request.top_k > -1 and request.top_k < 1):
return create_error_response(
ErrorCode.PARAM_OUT_OF_RANGE,
f"{request.top_k} is out of Range. Either set top_k to -1 or >=1.",
)
if request.stop is not None and (
not isinstance(request.stop, str) and not isinstance(request.stop, list)
):
Expand Down Expand Up @@ -246,6 +251,9 @@ async def get_gen_params(
*,
temperature: float,
top_p: float,
top_k: Optional[int],
presence_penalty: Optional[float],
frequency_penalty: Optional[float],
max_tokens: Optional[int],
echo: Optional[bool],
logprobs: Optional[int] = None,
Expand Down Expand Up @@ -290,8 +298,11 @@ async def get_gen_params(
"model": model_name,
"prompt": prompt,
"temperature": temperature,
"top_p": top_p,
"logprobs": logprobs,
"top_p": top_p,
"top_k": top_k,
"presence_penalty": presence_penalty,
"frequency_penalty": frequency_penalty,
"max_new_tokens": max_tokens,
"echo": echo,
"stop_token_ids": conv.stop_token_ids,
Expand Down Expand Up @@ -374,6 +385,9 @@ async def create_chat_completion(request: ChatCompletionRequest):
request.messages,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
presence_penalty=request.presence_penalty,
frequency_penalty=request.frequency_penalty,
max_tokens=request.max_tokens,
echo=False,
stop=request.stop,
Expand Down Expand Up @@ -506,6 +520,9 @@ async def create_completion(request: CompletionRequest):
text,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
frequency_penalty=request.frequency_penalty,
presence_penalty=request.presence_penalty,
max_tokens=request.max_tokens,
logprobs=request.logprobs,
echo=request.echo,
Expand Down Expand Up @@ -561,6 +578,9 @@ async def generate_completion_stream_generator(
text,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
presence_penalty=request.presence_penalty,
frequency_penalty=request.frequency_penalty,
max_tokens=request.max_tokens,
logprobs=request.logprobs,
echo=request.echo,
Expand Down Expand Up @@ -741,6 +761,9 @@ async def create_chat_completion(request: APIChatCompletionRequest):
request.messages,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
presence_penalty=request.presence_penalty,
frequency_penalty=request.frequency_penalty,
max_tokens=request.max_tokens,
echo=False,
stop=request.stop,
Expand Down
7 changes: 7 additions & 0 deletions fastchat/serve/vllm_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ async def generate_stream(self, params):
request_id = params.pop("request_id")
temperature = float(params.get("temperature", 1.0))
top_p = float(params.get("top_p", 1.0))
top_k = params.get("top_k", -1.0)
presence_penalty = float(params.get("presence_penalty", 0.0))
frequency_penalty = float(params.get("frequency_penalty", 0.0))
max_new_tokens = params.get("max_new_tokens", 256)
stop_str = params.get("stop", None)
stop_token_ids = params.get("stop_token_ids", None) or []
Expand All @@ -92,13 +95,17 @@ async def generate_stream(self, params):
top_p = max(top_p, 1e-5)
if temperature <= 1e-5:
top_p = 1.0

sampling_params = SamplingParams(
n=1,
temperature=temperature,
top_p=top_p,
use_beam_search=use_beam_search,
stop=list(stop),
max_tokens=max_new_tokens,
top_k=top_k,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
best_of=best_of,
)
results_generator = engine.generate(context, sampling_params, request_id)
Expand Down

0 comments on commit aa1d6b4

Please sign in to comment.