Skip to content

Commit

Permalink
OAI: Add ability to specify fastchat prompt template
Browse files Browse the repository at this point in the history
Sometimes fastchat may not be able to detect the prompt template from
the model path. Therefore, add the ability to set it in config.yml or
via the request object itself.

Also send the provided prompt template on model info request.

Signed-off-by: kingbri <bdashore3@proton.me>
  • Loading branch information
bdashore3 committed Dec 10, 2023
1 parent 9f195af commit db87efd
Show file tree
Hide file tree
Showing 7 changed files with 34 additions and 8 deletions.
1 change: 1 addition & 0 deletions OAI/types/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class ChatCompletionRequest(CommonCompletionRequest):
# Messages
# Take in a string as well even though it's not part of the OAI spec
messages: Union[str, List[ChatCompletionMessage]]
prompt_template: Optional[str] = None

class ChatCompletionResponse(BaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{uuid4().hex}")
Expand Down
2 changes: 1 addition & 1 deletion OAI/types/lora.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pydantic import BaseModel, Field;
from pydantic import BaseModel, Field
from time import time
from typing import Optional, List

Expand Down
2 changes: 2 additions & 0 deletions OAI/types/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ class ModelCardParameters(BaseModel):
max_seq_len: Optional[int] = 4096
rope_scale: Optional[float] = 1.0
rope_alpha: Optional[float] = 1.0
prompt_template: Optional[str] = None
draft: Optional['ModelCard'] = None

class ModelCard(BaseModel):
Expand Down Expand Up @@ -34,6 +35,7 @@ class ModelLoadRequest(BaseModel):
rope_alpha: Optional[float] = 1.0
no_flash_attention: Optional[bool] = False
low_mem: Optional[bool] = False
prompt_template: Optional[str] = None
draft: Optional[DraftModelLoadRequest] = None

class ModelLoadResponse(BaseModel):
Expand Down
18 changes: 12 additions & 6 deletions OAI/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os, pathlib
from OAI.types.completion import CompletionResponse, CompletionRespChoice, UsageStats
import pathlib
from OAI.types.completion import CompletionResponse, CompletionRespChoice
from OAI.types.chat_completion import (
ChatCompletionMessage,
ChatCompletionRespChoice,
Expand All @@ -11,13 +11,13 @@
from OAI.types.lora import LoraList, LoraCard
from OAI.types.model import ModelList, ModelCard
from packaging import version
from typing import Optional, List, Dict
from typing import Optional, List
from utils import unwrap

# Check fastchat
try:
import fastchat
from fastchat.model.model_adapter import get_conversation_template
from fastchat.model.model_adapter import get_conversation_template, get_conv_template
from fastchat.conversation import SeparatorStyle
_fastchat_available = True
except ImportError:
Expand Down Expand Up @@ -111,8 +111,9 @@ def get_lora_list(lora_path: pathlib.Path):

return lora_list

def get_chat_completion_prompt(model_path: str, messages: List[ChatCompletionMessage]):
def get_chat_completion_prompt(model_path: str, messages: List[ChatCompletionMessage], prompt_template: Optional[str] = None):

# TODO: Replace fastchat with in-house jinja templates
# Check if fastchat is available
if not _fastchat_available:
raise ModuleNotFoundError(
Expand All @@ -127,7 +128,11 @@ def get_chat_completion_prompt(model_path: str, messages: List[ChatCompletionMes
"pip install -U fschat[model_worker]"
)

conv = get_conversation_template(model_path)
if prompt_template:
conv = get_conv_template(prompt_template)
else:
conv = get_conversation_template(model_path)

if conv.sep_style is None:
conv.sep_style = SeparatorStyle.LLAMA2

Expand All @@ -145,4 +150,5 @@ def get_chat_completion_prompt(model_path: str, messages: List[ChatCompletionMes
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()

print(prompt)
return prompt
4 changes: 4 additions & 0 deletions config_sample.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ model:
# Enable 8 bit cache mode for VRAM savings (slight performance hit). Possible values FP16, FP8. (default: FP16)
cache_mode: FP16

# Set the prompt template for this model. If empty, fastchat will automatically find the best template to use (default: None)
# NOTE: Only works with chat completion message lists!
prompt_template:

# Options for draft models (speculative decoding). This will use more VRAM!
draft:
# Overrides the directory to look for draft (default: models)
Expand Down
10 changes: 9 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ async def get_current_model():
rope_scale = model_container.config.scale_pos_emb,
rope_alpha = model_container.config.scale_alpha_value,
max_seq_len = model_container.config.max_seq_len,
prompt_template = unwrap(model_container.prompt_template, "auto")
)
)

Expand Down Expand Up @@ -302,7 +303,14 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest
if isinstance(data.messages, str):
prompt = data.messages
else:
prompt = get_chat_completion_prompt(model_path.name, data.messages)
# If the request specified prompt template isn't found, use the one from model container
# Otherwise, let fastchat figure it out
prompt_template = unwrap(data.prompt_template, model_container.prompt_template)

try:
prompt = get_chat_completion_prompt(model_path.name, data.messages, prompt_template)
except KeyError:
return HTTPException(400, f"Could not find a Conversation from prompt template '{prompt_template}'. Check your spelling?")

if data.stream:
const_id = f"chatcmpl-{uuid4().hex}"
Expand Down
5 changes: 5 additions & 0 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class ModelContainer:
draft_cache: Optional[ExLlamaV2Cache] = None
tokenizer: Optional[ExLlamaV2Tokenizer] = None
generator: Optional[ExLlamaV2StreamingGenerator] = None
prompt_template: Optional[str] = None

cache_fp8: bool = False
gpu_split_auto: bool = True
Expand All @@ -48,6 +49,7 @@ def progress(loaded_modules: int, total_modules: int, loading_draft: bool)
'max_seq_len' (int): Override model's default max sequence length (default: 4096)
'rope_scale' (float): Set RoPE scaling factor for model (default: 1.0)
'rope_alpha' (float): Set RoPE alpha (NTK) factor for model (default: 1.0)
'prompt_template' (str): Manually sets the prompt template for this model (default: None)
'chunk_size' (int): Sets the maximum chunk size for the model (default: 2048)
Inferencing in chunks reduces overall VRAM overhead by processing very long sequences in smaller
batches. This limits the size of temporary buffers needed for the hidden state and attention
Expand Down Expand Up @@ -93,6 +95,9 @@ def progress(loaded_modules: int, total_modules: int, loading_draft: bool)
self.config.set_low_mem()
"""

# Set prompt template override if provided
self.prompt_template = kwargs.get("prompt_template")

chunk_size = min(unwrap(kwargs.get("chunk_size"), 2048), self.config.max_seq_len)
self.config.max_input_len = chunk_size
self.config.max_attn_size = chunk_size ** 2
Expand Down

0 comments on commit db87efd

Please sign in to comment.