Skip to content

Commit

Permalink
Templates: Switch to Jinja2
Browse files Browse the repository at this point in the history
Jinja2 is a lightweight template parser that's used in Transformers
for parsing chat completions. It's much more efficient than Fastchat
and can be imported as part of requirements.

Also allows for unblocking Pydantic's version.

Users now have to provide their own template if needed. A separate
repo may be usable for common prompt template storage.

Signed-off-by: kingbri <bdashore3@proton.me>
  • Loading branch information
bdashore3 committed Dec 19, 2023
1 parent 95fd0f0 commit f631dd6
Show file tree
Hide file tree
Showing 14 changed files with 115 additions and 74 deletions.
4 changes: 2 additions & 2 deletions OAI/types/chat_completion.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from uuid import uuid4
from time import time
from pydantic import BaseModel, Field
from typing import Union, List, Optional
from typing import Union, List, Optional, Dict
from OAI.types.common import UsageStats, CommonCompletionRequest

class ChatCompletionMessage(BaseModel):
Expand All @@ -24,7 +24,7 @@ class ChatCompletionStreamChoice(BaseModel):
class ChatCompletionRequest(CommonCompletionRequest):
# Messages
# Take in a string as well even though it's not part of the OAI spec
messages: Union[str, List[ChatCompletionMessage]]
messages: Union[str, List[Dict[str, str]]]
prompt_template: Optional[str] = None

class ChatCompletionResponse(BaseModel):
Expand Down
55 changes: 2 additions & 53 deletions OAI/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,9 @@
from OAI.types.common import UsageStats
from OAI.types.lora import LoraList, LoraCard
from OAI.types.model import ModelList, ModelCard
from packaging import version
from typing import Optional, List
from utils import unwrap
from typing import Optional

# Check fastchat
try:
import fastchat
from fastchat.model.model_adapter import get_conversation_template, get_conv_template
from fastchat.conversation import SeparatorStyle
_fastchat_available = True
except ImportError:
_fastchat_available = False
from utils import unwrap

def create_completion_response(text: str, prompt_tokens: int, completion_tokens: int, model_name: Optional[str]):
choice = CompletionRespChoice(
Expand Down Expand Up @@ -110,45 +101,3 @@ def get_lora_list(lora_path: pathlib.Path):
lora_list.data.append(lora_card)

return lora_list

def get_chat_completion_prompt(model_path: str, messages: List[ChatCompletionMessage], prompt_template: Optional[str] = None):

# TODO: Replace fastchat with in-house jinja templates
# Check if fastchat is available
if not _fastchat_available:
raise ModuleNotFoundError(
"Fastchat must be installed to parse these chat completion messages.\n"
"Please run the following command: pip install fschat[model_worker]"
)
if version.parse(fastchat.__version__) < version.parse("0.2.23"):
raise ImportError(
"Parsing these chat completion messages requires fastchat 0.2.23 or greater. "
f"Current version: {fastchat.__version__}\n"
"Please upgrade fastchat by running the following command: "
"pip install -U fschat[model_worker]"
)

if prompt_template:
conv = get_conv_template(prompt_template)
else:
conv = get_conversation_template(model_path)

if conv.sep_style is None:
conv.sep_style = SeparatorStyle.LLAMA2

for message in messages:
msg_role = message.role
if msg_role == "system":
conv.set_system_message(message.content)
elif msg_role == "user":
conv.append_message(conv.roles[0], message.content)
elif msg_role == "assistant":
conv.append_message(conv.roles[1], message.content)
else:
raise ValueError(f"Unknown role: {msg_role}")

conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()

print(prompt)
return prompt
8 changes: 6 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,6 @@ NOTE: For Flash Attention 2 to work on Windows, CUDA 12.x **must** be installed!

3. ROCm 5.6: `pip install -r requirements-amd.txt`

5. If you want the `/v1/chat/completions` endpoint to work with a list of messages, install fastchat by running `pip install fschat[model_worker]`

## Configuration

A config.yml file is required for overriding project defaults. If you are okay with the defaults, you don't need a config file!
Expand Down Expand Up @@ -126,6 +124,12 @@ All routes require an API key except for the following which require an **admin*

- `/v1/model/unload`

## Chat Completions

`/v1/chat/completions` now uses Jinja2 for templating. Please read [Huggingface's documentation](https://huggingface.co/docs/transformers/main/chat_templating) for more information of how chat templates work.

Also make sure to set the template name in `config.yml` to the template's filename.

## Common Issues

- AMD cards will error out with flash attention installed, even if the config option is set to False. Run `pip uninstall flash_attn` to remove the wheel from your system.
Expand Down
4 changes: 2 additions & 2 deletions config_sample.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ model:
# Enable 8 bit cache mode for VRAM savings (slight performance hit). Possible values FP16, FP8. (default: FP16)
cache_mode: FP16

# Set the prompt template for this model. If empty, fastchat will automatically find the best template to use (default: None)
# Set the prompt template for this model. If empty, chat completions will be disabled. (default: alpaca)
# NOTE: Only works with chat completion message lists!
prompt_template:
prompt_template: alpaca

# Number of experts to use per token. Loads from the model's config.json if not specified (default: None)
# WARNING: Don't set this unless you know what you're doing!
Expand Down
19 changes: 11 additions & 8 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@
create_completion_response,
get_model_list,
get_lora_list,
get_chat_completion_prompt,
create_chat_completion_response,
create_chat_completion_stream_chunk
)
from templating import get_prompt_from_template
from utils import get_generator_error, get_sse_packet, load_progress, unwrap

app = FastAPI()
Expand Down Expand Up @@ -76,14 +76,15 @@ async def list_models():
@app.get("/v1/internal/model/info", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
async def get_current_model():
model_name = model_container.get_model_path().name
prompt_template = model_container.prompt_template
model_card = ModelCard(
id = model_name,
parameters = ModelCardParameters(
rope_scale = model_container.config.scale_pos_emb,
rope_alpha = model_container.config.scale_alpha_value,
max_seq_len = model_container.config.max_seq_len,
cache_mode = "FP8" if model_container.cache_fp8 else "FP16",
prompt_template = unwrap(model_container.prompt_template, "auto")
prompt_template = prompt_template.name if prompt_template else None
),
logging = gen_logging.config
)
Expand Down Expand Up @@ -302,19 +303,21 @@ async def generator():
# Chat completions endpoint
@app.post("/v1/chat/completions", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
async def generate_chat_completion(request: Request, data: ChatCompletionRequest):
if model_container.prompt_template is None:
return HTTPException(422, "This endpoint is disabled because a prompt template is not set.")

model_path = model_container.get_model_path()

if isinstance(data.messages, str):
prompt = data.messages
else:
# If the request specified prompt template isn't found, use the one from model container
# Otherwise, let fastchat figure it out
prompt_template = unwrap(data.prompt_template, model_container.prompt_template)

try:
prompt = get_chat_completion_prompt(model_path.name, data.messages, prompt_template)
prompt = get_prompt_from_template(data.messages, model_container.prompt_template)
except KeyError:
return HTTPException(400, f"Could not find a Conversation from prompt template '{prompt_template}'. Check your spelling?")
return HTTPException(
400,
f"Could not find a Conversation from prompt template '{model_container.prompt_template.name}'. Check your spelling?"
)

if data.stream:
const_id = f"chatcmpl-{uuid4().hex}"
Expand Down
18 changes: 16 additions & 2 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from gen_logging import log_generation_params, log_prompt, log_response
from typing import List, Optional, Union
from templating import PromptTemplate
from utils import coalesce, unwrap

# Bytes to reserve on first device when loading with auto split
Expand All @@ -31,7 +32,7 @@ class ModelContainer:
draft_cache: Optional[ExLlamaV2Cache] = None
tokenizer: Optional[ExLlamaV2Tokenizer] = None
generator: Optional[ExLlamaV2StreamingGenerator] = None
prompt_template: Optional[str] = None
prompt_template: Optional[PromptTemplate] = None

cache_fp8: bool = False
gpu_split_auto: bool = True
Expand Down Expand Up @@ -103,7 +104,20 @@ def progress(loaded_modules: int, total_modules: int, loading_draft: bool)
"""

# Set prompt template override if provided
self.prompt_template = kwargs.get("prompt_template")
prompt_template_name = kwargs.get("prompt_template")
if prompt_template_name:
try:
with open(pathlib.Path(f"templates/{prompt_template_name}.jinja"), "r") as raw_template:
self.prompt_template = PromptTemplate(
name = prompt_template_name,
template = raw_template.read()
)
except OSError:
print("Chat completions are disabled because the provided prompt template couldn't be found.")
self.prompt_template = None
else:
print("Chat completions are disabled because a provided prompt template couldn't be found.")
self.prompt_template = None

# Set num of experts per token if provided
num_experts_override = kwargs.get("num_experts_per_token")
Expand Down
1 change: 1 addition & 0 deletions requirements-amd.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ pydantic < 2,>= 1
PyYAML
progress
uvicorn
jinja2
1 change: 1 addition & 0 deletions requirements-cu118.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ pydantic < 2,>= 1
PyYAML
progress
uvicorn
jinja2

# Linux FA2 from https://github.com/Dao-AILab/flash-attention/releases
https://github.com/Dao-AILab/flash-attention/releases/download/v2.3.6/flash_attn-2.3.6+cu118torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10"
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ pydantic < 2,>= 1
PyYAML
progress
uvicorn
jinja2

# Flash attention v2

Expand Down
7 changes: 7 additions & 0 deletions templates/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Templates

NOTE: This folder will be replaced by a submodule or something similar in the future

These templates are examples from [Aphrodite Engine](https://github.com/PygmalionAI/aphrodite-engine/tree/main/examples)

Please look at [Huggingface's documentation](https://huggingface.co/docs/transformers/main/chat_templating) for making Jinja2 templates.
29 changes: 29 additions & 0 deletions templates/alpaca.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
{{ (messages|selectattr('role', 'equalto', 'system')|list|last).content|trim if (messages|selectattr('role', 'equalto', 'system')|list) else '' }}

{% for message in messages %}
{% if message['role'] == 'user' %}
### Instruction:
{{ message['content']|trim -}}
{% if not loop.last %}


{% endif %}
{% elif message['role'] == 'assistant' %}
### Response:
{{ message['content']|trim -}}
{% if not loop.last %}


{% endif %}
{% elif message['role'] == 'user_context' %}
### Input:
{{ message['content']|trim -}}
{% if not loop.last %}


{% endif %}
{% endif %}
{% endfor %}
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}
### Response:
{% endif %}
2 changes: 2 additions & 0 deletions templates/chatml.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\n'}}{% endif %}{% endfor %}
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\n' }}{% endif %}
30 changes: 30 additions & 0 deletions templating.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from functools import lru_cache
from importlib.metadata import version as package_version
from packaging import version
from jinja2.sandbox import ImmutableSandboxedEnvironment
from pydantic import BaseModel

# Small replication of AutoTokenizer's chat template system for efficiency

class PromptTemplate(BaseModel):
name: str
template: str

def get_prompt_from_template(messages, prompt_template: PromptTemplate):
if version.parse(package_version("jinja2")) < version.parse("3.0.0"):
raise ImportError(
"Parsing these chat completion messages requires fastchat 0.2.23 or greater. "
f"Current version: {version('jinja2')}\n"
"Please upgrade fastchat by running the following command: "
"pip install -U fschat[model_worker]"
)

compiled_template = _compile_template(prompt_template.template)
return compiled_template.render(messages = messages)

# Inspired from https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L1761
@lru_cache
def _compile_template(template: str):
jinja_env = ImmutableSandboxedEnvironment(trim_blocks = True, lstrip_blocks = True)
jinja_template = jinja_env.from_string(template)
return jinja_template
10 changes: 5 additions & 5 deletions tests/wheel_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@
print("Torch is not found in your environment.")
errored_packages.append("torch")

if find_spec("fastchat") is not None:
print(f"Fastchat on version {version('fschat')} successfully imported")
successful_packages.append("fastchat")
if find_spec("jinja2") is not None:
print(f"Jinja2 on version {version('jinja2')} successfully imported")
successful_packages.append("jinja2")
else:
print("Fastchat is not found in your environment. It isn't needed unless you're using chat completions with message arrays.")
errored_packages.append("fastchat")
print("Jinja2 is not found in your environment.")
errored_packages.append("jinja2")

print(
f"\nSuccessful imports: {', '.join(successful_packages)}",
Expand Down

0 comments on commit f631dd6

Please sign in to comment.