Skip to content

Commit ca72653

Browse files
Add repo revision option (#202)
* Add repo revision option * Add tokenizer revision option * Add revision option to CLI * Fix error type * Add repo revision unit tests * Update CLI option description Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * Use correct revision when retrieving filenames from API * Use revision in getFileMetadata method * Apply formatting * Change error type --------- Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
1 parent 85b5bac commit ca72653

File tree

4 files changed

+72
-23
lines changed

4 files changed

+72
-23
lines changed

Sources/Hub/Hub.swift

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,10 +155,11 @@ public class LanguageModelConfigurationFromHub {
155155

156156
public init(
157157
modelName: String,
158+
revision: String = "main",
158159
hubApi: HubApi = .shared
159160
) {
160161
configPromise = Task.init {
161-
try await self.loadConfig(modelName: modelName, hubApi: hubApi)
162+
try await self.loadConfig(modelName: modelName, revision: revision, hubApi: hubApi)
162163
}
163164
}
164165

@@ -216,13 +217,14 @@ public class LanguageModelConfigurationFromHub {
216217

217218
func loadConfig(
218219
modelName: String,
220+
revision: String,
219221
hubApi: HubApi = .shared
220222
) async throws -> Configurations {
221223
let filesToDownload = ["config.json", "tokenizer_config.json", "chat_template.json", "tokenizer.json"]
222224
let repo = Hub.Repo(id: modelName)
223225

224226
do {
225-
let downloadedModelFolder = try await hubApi.snapshot(from: repo, matching: filesToDownload)
227+
let downloadedModelFolder = try await hubApi.snapshot(from: repo, revision: revision, matching: filesToDownload)
226228
return try await loadConfig(modelFolder: downloadedModelFolder, hubApi: hubApi)
227229
} catch {
228230
// Convert generic errors to more specific ones

Sources/Hub/HubApi.swift

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -134,16 +134,17 @@ public extension HubApi {
134134

135135
switch response.statusCode {
136136
case 200..<400: break // Allow redirects to pass through to the redirect delegate
137-
case 400..<500: throw Hub.HubClientError.authorizationRequired
137+
case 401, 403: throw Hub.HubClientError.authorizationRequired
138+
case 404: throw Hub.HubClientError.fileNotFound(url.lastPathComponent)
138139
default: throw Hub.HubClientError.httpStatusCode(response.statusCode)
139140
}
140141

141142
return (data, response)
142143
}
143144

144-
func getFilenames(from repo: Repo, matching globs: [String] = []) async throws -> [String] {
145+
func getFilenames(from repo: Repo, revision: String = "main", matching globs: [String] = []) async throws -> [String] {
145146
// Read repo info and only parse "siblings"
146-
let url = URL(string: "\(endpoint)/api/\(repo.type)/\(repo.id)")!
147+
let url = URL(string: "\(endpoint)/api/\(repo.type)/\(repo.id)/revision/\(revision)")!
147148
let (data, _) = try await httpGet(for: url)
148149
let response = try JSONDecoder().decode(SiblingsResponse.self, from: data)
149150
let filenames = response.siblings.map { $0.rfilename }
@@ -335,6 +336,7 @@ public extension HubApi {
335336
struct HubFileDownloader {
336337
let hub: HubApi
337338
let repo: Repo
339+
let revision: String
338340
let repoDestination: URL
339341
let repoMetadataDestination: URL
340342
let relativeFilename: String
@@ -349,7 +351,7 @@ public extension HubApi {
349351
url = url.appending(component: repo.type.rawValue)
350352
}
351353
url = url.appending(path: repo.id)
352-
url = url.appending(path: "resolve/main") // TODO: revisions
354+
url = url.appending(path: "resolve/\(revision)")
353355
url = url.appending(path: relativeFilename)
354356
return url
355357
}
@@ -462,7 +464,7 @@ public extension HubApi {
462464
}
463465

464466
@discardableResult
465-
func snapshot(from repo: Repo, matching globs: [String] = [], progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL {
467+
func snapshot(from repo: Repo, revision: String = "main", matching globs: [String] = [], progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL {
466468
let repoDestination = localRepoLocation(repo)
467469
let repoMetadataDestination = repoDestination
468470
.appendingPathComponent(".cache")
@@ -504,13 +506,14 @@ public extension HubApi {
504506
return repoDestination
505507
}
506508

507-
let filenames = try await getFilenames(from: repo, matching: globs)
509+
let filenames = try await getFilenames(from: repo, revision: revision, matching: globs)
508510
let progress = Progress(totalUnitCount: Int64(filenames.count))
509511
for filename in filenames {
510512
let fileProgress = Progress(totalUnitCount: 100, parent: progress, pendingUnitCount: 1)
511513
let downloader = HubFileDownloader(
512514
hub: self,
513515
repo: repo,
516+
revision: revision,
514517
repoDestination: repoDestination,
515518
repoMetadataDestination: repoMetadataDestination,
516519
relativeFilename: filename,
@@ -529,18 +532,18 @@ public extension HubApi {
529532
}
530533

531534
@discardableResult
532-
func snapshot(from repoId: String, matching globs: [String] = [], progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL {
533-
try await snapshot(from: Repo(id: repoId), matching: globs, progressHandler: progressHandler)
535+
func snapshot(from repoId: String, revision: String = "main", matching globs: [String] = [], progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL {
536+
try await snapshot(from: Repo(id: repoId), revision: revision, matching: globs, progressHandler: progressHandler)
534537
}
535538

536539
@discardableResult
537-
func snapshot(from repo: Repo, matching glob: String, progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL {
538-
try await snapshot(from: repo, matching: [glob], progressHandler: progressHandler)
540+
func snapshot(from repo: Repo, revision: String = "main", matching glob: String, progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL {
541+
try await snapshot(from: repo, revision: revision, matching: [glob], progressHandler: progressHandler)
539542
}
540543

541544
@discardableResult
542-
func snapshot(from repoId: String, matching glob: String, progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL {
543-
try await snapshot(from: Repo(id: repoId), matching: [glob], progressHandler: progressHandler)
545+
func snapshot(from repoId: String, revision: String = "main", matching glob: String, progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL {
546+
try await snapshot(from: Repo(id: repoId), revision: revision, matching: [glob], progressHandler: progressHandler)
544547
}
545548
}
546549

@@ -596,9 +599,9 @@ public extension HubApi {
596599
)
597600
}
598601

599-
func getFileMetadata(from repo: Repo, matching globs: [String] = []) async throws -> [FileMetadata] {
602+
func getFileMetadata(from repo: Repo, revision: String = "main", matching globs: [String] = []) async throws -> [FileMetadata] {
600603
let files = try await getFilenames(from: repo, matching: globs)
601-
let url = URL(string: "\(endpoint)/\(repo.id)/resolve/main")! // TODO: revisions
604+
let url = URL(string: "\(endpoint)/\(repo.id)/resolve/\(revision)")!
602605
var selectedMetadata: [FileMetadata] = []
603606
for file in files {
604607
let fileURL = url.appending(path: file)
@@ -607,16 +610,16 @@ public extension HubApi {
607610
return selectedMetadata
608611
}
609612

610-
func getFileMetadata(from repoId: String, matching globs: [String] = []) async throws -> [FileMetadata] {
611-
try await getFileMetadata(from: Repo(id: repoId), matching: globs)
613+
func getFileMetadata(from repoId: String, revision: String = "main", matching globs: [String] = []) async throws -> [FileMetadata] {
614+
try await getFileMetadata(from: Repo(id: repoId), revision: revision, matching: globs)
612615
}
613616

614-
func getFileMetadata(from repo: Repo, matching glob: String) async throws -> [FileMetadata] {
615-
try await getFileMetadata(from: repo, matching: [glob])
617+
func getFileMetadata(from repo: Repo, revision: String = "main", matching glob: String) async throws -> [FileMetadata] {
618+
try await getFileMetadata(from: repo, revision: revision, matching: [glob])
616619
}
617620

618-
func getFileMetadata(from repoId: String, matching glob: String) async throws -> [FileMetadata] {
619-
try await getFileMetadata(from: Repo(id: repoId), matching: [glob])
621+
func getFileMetadata(from repoId: String, revision: String = "main", matching glob: String) async throws -> [FileMetadata] {
622+
try await getFileMetadata(from: Repo(id: repoId), revision: revision, matching: [glob])
620623
}
621624
}
622625

Sources/HubCLI/HubCLI.swift

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ struct Download: AsyncParsableCommand, SubcommandWithToken {
4848
@Option(help: "Repo type")
4949
var repoType: RepoType = .model
5050

51+
@Option(help: "Specific revision (e.g. branch, commit hash or tag)")
52+
var revision: String = "main"
53+
5154
@Option(help: "Glob patterns for files to include")
5255
var include: [String] = []
5356

@@ -57,7 +60,7 @@ struct Download: AsyncParsableCommand, SubcommandWithToken {
5760
func run() async throws {
5861
let hubApi = HubApi(hfToken: hfToken)
5962
let repo = Hub.Repo(id: repo, type: repoType.asHubApiRepoType)
60-
let downloadedTo = try await hubApi.snapshot(from: repo, matching: include) { progress in
63+
let downloadedTo = try await hubApi.snapshot(from: repo, revision: revision, matching: include) { progress in
6164
DispatchQueue.main.async {
6265
let totalPercent = 100 * progress.fractionCompleted
6366
print("\(progress.completedUnitCount)/\(progress.totalUnitCount) \(totalPercent.formatted("%.02f"))%", terminator: "\r")

Tests/HubTests/HubApiTests.swift

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1081,4 +1081,45 @@ class SnapshotDownloadTests: XCTestCase {
10811081
XCTAssertTrue(FileManager.default.fileExists(atPath: filePath.path),
10821082
"Downloaded file should exist at \(filePath.path)")
10831083
}
1084+
1085+
func testDownloadWithRevision() async throws {
1086+
let hubApi = HubApi(downloadBase: downloadDestination)
1087+
var lastProgress: Progress? = nil
1088+
1089+
let commitHash = "eaf97358a37d03fd48e5a87d15aff2e8423c1afb"
1090+
let downloadedTo = try await hubApi.snapshot(from: repo, revision: commitHash, matching: "*.json") { progress in
1091+
print("Total Progress: \(progress.fractionCompleted)")
1092+
print("Files Completed: \(progress.completedUnitCount) of \(progress.totalUnitCount)")
1093+
lastProgress = progress
1094+
}
1095+
1096+
let downloadedFilenames = getRelativeFiles(url: downloadDestination, repo: repo)
1097+
XCTAssertEqual(lastProgress?.fractionCompleted, 1)
1098+
XCTAssertEqual(lastProgress?.completedUnitCount, 6)
1099+
XCTAssertEqual(downloadedTo, downloadDestination.appending(path: "models/\(repo)"))
1100+
XCTAssertEqual(
1101+
Set(downloadedFilenames),
1102+
Set([
1103+
"config.json", "tokenizer.json", "tokenizer_config.json",
1104+
"llama-2-7b-chat.mlpackage/Manifest.json",
1105+
"llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/FeatureDescriptions.json",
1106+
"llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/Metadata.json",
1107+
])
1108+
)
1109+
1110+
do {
1111+
let revision = "nonexistent-revision"
1112+
try await hubApi.snapshot(from: repo, revision: revision, matching: "*.json")
1113+
XCTFail("Expected an error to be thrown")
1114+
} catch let error as Hub.HubClientError {
1115+
switch error {
1116+
case .fileNotFound:
1117+
break // Error type is correct
1118+
default:
1119+
XCTFail("Wrong error type: \(error)")
1120+
}
1121+
} catch {
1122+
XCTFail("Unexpected error: \(error)")
1123+
}
1124+
}
10841125
}

0 commit comments

Comments
 (0)