Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions Sources/Hub/Hub.swift
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ public extension Hub {

public final class LanguageModelConfigurationFromHub: Sendable {
struct Configurations {
var modelConfig: Config
var modelConfig: Config?
var tokenizerConfig: Config?
var tokenizerData: Config
}
Expand All @@ -96,7 +96,7 @@ public final class LanguageModelConfigurationFromHub: Sendable {
}
}

public var modelConfig: Config {
public var modelConfig: Config? {
get async throws {
try await configPromise.value.modelConfig
}
Expand Down Expand Up @@ -135,7 +135,7 @@ public final class LanguageModelConfigurationFromHub: Sendable {

public var modelType: String? {
get async throws {
try await modelConfig.modelType.string()
try await modelConfig?.modelType.string()
}
}

Expand Down Expand Up @@ -174,11 +174,11 @@ public final class LanguageModelConfigurationFromHub: Sendable {
do {
// Load required configurations
let modelConfigURL = modelFolder.appending(path: "config.json")
guard FileManager.default.fileExists(atPath: modelConfigURL.path) else {
throw Hub.HubClientError.configurationMissing("config.json")
}

let modelConfig = try hubApi.configuration(fileURL: modelConfigURL)
var modelConfig: Config? = nil
if FileManager.default.fileExists(atPath: modelConfigURL.path) {
modelConfig = try hubApi.configuration(fileURL: modelConfigURL)
}

let tokenizerDataURL = modelFolder.appending(path: "tokenizer.json")
guard FileManager.default.fileExists(atPath: tokenizerDataURL.path) else {
Expand Down
10 changes: 5 additions & 5 deletions Sources/Models/LanguageModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ public extension LanguageModel {

/// async properties downloaded from the configuration
public extension LanguageModel {
var modelConfig: Config {
var modelConfig: Config? {
get async throws {
try await configuration!.modelConfig
}
Expand All @@ -161,13 +161,13 @@ public extension LanguageModel {

var modelType: String? {
get async throws {
try await modelConfig.modelType.string()
try await modelConfig?.modelType.string()
}
}

var textGenerationParameters: Config? {
get async throws {
try await modelConfig.taskSpecificParams.textGeneration
try await modelConfig?.taskSpecificParams.textGeneration
}
}

Expand All @@ -180,14 +180,14 @@ public extension LanguageModel {
var bosTokenId: Int? {
get async throws {
let modelConfig = try await modelConfig
return modelConfig.bosTokenId.integer()
return modelConfig?.bosTokenId.integer()
}
}

var eosTokenId: Int? {
get async throws {
let modelConfig = try await modelConfig
return modelConfig.eosTokenId.integer()
return modelConfig?.eosTokenId.integer()
}
}

Expand Down
10 changes: 8 additions & 2 deletions Tests/HubTests/HubTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@ class HubTests: XCTestCase {
func testConfigDownload() async {
do {
let configLoader = LanguageModelConfigurationFromHub(modelName: "t5-base", hubApi: hubApi)
let config = try await configLoader.modelConfig
guard let config = try await configLoader.modelConfig else {
XCTFail("Test repo is expected to have a config.json file")
return
}

// Test leaf value (Int)
guard let eos = config["eos_token_id"].integer() else {
Expand Down Expand Up @@ -71,7 +74,10 @@ class HubTests: XCTestCase {
func testConfigCamelCase() async {
do {
let configLoader = LanguageModelConfigurationFromHub(modelName: "t5-base", hubApi: hubApi)
let config = try await configLoader.modelConfig
guard let config = try await configLoader.modelConfig else {
XCTFail("Test repo is expected to have a config.json file")
return
}

// Test leaf value (Int)
guard let eos = config["eosTokenId"].integer() else {
Expand Down