Skip to content

Commit f0ef2e8

Browse files
authored
Update tokenizer apply_chat_template functionality (huggingface#647)
* Allow custom kwargs in `tokenizer.apply_chat_template` * Update jinja dependency version * Add `tokenizer_kwargs` options * Add support for dictionaries of chat templates in the tokenizer config * Add `CohereTokenizer` * `apply_chat_template` is no longer async * Add unit test for multiple chat templates * Update tokenizers.js * Also update when `chat_template` is undefined * Support setting tokenizer and text from URL * Update Claude tokenizer display name * Add Cohere Command-R tokenizer to playground * Add `Grok1Tokenizer` * Throw error if chat template object is malformed * Improved error checking * Remove redundant error check * `template_dict` can be a null-prototype object
1 parent 40cdd36 commit f0ef2e8

File tree

6 files changed

+110
-10
lines changed

6 files changed

+110
-10
lines changed

examples/tokenizer-playground/src/App.jsx

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,16 @@ import { Token } from './components/Token'
44

55

66
function App() {
7+
// Allow user to set tokenizer and text via URL query parameters
8+
const urlParams = new URLSearchParams(window.location.search);
9+
const tokenizerParam = urlParams.get('tokenizer');
10+
const textParam = urlParams.get('text');
711

812
const [tokenIds, setTokenIds] = useState([])
913
const [decodedTokens, setDecodedTokens] = useState([])
1014
const [margins, setMargins] = useState([])
1115
const [outputOption, setOutputOption] = useState('text');
12-
const [tokenizer, setTokenizer] = useState('Xenova/gpt-4');
16+
const [tokenizer, setTokenizer] = useState(tokenizerParam ?? 'Xenova/gpt-4');
1317

1418
const textareaRef = useRef(null);
1519
const outputRef = useRef(null);
@@ -51,6 +55,12 @@ function App() {
5155
worker.current.postMessage({ model_id, text });
5256
}, [tokenizer]);
5357

58+
useEffect(() => {
59+
if (textParam) {
60+
onInputChange({ target: { value: textParam } });
61+
}
62+
}, [onInputChange, textParam]);
63+
5464
const onTokenizerChange = useCallback((e) => {
5565
const model_id = e.target.value;
5666
setTokenizer(model_id);
@@ -70,10 +80,12 @@ function App() {
7080
<option value="Xenova/gpt-4">gpt-4 / gpt-3.5-turbo / text-embedding-ada-002</option>
7181
<option value="Xenova/text-davinci-003">text-davinci-003 / text-davinci-002</option>
7282
<option value="Xenova/gpt-3">gpt-3</option>
73-
<option value="Xenova/claude-tokenizer">Claude 3</option>
83+
<option value="Xenova/grok-1-tokenizer">Grok-1</option>
84+
<option value="Xenova/claude-tokenizer">Claude</option>
7485
<option value="Xenova/mistral-tokenizer">Mistral</option>
7586
<option value="Xenova/gemma-tokenizer">Gemma</option>
7687
<option value="Xenova/llama-tokenizer">LLaMA / Llama 2</option>
88+
<option value="Xenova/c4ai-command-r-v01-tokenizer">Cohere Command-R</option>
7789
<option value="Xenova/t5-small">T5</option>
7890
<option value="Xenova/bert-base-cased">bert-base-cased</option>
7991
</select>
@@ -86,6 +98,7 @@ function App() {
8698
rows="8"
8799
className="font-mono text-lg block w-full p-2.5 text-gray-900 bg-gray-50 rounded-lg border border-gray-200"
88100
placeholder="Enter some text"
101+
defaultValue={textParam ?? textareaRef.current?.value ?? ''}
89102
></textarea>
90103

91104
<div className='flex justify-center gap-5'>

examples/tokenizer-playground/src/worker.js

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ self.addEventListener('message', async (event) => {
2222
// NOTE: We just remove the StripDecoder from the llama tokenizer
2323
switch (tokenizer.constructor.name) {
2424
case 'LlamaTokenizer':
25+
case 'Grok1Tokenizer':
2526
// tokenizer.decoder.decoders.at(-1).constructor.name === 'StripDecoder'
2627
tokenizer.decoder.decoders.pop();
2728
break;

package-lock.json

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
"dependencies": {
4141
"onnxruntime-web": "1.14.0",
4242
"sharp": "^0.32.0",
43-
"@huggingface/jinja": "^0.2.1"
43+
"@huggingface/jinja": "^0.2.2"
4444
},
4545
"optionalDependencies": {
4646
"onnxruntime-node": "1.14.0"

src/tokenizers.js

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2519,6 +2519,18 @@ export class PreTrainedTokenizer extends Callable {
25192519
this.legacy = false;
25202520

25212521
this.chat_template = tokenizerConfig.chat_template ?? null;
2522+
if (Array.isArray(this.chat_template)) {
2523+
// Chat templates are stored as lists of dicts with fixed key names,
2524+
// we reconstruct that into a single dict while loading them.
2525+
const chat_template = Object.create(null);
2526+
for (const { name, template } of this.chat_template) {
2527+
if (typeof name !== 'string' || typeof template !== 'string') {
2528+
throw new Error('Chat template must be a list of objects with "name" and "template" properties');
2529+
}
2530+
chat_template[name] = template;
2531+
}
2532+
this.chat_template = chat_template;
2533+
}
25222534
this._compiled_template_cache = new Map();
25232535
}
25242536

@@ -2995,6 +3007,7 @@ export class PreTrainedTokenizer extends Callable {
29953007
* @param {number} [options.max_length=null] Maximum length (in tokens) to use for padding or truncation. Has no effect if tokenize is false.
29963008
* If not specified, the tokenizer's `max_length` attribute will be used as a default.
29973009
* @param {boolean} [options.return_tensor=true] Whether to return the output as a Tensor or an Array. Has no effect if tokenize is false.
3010+
* @param {Object} [options.tokenizer_kwargs={}] Additional options to pass to the tokenizer.
29983011
* @returns {string | Tensor | number[]| number[][]} The tokenized output.
29993012
*/
30003013
apply_chat_template(conversation, {
@@ -3005,9 +3018,37 @@ export class PreTrainedTokenizer extends Callable {
30053018
truncation = false,
30063019
max_length = null,
30073020
return_tensor = true,
3021+
tokenizer_kwargs = {},
3022+
...kwargs
30083023
} = {}) {
30093024

3010-
chat_template ??= this.chat_template ?? this.default_chat_template;
3025+
// First, handle the cases when the model has a dict of multiple templates
3026+
if (
3027+
(this.chat_template && typeof this.chat_template === 'object') ||
3028+
(this.chat_template === null && this.default_chat_template && typeof this.default_chat_template === 'object')
3029+
) {
3030+
const template_dict = this.chat_template ?? this.default_chat_template; // Guaranteed to be a non-null object
3031+
3032+
if (chat_template !== null && Object.hasOwn(template_dict, chat_template)) {
3033+
// The user can pass the name of a template to the chat template argument instead of an entire template
3034+
chat_template = template_dict[chat_template];
3035+
} else if (chat_template === null && 'default' in template_dict) {
3036+
chat_template = template_dict['default'];
3037+
} else if (chat_template === null) {
3038+
throw Error(
3039+
`This model has multiple chat templates with no default specified! Please either pass a chat ` +
3040+
`template or the name of the template you wish to use to the 'chat_template' argument. Available ` +
3041+
`template names are ${Object.keys(template_dict).sort()}.`
3042+
)
3043+
}
3044+
} else {
3045+
// These are the cases when the model has a single template
3046+
// priority: `chat_template` argument > `tokenizer.chat_template` > `tokenizer.default_chat_template
3047+
chat_template ??= this.chat_template ?? this.default_chat_template;
3048+
}
3049+
if (typeof chat_template !== 'string') {
3050+
throw Error(`chat_template must be a string, but got ${typeof chat_template}`);
3051+
}
30113052

30123053
// Compilation function uses a cache to avoid recompiling the same template
30133054
let compiledTemplate = this._compiled_template_cache.get(chat_template);
@@ -3029,6 +3070,7 @@ export class PreTrainedTokenizer extends Callable {
30293070
add_generation_prompt: add_generation_prompt,
30303071

30313072
...special_tokens_map,
3073+
...kwargs,
30323074
});
30333075

30343076
if (tokenize) {
@@ -3038,6 +3080,7 @@ export class PreTrainedTokenizer extends Callable {
30383080
truncation,
30393081
max_length,
30403082
return_tensor,
3083+
...tokenizer_kwargs,
30413084
}).input_ids;
30423085
}
30433086

@@ -3208,6 +3251,8 @@ export class GemmaTokenizer extends PreTrainedTokenizer {
32083251
_default_chat_template = "{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}"
32093252
}
32103253

3254+
export class Grok1Tokenizer extends PreTrainedTokenizer { }
3255+
32113256
/**
32123257
* Helper function to build translation inputs for an `NllbTokenizer` or `M2M100Tokenizer`.
32133258
* @param {PreTrainedTokenizer} self The tokenizer instance.
@@ -4263,6 +4308,9 @@ export class VitsTokenizer extends PreTrainedTokenizer {
42634308
this.decoder = new VitsDecoder({});
42644309
}
42654310
}
4311+
4312+
export class CohereTokenizer extends PreTrainedTokenizer { }
4313+
42664314
/**
42674315
* Helper class which is used to instantiate pretrained tokenizers with the `from_pretrained` function.
42684316
* The chosen tokenizer class is determined by the type specified in the tokenizer config.
@@ -4314,6 +4362,8 @@ export class AutoTokenizer {
43144362
VitsTokenizer,
43154363
Qwen2Tokenizer,
43164364
GemmaTokenizer,
4365+
Grok1Tokenizer,
4366+
CohereTokenizer,
43174367

43184368
// Base case:
43194369
PreTrainedTokenizer,

tests/tokenizers.test.js

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,42 @@ describe('Chat templates', () => {
350350
compare(input_ids, [1, 733, 16289, 28793, 22557, 28725, 910, 460, 368, 28804, 733, 28748, 16289, 28793, 28737, 28742, 28719, 2548, 1598, 28723, 1602, 541, 315, 1316, 368, 3154, 28804, 2, 28705, 733, 16289, 28793, 315, 28742, 28715, 737, 298, 1347, 805, 910, 10706, 5752, 1077, 3791, 28808, 733, 28748, 16289, 28793])
351351
});
352352

353+
it('should support multiple chat templates', async () => {
354+
355+
const tokenizer = await AutoTokenizer.from_pretrained("Xenova/c4ai-command-r-v01-tokenizer")
356+
357+
// define conversation input:
358+
const conversation = [
359+
{ role: "user", content: "Whats the biggest penguin in the world?" }
360+
]
361+
// define documents to ground on:
362+
const documents = [
363+
{ title: "Tall penguins", text: "Emperor penguins are the tallest growing up to 122 cm in height." },
364+
{ title: "Penguin habitats", text: "Emperor penguins only live in Antarctica." }
365+
]
366+
367+
// render the RAG prompt as a string:
368+
const grounded_generation_prompt = tokenizer.apply_chat_template(
369+
conversation,
370+
{
371+
chat_template: "rag",
372+
tokenize: false,
373+
add_generation_prompt: true,
374+
375+
documents,
376+
citation_mode: "accurate", // or "fast"
377+
}
378+
)
379+
expect(grounded_generation_prompt).toEqual(
380+
"<BOS_TOKEN><|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|># Safety Preamble\nThe instructions in this section override those in the task description and style guide sections. Don't answer questions that are harmful or immoral.\n\n" +
381+
"# System Preamble\n## Basic Rules\nYou are a powerful conversational AI trained by Cohere to help people. You are augmented by a number of tools, and your job is to use and consume the output of these tools to best help the user. You will see a conversation history between yourself and a user, ending with an utterance from the user. You will then see a specific instruction instructing you what kind of response to generate. When you answer the user's requests, you cite your sources in your answers, according to those instructions.\n\n" +
382+
"# User Preamble\n## Task and Context\nYou help people answer their questions and other requests interactively. You will be asked a very wide array of requests on all kinds of topics. You will be equipped with a wide range of search engines or similar tools to help you, which you use to research your answer. You should focus on serving the user's needs as best you can, which will be wide-ranging.\n\n## Style Guide\nUnless the user asks for a different style of answer, you should answer in full sentences, using proper grammar and spelling.<|END_OF_TURN_TOKEN|>" +
383+
"<|START_OF_TURN_TOKEN|><|USER_TOKEN|>Whats the biggest penguin in the world?<|END_OF_TURN_TOKEN|>" +
384+
"<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|><results>\nDocument: 0\ntitle: Tall penguins\ntext: Emperor penguins are the tallest growing up to 122 cm in height.\n\nDocument: 1\ntitle: Penguin habitats\ntext: Emperor penguins only live in Antarctica.\n</results><|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>Carefully perform the following instructions, in order, starting each with a new line.\nFirstly, Decide which of the retrieved documents are relevant to the user's last input by writing 'Relevant Documents:' followed by comma-separated list of document numbers. If none are relevant, you should instead write 'None'.\nSecondly, Decide which of the retrieved documents contain facts that should be cited in a good answer to the user's last input by writing 'Cited Documents:' followed a comma-separated list of document numbers. If you dont want to cite any of them, you should instead write 'None'.\nThirdly, Write 'Answer:' followed by a response to the user's last input in high quality natural english. Use the retrieved documents to help you. Do not insert any citations or grounding markup.\nFinally, Write 'Grounded answer:' followed by a response to the user's last input in high quality natural english. Use the symbols <co: doc> and </co: doc> to indicate when a fact comes from a document in the search result, e.g <co: 0>my fact</co: 0> for a fact from document 0.<|END_OF_TURN_TOKEN|>" +
385+
"<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"
386+
);
387+
});
388+
353389
it('should support user-defined chat template', async () => {
354390
const tokenizer = await AutoTokenizer.from_pretrained("Xenova/llama-tokenizer");
355391

@@ -395,7 +431,7 @@ describe('Chat templates', () => {
395431
.replaceAll('USE_DEFAULT_PROMPT', true)
396432
.replaceAll('DEFAULT_SYSTEM_MESSAGE', 'You are a helpful, respectful and honest assistant.');
397433

398-
const text = await tokenizer.apply_chat_template(chat, { tokenize: false, return_tensor: false, chat_template });
434+
const text = tokenizer.apply_chat_template(chat, { tokenize: false, return_tensor: false, chat_template });
399435

400436
expect(text).toEqual("<s>[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant.\n<</SYS>>\n\nHello, how are you? [/INST] I'm doing great. How can I help you today? </s><s>[INST] I'd like to show off how chat templating works! [/INST]");
401437

@@ -412,7 +448,7 @@ describe('Chat templates', () => {
412448

413449
for (let { messages, add_generation_prompt, tokenize, target } of tests) {
414450

415-
const generated = await tokenizer.apply_chat_template(messages, {
451+
const generated = tokenizer.apply_chat_template(messages, {
416452
tokenize,
417453
add_generation_prompt,
418454
return_tensor: false,

0 commit comments

Comments
 (0)