Skip to content

Commit a45b075

Browse files
committed
Finalize + fix 3.6 tokenization and hf chat template, disable pydantic protected space warnings
1 parent 2c619f0 commit a45b075

File tree

3 files changed

+13
-8
lines changed

3 files changed

+13
-8
lines changed

ochat/config/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def _v3_6_role_prefix(from_role, condition, role_start_token="", role_end_token=
5353
system_as_role=True,
5454
inference_condition="GPT4 Correct",
5555
message_prefix="\n\n"),
56-
hf_chat_template="{% set loop_messages = messages %}{% for message in loop_messages %}{% if message['role'] in ['user', 'assistant'] %}{% set content = '<|start_header_id|>GPT4 Correct ' + message['role'].title() + '<|end_header_id|>\n\n' + message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>GPT4 Correct Assistant<|end_header_id|>\n\n' }}{% endif %}",
56+
hf_chat_template="{% set loop_messages = messages %}{% for message in loop_messages %}{% if message['role'] in ['user', 'assistant'] %}{% set content = '<|start_header_id|>GPT4 Correct ' + message['role'].title() + '<|end_header_id|>\n\n' + message['content'] | trim + '<|eot_id|>' %}{% elif message['role'] == 'system' %}{% set content = '<|start_header_id|>System<|end_header_id|>\n\n' + message['content'] | trim + '<|eot_id|>' %}{% endif %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>GPT4 Correct Assistant<|end_header_id|>\n\n' }}{% endif %}",
5757
),
5858

5959
# OpenChat V3.2

ochat/config/conversation_template.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ class ConversationTemplate(BaseModel):
3333
bos_tokens_: List[int]
3434
eot_tokens_: List[int]
3535
message_prefix_tokens_: List[int]
36+
system_role_tokens_: Optional[List[int]] = []
3637

3738
def __init__(self, **data):
3839
tokenizer = data["tokenizer"]
@@ -63,15 +64,14 @@ def tokenize_conversations(self, conversations: Iterable[Conversation], inferenc
6364
for msg in conv.items:
6465
role_mappings.add((msg.role, conv.condition or default_condition))
6566
all_text.append(msg.content)
66-
67+
68+
if self.system_as_role:
69+
self.system_role_tokens_ = self.tokenizer(self.role_prefix("system", ""), add_special_tokens=False).input_ids + self.message_prefix_tokens_
70+
6771
sys_mappings = list(sys_mappings)
6872
role_mappings = list(role_mappings)
6973

70-
# Tokenize
71-
if self.system_as_role:
72-
sys_mappings = dict(zip(sys_mappings, self._tokenize([self.role_prefix(sys) for sys in sys_mappings], ignore_special=False)))
73-
else:
74-
sys_mappings = dict(zip(sys_mappings, self._tokenize(sys_mappings)))
74+
sys_mappings = dict(zip(sys_mappings, self._tokenize(sys_mappings)))
7575
role_mappings = dict(zip(role_mappings, self._tokenize([self.role_prefix(*args) for args in role_mappings], ignore_special=False)))
7676
all_text = self._tokenize(all_text)
7777

@@ -89,6 +89,9 @@ def tokenize_conversations(self, conversations: Iterable[Conversation], inferenc
8989

9090
# System
9191
if conv.system:
92+
tokens.extend(self.system_role_tokens_)
93+
weights.extend([0.] * len(self.system_role_tokens_))
94+
9295
system = sys_mappings[conv.system]
9396
tokens.extend(system)
9497
weights.extend([0.] * len(system))

ochat/config/model_config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Callable, Iterable
22

3-
from pydantic import BaseModel
3+
from pydantic import BaseModel, ConfigDict
44

55

66
class ModelConfig(BaseModel):
@@ -15,3 +15,5 @@ class ModelConfig(BaseModel):
1515
# conversation template
1616
conversation_template: Callable
1717
hf_chat_template: str = None
18+
19+
model_config = ConfigDict(protected_namespaces=()) # Disables warnings for the model_ namespace used abvoe

0 commit comments

Comments
 (0)