Skip to content

Commit 2c619f0

Browse files
committed
Treat system as a role
1 parent 198d10d commit 2c619f0

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

ochat/config/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717

1818
_V3_6_PREFIXES = {
1919
"user": "User",
20-
"assistant": "Assistant"
20+
"assistant": "Assistant",
21+
"system": "System"
2122
}
2223

2324

@@ -31,7 +32,7 @@ def _v3_2_role_prefix(from_role, condition):
3132
return f"{condition} {_V3_2_PREFIXES[from_role]}".strip()
3233

3334
def _v3_6_role_prefix(from_role, condition, role_start_token="", role_end_token=""):
34-
return f"{role_start_token}{condition} {_V3_6_PREFIXES[from_role]}{role_end_token}".strip()
35+
return role_start_token + f"{condition} {_V3_6_PREFIXES[from_role]}".strip() + role_end_token
3536

3637
MODEL_CONFIG_MAP = {
3738
# OpenChat V3.6 (llama 3)
@@ -49,6 +50,7 @@ def _v3_6_role_prefix(from_role, condition, role_start_token="", role_end_token=
4950
role_end_token="<|end_header_id|>"),
5051
bos="<|begin_of_text|>", # Llama 3 tokenizer needs manually specifing tokenizer
5152
eot="<|eot_id|>",
53+
system_as_role=True,
5254
inference_condition="GPT4 Correct",
5355
message_prefix="\n\n"),
5456
hf_chat_template="{% set loop_messages = messages %}{% for message in loop_messages %}{% if message['role'] in ['user', 'assistant'] %}{% set content = '<|start_header_id|>GPT4 Correct ' + message['role'].title() + '<|end_header_id|>\n\n' + message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>GPT4 Correct Assistant<|end_header_id|>\n\n' }}{% endif %}",

ochat/config/conversation_template.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class ConversationTemplate(BaseModel):
2424
bos: Optional[str] = None
2525
role_prefix: Callable
2626
message_prefix: str = ""
27+
system_as_role: bool = False
2728
eot: str
2829

2930
inference_condition: Optional[str] = None
@@ -67,7 +68,10 @@ def tokenize_conversations(self, conversations: Iterable[Conversation], inferenc
6768
role_mappings = list(role_mappings)
6869

6970
# Tokenize
70-
sys_mappings = dict(zip(sys_mappings, self._tokenize(sys_mappings)))
71+
if self.system_as_role:
72+
sys_mappings = dict(zip(sys_mappings, self._tokenize([self.role_prefix(sys) for sys in sys_mappings], ignore_special=False)))
73+
else:
74+
sys_mappings = dict(zip(sys_mappings, self._tokenize(sys_mappings)))
7175
role_mappings = dict(zip(role_mappings, self._tokenize([self.role_prefix(*args) for args in role_mappings], ignore_special=False)))
7276
all_text = self._tokenize(all_text)
7377

0 commit comments

Comments
 (0)