Skip to content

Refactor registry checksum TOFU logic #6190

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 65 additions & 41 deletions Sources/PackageFingerprint/PackageFingerprintStorage.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
//
// This source file is part of the Swift open source project
//
// Copyright (c) 2021-2022 Apple Inc. and the Swift project authors
// Copyright (c) 2021-2023 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See http://swift.org/LICENSE.txt for license information
Expand All @@ -17,59 +17,83 @@ import PackageModel
import struct TSCUtility.Version

public protocol PackageFingerprintStorage {
func get(package: PackageIdentity,
version: Version,
observabilityScope: ObservabilityScope,
callbackQueue: DispatchQueue,
callback: @escaping (Result<[Fingerprint.Kind: Fingerprint], Error>) -> Void)
func get(
package: PackageIdentity,
version: Version,
observabilityScope: ObservabilityScope,
callbackQueue: DispatchQueue,
callback: @escaping (Result<[Fingerprint.Kind: Fingerprint], Error>) -> Void
)

func put(package: PackageIdentity,
version: Version,
fingerprint: Fingerprint,
observabilityScope: ObservabilityScope,
callbackQueue: DispatchQueue,
callback: @escaping (Result<Void, Error>) -> Void)
func put(
package: PackageIdentity,
version: Version,
fingerprint: Fingerprint,
observabilityScope: ObservabilityScope,
callbackQueue: DispatchQueue,
callback: @escaping (Result<Void, Error>) -> Void
)

func get(package: PackageReference,
version: Version,
observabilityScope: ObservabilityScope,
callbackQueue: DispatchQueue,
callback: @escaping (Result<[Fingerprint.Kind: Fingerprint], Error>) -> Void)
func get(
package: PackageReference,
version: Version,
observabilityScope: ObservabilityScope,
callbackQueue: DispatchQueue,
callback: @escaping (Result<[Fingerprint.Kind: Fingerprint], Error>) -> Void
)

func put(package: PackageReference,
version: Version,
fingerprint: Fingerprint,
observabilityScope: ObservabilityScope,
callbackQueue: DispatchQueue,
callback: @escaping (Result<Void, Error>) -> Void)
func put(
package: PackageReference,
version: Version,
fingerprint: Fingerprint,
observabilityScope: ObservabilityScope,
callbackQueue: DispatchQueue,
callback: @escaping (Result<Void, Error>) -> Void
)
}

public extension PackageFingerprintStorage {
func get(package: PackageIdentity,
version: Version,
kind: Fingerprint.Kind,
observabilityScope: ObservabilityScope,
callbackQueue: DispatchQueue,
callback: @escaping (Result<Fingerprint, Error>) -> Void) {
self.get(package: package, version: version, observabilityScope: observabilityScope, callbackQueue: callbackQueue) { result in
extension PackageFingerprintStorage {
public func get(
package: PackageIdentity,
version: Version,
kind: Fingerprint.Kind,
observabilityScope: ObservabilityScope,
callbackQueue: DispatchQueue,
callback: @escaping (Result<Fingerprint, Error>) -> Void
) {
self.get(
package: package,
version: version,
observabilityScope: observabilityScope,
callbackQueue: callbackQueue
) { result in
self.get(kind: kind, result, callback: callback)
}
}

func get(package: PackageReference,
version: Version,
kind: Fingerprint.Kind,
observabilityScope: ObservabilityScope,
callbackQueue: DispatchQueue,
callback: @escaping (Result<Fingerprint, Error>) -> Void) {
self.get(package: package, version: version, observabilityScope: observabilityScope, callbackQueue: callbackQueue) { result in
public func get(
package: PackageReference,
version: Version,
kind: Fingerprint.Kind,
observabilityScope: ObservabilityScope,
callbackQueue: DispatchQueue,
callback: @escaping (Result<Fingerprint, Error>) -> Void
) {
self.get(
package: package,
version: version,
observabilityScope: observabilityScope,
callbackQueue: callbackQueue
) { result in
self.get(kind: kind, result, callback: callback)
}
}

private func get(kind: Fingerprint.Kind,
_ fingerprintsResult: Result<[Fingerprint.Kind: Fingerprint], Error>,
callback: @escaping (Result<Fingerprint, Error>) -> Void) {
private func get(
kind: Fingerprint.Kind,
_ fingerprintsResult: Result<[Fingerprint.Kind: Fingerprint], Error>,
callback: @escaping (Result<Fingerprint, Error>) -> Void
) {
callback(fingerprintsResult.tryMap { fingerprints in
guard let fingerprint = fingerprints[kind] else {
throw PackageFingerprintStorageError.notFound
Expand Down
5 changes: 3 additions & 2 deletions Sources/PackageRegistry/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# This source file is part of the Swift open source project
#
# Copyright (c) 2021 Apple Inc. and the Swift project authors
# Copyright (c) 2021-2023 Apple Inc. and the Swift project authors
# Licensed under Apache License v2.0 with Runtime Library Exception
#
# See http://swift.org/LICENSE.txt for license information
Expand All @@ -10,7 +10,8 @@ add_library(PackageRegistry STATIC
Registry.swift
RegistryConfiguration.swift
RegistryClient.swift
RegistryDownloadsManager.swift)
RegistryDownloadsManager.swift
ChecksumTOFU.swift)
target_link_libraries(PackageRegistry PUBLIC
Basics
PackageFingerprint
Expand Down
217 changes: 217 additions & 0 deletions Sources/PackageRegistry/ChecksumTOFU.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the Swift open source project
//
// Copyright (c) 2023 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See http://swift.org/LICENSE.txt for license information
// See http://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//

import Dispatch

import Basics
import PackageFingerprint
import PackageModel

import struct TSCUtility.Version

struct PackageVersionChecksumTOFU {
private let fingerprintStorage: PackageFingerprintStorage?
private let fingerprintCheckingMode: FingerprintCheckingMode

private let registryClient: RegistryClient

init(
fingerprintStorage: PackageFingerprintStorage?,
fingerprintCheckingMode: FingerprintCheckingMode,
registryClient: RegistryClient
) {
self.fingerprintStorage = fingerprintStorage
self.fingerprintCheckingMode = fingerprintCheckingMode
self.registryClient = registryClient
}

func check(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: validate probable a better fit?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will rename it as part of #6184

registry: Registry,
package: PackageIdentity.RegistryIdentity,
version: Version,
checksum: String,
timeout: DispatchTimeInterval?,
observabilityScope: ObservabilityScope,
callbackQueue: DispatchQueue,
completion: @escaping (Result<Void, Error>) -> Void
) {
self.getExpectedChecksum(
registry: registry,
package: package,
version: version,
timeout: timeout,
observabilityScope: observabilityScope,
callbackQueue: callbackQueue
) { result in
completion(
result.tryMap { expectedChecksum in
if checksum != expectedChecksum {
switch self.fingerprintCheckingMode {
case .strict:
throw RegistryError.invalidChecksum(expected: expectedChecksum, actual: checksum)
case .warn:
observabilityScope
.emit(
warning: "The checksum \(checksum) does not match previously recorded value \(expectedChecksum)"
)
}
}
}
)
}
}

private func getExpectedChecksum(
registry: Registry,
package: PackageIdentity.RegistryIdentity,
version: Version,
timeout: DispatchTimeInterval?,
observabilityScope: ObservabilityScope,
callbackQueue: DispatchQueue,
completion: @escaping (Result<String, Error>) -> Void
) {
// We either use a previously recorded checksum, or fetch it from the registry.
self.readFromStorage(
package: package,
version: version,
observabilityScope: observabilityScope,
callbackQueue: callbackQueue
) { result in
switch result {
case .success(.some(let savedChecksum)):
completion(.success(savedChecksum))
default:
// Try fetching checksum from registry if:
// - No storage available
// - Checksum not found in storage
// - Reading from storage resulted in error
self.registryClient.getRawPackageVersionMetadata(
registry: registry,
package: package,
version: version,
timeout: timeout,
observabilityScope: observabilityScope,
callbackQueue: callbackQueue
) { result in
switch result {
case .success(let metadata):
guard let sourceArchive = metadata.resources
.first(where: { $0.name == "source-archive" })
else {
return completion(.failure(RegistryError.missingSourceArchive))
}

guard let checksum = sourceArchive.checksum else {
return completion(.failure(RegistryError.invalidSourceArchive))
}

self.writeToStorage(
registry: registry,
package: package,
version: version,
checksum: checksum,
observabilityScope: observabilityScope,
callbackQueue: callbackQueue
) { writeResult in
completion(writeResult.tryMap { _ in checksum })
}
case .failure(RegistryError.failedRetrievingReleaseInfo(_, _, _, let error)):
completion(.failure(RegistryError.failedRetrievingReleaseChecksum(
registry: registry,
package: package.underlying,
version: version,
error: error
)))
case .failure(let error):
completion(.failure(RegistryError.failedRetrievingReleaseChecksum(
registry: registry,
package: package.underlying,
version: version,
error: error
)))
}
}
}
}
}

private func readFromStorage(
package: PackageIdentity.RegistryIdentity,
version: Version,
observabilityScope: ObservabilityScope,
callbackQueue: DispatchQueue,
completion: @escaping (Result<String?, Error>) -> Void
) {
guard let fingerprintStorage = self.fingerprintStorage else {
return completion(.success(nil))
}

fingerprintStorage.get(
package: package.underlying,
version: version,
kind: .registry,
observabilityScope: observabilityScope,
callbackQueue: callbackQueue
) { result in
switch result {
case .success(let fingerprint):
completion(.success(fingerprint.value))
case .failure(PackageFingerprintStorageError.notFound):
completion(.success(nil))
case .failure(let error):
observabilityScope
.emit(error: "Failed to get registry fingerprint for \(package) \(version) from storage: \(error)")
completion(.failure(error))
}
}
}

private func writeToStorage(
registry: Registry,
package: PackageIdentity.RegistryIdentity,
version: Version,
checksum: String,
observabilityScope: ObservabilityScope,
callbackQueue: DispatchQueue,
completion: @escaping (Result<Void, Error>) -> Void
) {
guard let fingerprintStorage = self.fingerprintStorage else {
return completion(.success(()))
}

fingerprintStorage.put(
package: package.underlying,
version: version,
fingerprint: .init(origin: .registry(registry.url), value: checksum),
observabilityScope: observabilityScope,
callbackQueue: callbackQueue
) { result in
switch result {
case .success:
completion(.success(()))
case .failure(PackageFingerprintStorageError.conflict(_, let existing)):
switch self.fingerprintCheckingMode {
case .strict:
completion(.failure(RegistryError.checksumChanged(latest: checksum, previous: existing.value)))
case .warn:
observabilityScope
.emit(
warning: "The checksum \(checksum) from \(registry.url.absoluteString) does not match previously recorded value \(existing.value) from \(String(describing: existing.origin.url?.absoluteString))"
)
completion(.success(()))
}
case .failure(let error):
completion(.failure(error))
}
}
}
}
Loading