Skip to content

Commit 44e7a8c

Browse files
authored
Update template.py
1 parent b61787a commit 44e7a8c

File tree

1 file changed

+30
-1
lines changed

1 file changed

+30
-1
lines changed

api/adapter/template.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1216,6 +1216,34 @@ def template(self) -> str:
12161216
)
12171217

12181218

1219+
class MixtralTemplate(BaseTemplate):
1220+
""" https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2/blob/main/tokenizer_config.json """
1221+
1222+
name = "mixtral"
1223+
allow_models = ["mixtral"]
1224+
stop = {
1225+
"strings": ["[INST]", "[/INST]"],
1226+
}
1227+
1228+
@property
1229+
def template(self) -> str:
1230+
return (
1231+
"{{ bos_token }}"
1232+
"{% for message in messages %}"
1233+
"{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}"
1234+
"{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}"
1235+
"{% endif %}"
1236+
"{% if message['role'] == 'user' %}"
1237+
"{{ '[INST] ' + message['content'] + ' [/INST]' }}"
1238+
"{% elif message['role'] == 'assistant' %}"
1239+
"{{ message['content'] + '</s>' }}"
1240+
"{% else %}"
1241+
"{{ raise_exception('Only user and assistant roles are supported!') }}"
1242+
"{% endif %}"
1243+
"{% endfor %}"
1244+
)
1245+
1246+
12191247
register_prompt_adapter(AlpacaTemplate)
12201248
register_prompt_adapter(AquilaChatTemplate)
12211249
register_prompt_adapter(BaiChuanTemplate)
@@ -1233,6 +1261,7 @@ def template(self) -> str:
12331261
register_prompt_adapter(HuatuoTemplate)
12341262
register_prompt_adapter(InternLMTemplate)
12351263
register_prompt_adapter(Llama2Template)
1264+
register_prompt_adapter(MixtralTemplate)
12361265
register_prompt_adapter(MossTemplate)
12371266
register_prompt_adapter(OctopackTemplate)
12381267
register_prompt_adapter(OpenBuddyTemplate)
@@ -1256,6 +1285,6 @@ def template(self) -> str:
12561285
{"role": "assistant", "content": "I'm doing great. How can I help you today?"},
12571286
{"role": "user", "content": "I'd like to show off how chat templating works!"},
12581287
]
1259-
template = get_prompt_adapter(prompt_name="sus-chat")
1288+
template = get_prompt_adapter(prompt_name="mixtral")
12601289
messages = template.postprocess_messages(chat)
12611290
print(template.apply_chat_template(messages))

0 commit comments

Comments
 (0)