Skip to content

Commit 77406c4

Browse files
committed
Refactored tokenizer config, added more supported models
1 parent b9a8e6b commit 77406c4

File tree

3 files changed

+47
-20
lines changed

3 files changed

+47
-20
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ Some of the supported models on `Hugging Face`:
6262
- [minishlab/potion-base-2M](https://huggingface.co/minishlab/potion-base-2M)
6363
- [minishlab/potion-base-4M](https://huggingface.co/minishlab/potion-base-4M)
6464
- [minishlab/potion-base-8M](https://huggingface.co/minishlab/potion-base-8M)
65+
- [minishlab/potion-retrieval-32M](https://huggingface.co/minishlab/potion-retrieval-32M)
66+
- [minishlab/potion-base-32M](https://huggingface.co/minishlab/potion-base-32M)
6567
- [minishlab/M2V_base_output](https://huggingface.co/minishlab/M2V_base_output)
6668

6769
### Static Embeddings

Sources/Embeddings/EmbeddingsUtils.swift

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -47,16 +47,21 @@ extension String {
4747
}
4848
}
4949

50+
public enum TokenizerConfigType {
51+
case filePath(String)
52+
case data([String: Any])
53+
}
54+
5055
public struct TokenizerConfig {
51-
public let dataFileName: String
52-
public let tokenizerClass: String
56+
public let data: TokenizerConfigType
57+
public let config: TokenizerConfigType
5358

5459
public init(
55-
dataFileName: String = "tokenizer.json",
56-
tokenizerClass: String = "BertTokenizer"
60+
data: TokenizerConfigType = .filePath("tokenizer.json"),
61+
config: TokenizerConfigType = .filePath("tokenizer_config.json")
5762
) {
58-
self.dataFileName = dataFileName
59-
self.tokenizerClass = tokenizerClass
63+
self.data = data
64+
self.config = config
6065
}
6166
}
6267

@@ -104,9 +109,11 @@ extension LoadConfig {
104109
modelConfig: ModelConfig(
105110
weightsFileName: "0_StaticEmbedding/model.safetensors"
106111
),
112+
// In case of `StaticEmbeddings` tokenizer `data` is loaded from `0_StaticEmbedding/tokenizer.json` file
113+
// and tokenizer `config` is a dictionary with a single key `tokenizerClass` and value `BertTokenizer`.
107114
tokenizerConfig: TokenizerConfig(
108-
dataFileName: "0_StaticEmbedding/tokenizer.json",
109-
tokenizerClass: "BertTokenizer"
115+
data: .filePath("0_StaticEmbedding/tokenizer.json"),
116+
config: .data(["tokenizerClass": "BertTokenizer"])
110117
)
111118
)
112119
}

Sources/Embeddings/StaticEmbeddings/StaticEmbeddingsUtils.swift

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@ extension StaticEmbeddings {
3131
let tokenizer =
3232
if let tokenizerConfig = loadConfig.tokenizerConfig {
3333
try AutoTokenizer.from(
34-
tokenizerDataFile: modelFolder.appendingPathComponent(
35-
tokenizerConfig.dataFileName),
36-
tokenizerClass: tokenizerConfig.tokenizerClass
34+
modelFolder: modelFolder,
35+
tokenizerData: tokenizerConfig.data,
36+
tokenizerConfig: tokenizerConfig.config
3737
)
3838
} else {
3939
try await AutoTokenizer.from(modelFolder: modelFolder)
@@ -65,17 +65,35 @@ extension StaticEmbeddings {
6565

6666
extension AutoTokenizer {
6767
static func from(
68-
tokenizerDataFile: URL,
69-
tokenizerClass: String
68+
modelFolder: URL,
69+
tokenizerData: TokenizerConfigType,
70+
tokenizerConfig: TokenizerConfigType
7071
) throws -> any Tokenizer {
71-
let data = try Data(contentsOf: tokenizerDataFile)
72-
let parsedData = try JSONSerialization.jsonObject(with: data, options: [])
73-
guard let tokenizerData = parsedData as? [NSString: Any] else {
74-
throw EmbeddingsError.invalidFile
75-
}
72+
let tokenizerConfig = try resolveConfig(tokenizerConfig, in: modelFolder)
73+
let tokenizerData = try resolveConfig(tokenizerData, in: modelFolder)
7674
return try AutoTokenizer.from(
77-
tokenizerConfig: Config(["tokenizerClass": tokenizerClass]),
78-
tokenizerData: Config(tokenizerData)
75+
tokenizerConfig: tokenizerConfig,
76+
tokenizerData: tokenizerData
7977
)
8078
}
8179
}
80+
81+
func resolveConfig(_ tokenizerConfig: TokenizerConfigType, in modelFolder: URL) throws -> Config {
82+
switch tokenizerConfig {
83+
case .filePath(let filePath):
84+
let fileURL = modelFolder.appendingPathComponent(filePath)
85+
let data = try loadJSONConfig(at: fileURL)
86+
return Config(data as [NSString: Any])
87+
case .data(let data):
88+
return Config(data as [NSString: Any])
89+
}
90+
}
91+
92+
func loadJSONConfig(at filePath: URL) throws -> [String: Any] {
93+
let data = try Data(contentsOf: filePath)
94+
let parsedData = try JSONSerialization.jsonObject(with: data, options: [])
95+
guard let config = parsedData as? [String: Any] else {
96+
throw EmbeddingsError.invalidFile
97+
}
98+
return config
99+
}

0 commit comments

Comments
 (0)