Skip to content

Commit 198d10d

Browse files
committed
Fixed condition, added option to save hf chat template
1 parent 81b4fcc commit 198d10d

File tree

3 files changed

+14
-5
lines changed

3 files changed

+14
-5
lines changed

ochat/config/__init__.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,9 @@ def _v3_6_role_prefix(from_role, condition, role_start_token="", role_end_token=
4949
role_end_token="<|end_header_id|>"),
5050
bos="<|begin_of_text|>", # Llama 3 tokenizer needs manually specifing tokenizer
5151
eot="<|eot_id|>",
52-
inference_condition="GPT4",
53-
message_prefix="\n\n")
52+
inference_condition="GPT4 Correct",
53+
message_prefix="\n\n"),
54+
hf_chat_template="{% set loop_messages = messages %}{% for message in loop_messages %}{% if message['role'] in ['user', 'assistant'] %}{% set content = '<|start_header_id|>GPT4 Correct ' + message['role'].title() + '<|end_header_id|>\n\n' + message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>GPT4 Correct Assistant<|end_header_id|>\n\n' }}{% endif %}",
5455
),
5556

5657
# OpenChat V3.2
@@ -83,7 +84,8 @@ def _v3_6_role_prefix(from_role, condition, role_start_token="", role_end_token=
8384
conversation_template=partial(ConversationTemplate,
8485
role_prefix=_v3_2_role_prefix,
8586
eot="<|end_of_turn|>",
86-
inference_condition="GPT4 Correct")
87+
inference_condition="GPT4 Correct"),
88+
hf_chat_template="{{ bos_token }}{% for message in messages %}{{ 'GPT4 Correct ' + message['role'].title() + ': ' + message['content'] + '<|end_of_turn|>'}}{% endfor %}{% if add_generation_prompt %}{{ 'GPT4 Correct Assistant:' }}{% endif %}"
8789
),
8890

8991
"openchat_v3.2_gemma_new": ModelConfig(
@@ -100,7 +102,8 @@ def _v3_6_role_prefix(from_role, condition, role_start_token="", role_end_token=
100102
conversation_template=partial(ConversationTemplate,
101103
role_prefix=_v3_2_role_prefix,
102104
eot="<end_of_turn>",
103-
inference_condition="GPT4 Correct")
105+
inference_condition="GPT4 Correct"),
106+
hf_chat_template="{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}"
104107
),
105108

106109
### Other models

ochat/config/model_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@ class ModelConfig(BaseModel):
1414

1515
# conversation template
1616
conversation_template: Callable
17+
hf_chat_template: str = None

ochat/training_deepspeed/train.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def parse_args():
3131
parser.add_argument("--data_prefix", type=str, required=True)
3232
parser.add_argument("--save_path", type=str, required=True)
3333
parser.add_argument("--save_every", type=int, default=None)
34+
parser.add_argument("--save_hf_chat_template", bool, default=False) # False until fully tested
3435

3536
# Hyperparameters
3637
parser.add_argument("--batch_max_len", type=int, default=81920)
@@ -136,7 +137,11 @@ def create_lr_scheduler(args, train_total_steps):
136137

137138

138139
def save_tokenizer(args, save_path):
139-
MODEL_CONFIG_MAP[args.model_type].model_tokenizer_create(args.model_path).save_pretrained(save_path)
140+
model_config = MODEL_CONFIG_MAP[args.model_type]
141+
tokenizer = model_config.model_tokenizer_create(args.model_path)
142+
if args.save_hf_chat_template and model_config.hf_chat_template:
143+
tokenizer.chat_template = model_config.hf_chat_template
144+
tokenizer.save_pretrained(save_path)
140145

141146

142147
def save_openchat_metadata(args, epoch, save_path):

0 commit comments

Comments
 (0)