Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 36 additions & 29 deletions Sources/Tokenizers/Tokenizer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -787,41 +787,48 @@ public class PreTrainedTokenizer: @unchecked Sendable, Tokenizer {
throw TokenizerError.missingChatTemplate
}

let template = try compiledTemplate(for: selectedChatTemplate)
var context: [String: Jinja.Value] = try [
"messages": .array(messages.map { try Value(any: $0) }),
"add_generation_prompt": .boolean(addGenerationPrompt),
]
if let tools {
context["tools"] = try .array(tools.map { try Value(any: $0) })
}
if let additionalContext {
// Additional keys and values to be added to the context provided to the prompt templating engine.
// For example, the app could set "tools_in_user_message" to false for Llama 3.1 and 3.2 if a system message is provided.
// The default value is true in the Llama 3.1 and 3.2 chat templates, but these models will perform better if the tools are included in a system message.
for (key, value) in additionalContext {
context[key] = try Value(any: value)
let renderedTemplate: String
do {
let template = try compiledTemplate(for: selectedChatTemplate)

var context: [String: Jinja.Value] = try [
"messages": .array(messages.map { try Value(any: $0) }),
"add_generation_prompt": .boolean(addGenerationPrompt),
]
if let tools {
context["tools"] = try .array(tools.map { try Value(any: $0) })
}
if let additionalContext {
// Additional keys and values to be added to the context provided to the prompt templating engine.
// For example, the app could set "tools_in_user_message" to false for Llama 3.1 and 3.2 if a system message is provided.
// The default value is true in the Llama 3.1 and 3.2 chat templates, but these models will perform better if the tools are included in a system message.
for (key, value) in additionalContext {
context[key] = try Value(any: value)
}
}
}

for (key, value) in tokenizerConfig.dictionary(or: [:]) {
if specialTokenAttributes.contains(key.string), !value.isNull() {
if let stringValue = value.string() {
context[key.string] = .string(stringValue)
} else if let dictionary = value.dictionary() {
if let addedTokenString = addedTokenAsString(Config(dictionary)) {
context[key.string] = .string(addedTokenString)
for (key, value) in tokenizerConfig.dictionary(or: [:]) {
if specialTokenAttributes.contains(key.string), !value.isNull() {
if let stringValue = value.string() {
context[key.string] = .string(stringValue)
} else if let dictionary = value.dictionary() {
if let addedTokenString = addedTokenAsString(Config(dictionary)) {
context[key.string] = .string(addedTokenString)
}
} else if let array: [String] = value.get() {
context[key.string] = .array(array.map { .string($0) })
} else {
context[key.string] = try Value(any: value)
}
} else if let array: [String] = value.get() {
context[key.string] = .array(array.map { .string($0) })
} else {
context[key.string] = try Value(any: value)
}
}
}

let rendered = try template.render(context)
var encodedTokens = encode(text: rendered, addSpecialTokens: false)
renderedTemplate = try template.render(context)
} catch let error as JinjaError {
let description = (error as? LocalizedError)?.errorDescription ?? "\(error)"
throw TokenizerError.chatTemplate(description)
}
var encodedTokens = encode(text: renderedTemplate, addSpecialTokens: false)
var maxLength = maxLength ?? encodedTokens.count
maxLength = min(maxLength, tokenizerConfig.modelMaxLength.integer() ?? maxLength)
if encodedTokens.count > maxLength {
Expand Down