Skip to content

Commit 94e5979

Browse files
committed
Integrate Jinja package
1 parent 4448e89 commit 94e5979

File tree

3 files changed

+24
-21
lines changed

3 files changed

+24
-21
lines changed

Sources/Hub/Config.swift

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
// Created by Piotr Kowalczuk on 06.03.25.
66

77
import Foundation
8+
import Jinja
89

910
// MARK: - Configuration files with dynamic lookup
1011

@@ -413,28 +414,28 @@ public struct Config: Hashable, Sendable,
413414
self.dictionary(or: or)
414415
}
415416

416-
public func toJinjaCompatible() -> Any? {
417+
public func jinjaValue() -> Jinja.Value {
417418
switch self.value {
418419
case let .array(val):
419-
return val.map { $0.toJinjaCompatible() }
420+
return .array(val.map { $0.jinjaValue() })
420421
case let .dictionary(val):
421-
var result: [String: Any?] = [:]
422+
var result: [String: Jinja.Value] = [:]
422423
for (key, config) in val {
423-
result[key.string] = config.toJinjaCompatible()
424+
result[key.string] = config.jinjaValue()
424425
}
425-
return result
426+
return .object(.init(uniqueKeysWithValues: result))
426427
case let .boolean(val):
427-
return val
428+
return .boolean(val)
428429
case let .floating(val):
429-
return val
430+
return .double(Double(String(val)) ?? Double(val))
430431
case let .integer(val):
431-
return val
432+
return .int(val)
432433
case let .string(val):
433-
return val.string
434+
return .string(val.string)
434435
case let .token(val):
435-
return [String(val.0): val.1.string] as [String: String]
436+
return [String(val.0): .string(val.1.string)]
436437
case .null:
437-
return nil
438+
return .null
438439
}
439440
}
440441

Sources/Tokenizers/Tokenizer.swift

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -549,32 +549,34 @@ public class PreTrainedTokenizer: Tokenizer {
549549
}
550550

551551
let template = try compiledTemplate(for: selectedChatTemplate)
552-
var context: [String: Any] = [
553-
"messages": messages,
554-
"add_generation_prompt": addGenerationPrompt,
552+
var context: [String: Jinja.Value] = try [
553+
"messages": .array(messages.map { try Value(any: $0) }),
554+
"add_generation_prompt": .boolean(addGenerationPrompt),
555555
]
556556
if let tools {
557-
context["tools"] = tools
557+
context["tools"] = try .array(tools.map { try Value(any: $0) })
558558
}
559559
if let additionalContext {
560560
// Additional keys and values to be added to the context provided to the prompt templating engine.
561561
// For example, the app could set "tools_in_user_message" to false for Llama 3.1 and 3.2 if a system message is provided.
562562
// The default value is true in the Llama 3.1 and 3.2 chat templates, but these models will perform better if the tools are included in a system message.
563563
for (key, value) in additionalContext {
564-
context[key] = value
564+
context[key] = try Value(any: value)
565565
}
566566
}
567567

568568
for (key, value) in tokenizerConfig.dictionary(or: [:]) {
569569
if specialTokenAttributes.contains(key.string), !value.isNull() {
570570
if let stringValue = value.string() {
571-
context[key.string] = stringValue
571+
context[key.string] = .string(stringValue)
572572
} else if let dictionary = value.dictionary() {
573-
context[key.string] = addedTokenAsString(Config(dictionary))
573+
if let addedTokenString = addedTokenAsString(Config(dictionary)) {
574+
context[key.string] = .string(addedTokenString)
575+
}
574576
} else if let array: [String] = value.get() {
575-
context[key.string] = array
577+
context[key.string] = .array(array.map { .string($0) })
576578
} else {
577-
context[key.string] = value
579+
context[key.string] = try Value(any: value)
578580
}
579581
}
580582
}

Tests/HubTests/ConfigTests.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,7 @@ struct ConfigTests {
435435
"""
436436

437437
let got = try Template(template).render([
438-
"config": cfg.toJinjaCompatible()
438+
"config": cfg.jinjaValue()
439439
])
440440

441441
#expect(got == exp)

0 commit comments

Comments
 (0)