Skip to content

Commit db87efd

Browse files
committed
OAI: Add ability to specify fastchat prompt template
Sometimes fastchat may not be able to detect the prompt template from the model path. Therefore, add the ability to set it in config.yml or via the request object itself. Also send the provided prompt template on model info request. Signed-off-by: kingbri <bdashore3@proton.me>
1 parent 9f195af commit db87efd

File tree

7 files changed

+34
-8
lines changed

7 files changed

+34
-8
lines changed

OAI/types/chat_completion.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ class ChatCompletionRequest(CommonCompletionRequest):
2525
# Messages
2626
# Take in a string as well even though it's not part of the OAI spec
2727
messages: Union[str, List[ChatCompletionMessage]]
28+
prompt_template: Optional[str] = None
2829

2930
class ChatCompletionResponse(BaseModel):
3031
id: str = Field(default_factory=lambda: f"chatcmpl-{uuid4().hex}")

OAI/types/lora.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from pydantic import BaseModel, Field;
1+
from pydantic import BaseModel, Field
22
from time import time
33
from typing import Optional, List
44

OAI/types/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ class ModelCardParameters(BaseModel):
66
max_seq_len: Optional[int] = 4096
77
rope_scale: Optional[float] = 1.0
88
rope_alpha: Optional[float] = 1.0
9+
prompt_template: Optional[str] = None
910
draft: Optional['ModelCard'] = None
1011

1112
class ModelCard(BaseModel):
@@ -34,6 +35,7 @@ class ModelLoadRequest(BaseModel):
3435
rope_alpha: Optional[float] = 1.0
3536
no_flash_attention: Optional[bool] = False
3637
low_mem: Optional[bool] = False
38+
prompt_template: Optional[str] = None
3739
draft: Optional[DraftModelLoadRequest] = None
3840

3941
class ModelLoadResponse(BaseModel):

OAI/utils.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
import os, pathlib
2-
from OAI.types.completion import CompletionResponse, CompletionRespChoice, UsageStats
1+
import pathlib
2+
from OAI.types.completion import CompletionResponse, CompletionRespChoice
33
from OAI.types.chat_completion import (
44
ChatCompletionMessage,
55
ChatCompletionRespChoice,
@@ -11,13 +11,13 @@
1111
from OAI.types.lora import LoraList, LoraCard
1212
from OAI.types.model import ModelList, ModelCard
1313
from packaging import version
14-
from typing import Optional, List, Dict
14+
from typing import Optional, List
1515
from utils import unwrap
1616

1717
# Check fastchat
1818
try:
1919
import fastchat
20-
from fastchat.model.model_adapter import get_conversation_template
20+
from fastchat.model.model_adapter import get_conversation_template, get_conv_template
2121
from fastchat.conversation import SeparatorStyle
2222
_fastchat_available = True
2323
except ImportError:
@@ -111,8 +111,9 @@ def get_lora_list(lora_path: pathlib.Path):
111111

112112
return lora_list
113113

114-
def get_chat_completion_prompt(model_path: str, messages: List[ChatCompletionMessage]):
114+
def get_chat_completion_prompt(model_path: str, messages: List[ChatCompletionMessage], prompt_template: Optional[str] = None):
115115

116+
# TODO: Replace fastchat with in-house jinja templates
116117
# Check if fastchat is available
117118
if not _fastchat_available:
118119
raise ModuleNotFoundError(
@@ -127,7 +128,11 @@ def get_chat_completion_prompt(model_path: str, messages: List[ChatCompletionMes
127128
"pip install -U fschat[model_worker]"
128129
)
129130

130-
conv = get_conversation_template(model_path)
131+
if prompt_template:
132+
conv = get_conv_template(prompt_template)
133+
else:
134+
conv = get_conversation_template(model_path)
135+
131136
if conv.sep_style is None:
132137
conv.sep_style = SeparatorStyle.LLAMA2
133138

@@ -145,4 +150,5 @@ def get_chat_completion_prompt(model_path: str, messages: List[ChatCompletionMes
145150
conv.append_message(conv.roles[1], None)
146151
prompt = conv.get_prompt()
147152

153+
print(prompt)
148154
return prompt

config_sample.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,10 @@ model:
4848
# Enable 8 bit cache mode for VRAM savings (slight performance hit). Possible values FP16, FP8. (default: FP16)
4949
cache_mode: FP16
5050

51+
# Set the prompt template for this model. If empty, fastchat will automatically find the best template to use (default: None)
52+
# NOTE: Only works with chat completion message lists!
53+
prompt_template:
54+
5155
# Options for draft models (speculative decoding). This will use more VRAM!
5256
draft:
5357
# Overrides the directory to look for draft (default: models)

main.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ async def get_current_model():
8080
rope_scale = model_container.config.scale_pos_emb,
8181
rope_alpha = model_container.config.scale_alpha_value,
8282
max_seq_len = model_container.config.max_seq_len,
83+
prompt_template = unwrap(model_container.prompt_template, "auto")
8384
)
8485
)
8586

@@ -302,7 +303,14 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest
302303
if isinstance(data.messages, str):
303304
prompt = data.messages
304305
else:
305-
prompt = get_chat_completion_prompt(model_path.name, data.messages)
306+
# If the request specified prompt template isn't found, use the one from model container
307+
# Otherwise, let fastchat figure it out
308+
prompt_template = unwrap(data.prompt_template, model_container.prompt_template)
309+
310+
try:
311+
prompt = get_chat_completion_prompt(model_path.name, data.messages, prompt_template)
312+
except KeyError:
313+
return HTTPException(400, f"Could not find a Conversation from prompt template '{prompt_template}'. Check your spelling?")
306314

307315
if data.stream:
308316
const_id = f"chatcmpl-{uuid4().hex}"

model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class ModelContainer:
2727
draft_cache: Optional[ExLlamaV2Cache] = None
2828
tokenizer: Optional[ExLlamaV2Tokenizer] = None
2929
generator: Optional[ExLlamaV2StreamingGenerator] = None
30+
prompt_template: Optional[str] = None
3031

3132
cache_fp8: bool = False
3233
gpu_split_auto: bool = True
@@ -48,6 +49,7 @@ def progress(loaded_modules: int, total_modules: int, loading_draft: bool)
4849
'max_seq_len' (int): Override model's default max sequence length (default: 4096)
4950
'rope_scale' (float): Set RoPE scaling factor for model (default: 1.0)
5051
'rope_alpha' (float): Set RoPE alpha (NTK) factor for model (default: 1.0)
52+
'prompt_template' (str): Manually sets the prompt template for this model (default: None)
5153
'chunk_size' (int): Sets the maximum chunk size for the model (default: 2048)
5254
Inferencing in chunks reduces overall VRAM overhead by processing very long sequences in smaller
5355
batches. This limits the size of temporary buffers needed for the hidden state and attention
@@ -93,6 +95,9 @@ def progress(loaded_modules: int, total_modules: int, loading_draft: bool)
9395
self.config.set_low_mem()
9496
"""
9597

98+
# Set prompt template override if provided
99+
self.prompt_template = kwargs.get("prompt_template")
100+
96101
chunk_size = min(unwrap(kwargs.get("chunk_size"), 2048), self.config.max_seq_len)
97102
self.config.max_input_len = chunk_size
98103
self.config.max_attn_size = chunk_size ** 2

0 commit comments

Comments
 (0)