@@ -9,6 +9,9 @@ import Hub
99import Foundation
1010import Jinja
1111
12+ public typealias Message = [ String : Any ]
13+ public typealias ToolSpec = [ String : Any ]
14+
1215enum TokenizerError : Error {
1316 case missingConfig
1417 case missingTokenizerClassInConfig
@@ -142,23 +145,57 @@ public protocol Tokenizer {
142145 var unknownTokenId : Int ? { get }
143146
144147 /// The appropriate chat template is selected from the tokenizer config
145- func applyChatTemplate( messages: [ [ String : String ] ] ) throws -> [ Int ]
148+ func applyChatTemplate( messages: [ Message ] ) throws -> [ Int ]
149+
150+ /// The appropriate chat template is selected from the tokenizer config
151+ func applyChatTemplate( messages: [ Message ] , tools: [ ToolSpec ] ) throws -> [ Int ]
146152
147153 /// The chat template is provided as a string literal or specified by name
148- func applyChatTemplate( messages: [ [ String : String ] ] , chatTemplate: ChatTemplateArgument ) throws -> [ Int ]
154+ func applyChatTemplate( messages: [ Message ] , chatTemplate: ChatTemplateArgument ) throws -> [ Int ]
149155
150156 /// The chat template is provided as a string literal
151- func applyChatTemplate( messages: [ [ String : String ] ] , chatTemplate: String ) throws -> [ Int ]
157+ func applyChatTemplate( messages: [ Message ] , chatTemplate: String ) throws -> [ Int ]
152158
153159 func applyChatTemplate(
154- messages: [ [ String : String ] ] ,
160+ messages: [ Message ] ,
155161 /// A chat template can optionally be provided or specified by name when several templates are included in the tokenizer config. Normally this is not necessary.
156162 chatTemplate: ChatTemplateArgument ? ,
157163 addGenerationPrompt: Bool ,
158164 truncation: Bool ,
159165 maxLength: Int ? ,
160- tools: [ [ String : Any ] ] ?
166+ tools: [ ToolSpec ] ?
161167 ) throws -> [ Int ]
168+
169+ func applyChatTemplate(
170+ messages: [ Message ] ,
171+ /// A chat template can optionally be provided or specified by name when several templates are included in the tokenizer config. Normally this is not necessary.
172+ chatTemplate: ChatTemplateArgument ? ,
173+ addGenerationPrompt: Bool ,
174+ truncation: Bool ,
175+ maxLength: Int ? ,
176+ tools: [ ToolSpec ] ? ,
177+ additionalContext: [ String : Any ] ?
178+ ) throws -> [ Int ]
179+ }
180+
181+ extension Tokenizer {
182+ /// Call previous signature for backwards compatibility
183+ func applyChatTemplate(
184+ messages: [ Message ] ,
185+ /// A chat template can optionally be provided or specified by name when several templates are included in the tokenizer config. Normally this is not necessary.
186+ chatTemplate: ChatTemplateArgument ? ,
187+ addGenerationPrompt: Bool ,
188+ truncation: Bool ,
189+ maxLength: Int ? ,
190+ tools: [ ToolSpec ] ? ,
191+ additionalContext: [ String : Any ] ?
192+ ) throws -> [ Int ] {
193+ if additionalContext == nil {
194+ try applyChatTemplate ( messages: messages, chatTemplate: chatTemplate, addGenerationPrompt: addGenerationPrompt, truncation: truncation, maxLength: maxLength, tools: tools)
195+ } else {
196+ throw TokenizerError . chatTemplate ( " Not implemented " )
197+ }
198+ }
162199}
163200
164201public extension Tokenizer {
@@ -359,20 +396,46 @@ public class PreTrainedTokenizer: Tokenizer {
359396 model. convertIdToToken ( id)
360397 }
361398
362- public func applyChatTemplate( messages: [ [ String : String ] ] ) throws -> [ Int ] {
399+ public func applyChatTemplate( messages: [ Message ] ) throws -> [ Int ] {
363400 try applyChatTemplate ( messages: messages, addGenerationPrompt: true )
364401 }
365402
366- public func applyChatTemplate( messages: [ [ String : String ] ] , chatTemplate: ChatTemplateArgument ) throws -> [ Int ] {
403+ public func applyChatTemplate( messages: [ Message ] , tools: [ ToolSpec ] ) throws -> [ Int ] {
404+ try applyChatTemplate ( messages: messages, addGenerationPrompt: true , tools: tools)
405+ }
406+
407+ public func applyChatTemplate( messages: [ Message ] , tools: [ ToolSpec ] , additionalContext: [ String : Any ] ) throws
408+ -> [ Int ]
409+ {
410+ try applyChatTemplate (
411+ messages: messages,
412+ addGenerationPrompt: true ,
413+ tools: tools,
414+ additionalContext: additionalContext
415+ )
416+ }
417+
418+ public func applyChatTemplate( messages: [ Message ] , chatTemplate: ChatTemplateArgument ) throws -> [ Int ] {
367419 try applyChatTemplate ( messages: messages, chatTemplate: chatTemplate, addGenerationPrompt: true )
368420 }
369421
370- public func applyChatTemplate( messages: [ [ String : String ] ] , chatTemplate: String ) throws -> [ Int ] {
422+ public func applyChatTemplate( messages: [ Message ] , chatTemplate: String ) throws -> [ Int ] {
371423 try applyChatTemplate ( messages: messages, chatTemplate: . literal( chatTemplate) , addGenerationPrompt: true )
372424 }
373425
374426 public func applyChatTemplate(
375- messages: [ [ String : String ] ] ,
427+ messages: [ Message ] ,
428+ chatTemplate: ChatTemplateArgument ? = nil ,
429+ addGenerationPrompt: Bool = false ,
430+ truncation: Bool = false ,
431+ maxLength: Int ? = nil ,
432+ tools: [ ToolSpec ] ? = nil
433+ ) throws -> [ Int ] {
434+ try applyChatTemplate ( messages: messages, chatTemplate: chatTemplate, addGenerationPrompt: addGenerationPrompt, truncation: truncation, maxLength: maxLength, tools: tools, additionalContext: nil )
435+ }
436+
437+ public func applyChatTemplate(
438+ messages: [ Message ] ,
376439 chatTemplate: ChatTemplateArgument ? = nil ,
377440 addGenerationPrompt: Bool = false ,
378441 truncation: Bool = false ,
@@ -382,8 +445,8 @@ public class PreTrainedTokenizer: Tokenizer {
382445 /// giving the name, description and argument types for the tool. See the
383446 /// [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#automated-function-conversion-for-tool-use)
384447 /// for more information.
385- /// Note: tool calling is not supported yet, it will be available in a future update.
386- tools : [ [ String : Any ] ] ? = nil
448+ tools : [ ToolSpec ] ? = nil ,
449+ additionalContext : [ String : Any ] ? = nil
387450 ) throws -> [ Int ] {
388451 var selectedChatTemplate : String ?
389452 if let chatTemplate, case . literal( let template) = chatTemplate {
@@ -425,10 +488,21 @@ public class PreTrainedTokenizer: Tokenizer {
425488 let template = try Template ( selectedChatTemplate)
426489 var context : [ String : Any ] = [
427490 " messages " : messages,
428- " add_generation_prompt " : addGenerationPrompt
429- // TODO: Add `tools` entry when support is added in Jinja
430- // "tools": tools
491+ " add_generation_prompt " : addGenerationPrompt,
431492 ]
493+ if let tools {
494+ context [ " tools " ] = tools
495+ }
496+ if let additionalContext {
497+ /*
498+ Additional keys and values to be added to the context provided to the prompt templating engine.
499+ For example, the app could set "tools_in_user_message" to false for Llama 3.1 and 3.2 if a system message is provided.
500+ The default value is true in the Llama 3.1 and 3.2 chat templates, but these models will perform better if the tools are included in a system message.
501+ */
502+ for (key, value) in additionalContext {
503+ context [ key] = value
504+ }
505+ }
432506
433507 // TODO: maybe keep NSString here
434508 for (key, value) in tokenizerConfig. dictionary as [ String : Any ] {
0 commit comments