Skip to content

Commit 2347f20

Browse files
Use Thread.threadDictionary instead of TaskLocal for thread-local (#395)
## Motivation We were using a `@TaskLocal static var` to hold the `FiniteFieldArithmeticContext` (FFAC) on each of the curve types. The intention here was this would operate as a thread-local value, since task-locals behave like thread locals when the caller is not part of a task context. However, they still require binding, (i.e. the use of e.g. `P256.$__ffac.withValue { ... }`[^1], otherwise they just return the _default_ value, which is globally shared. We were not doing that, nor do we have a sensible place where we could. This causes crashes when using these values in a multithreaded environment since the same value will be used across threads. ## Modifications - Use an explicit thread-local API from Foundation to store and read the per-curve FFAC. - Add a test that shows each thread gets the same value on each read, but distinct from all other threads. ## Result Fixes crash when using ECToolbox-based APIs in multithreaded code. [^1]: https://developer.apple.com/documentation/swift/tasklocal#Using-task-local-values-outside-of-tasks
1 parent 334e682 commit 2347f20

File tree

2 files changed

+102
-9
lines changed

2 files changed

+102
-9
lines changed

Sources/_CryptoExtras/ECToolbox/BoringSSL/ECToolbox_boring.swift

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,16 @@ extension P256: OpenSSLSupportedNISTCurve {
6565
@inlinable
6666
static var hashToFieldByteCount: Int { 48 }
6767

68-
@TaskLocal
6968
@usableFromInline
70-
// NOTE: This could be a let when Swift 6.0 is the minimum supported version.
71-
static var __ffac = try! FiniteFieldArithmeticContext(fieldSize: P256.group.order)
69+
static var __ffac: FiniteFieldArithmeticContext {
70+
let key = "com.apple.swift-crypto.P256.__ffac"
71+
if let value = Thread.current.threadDictionary[key] as? FiniteFieldArithmeticContext {
72+
return value
73+
}
74+
let value = try! FiniteFieldArithmeticContext(fieldSize: P256.group.order)
75+
Thread.current.threadDictionary[key] = value
76+
return value
77+
}
7278
}
7379

7480
/// NOTE: This conformance applies to this type from the Crypto module even if it comes from the SDK.
@@ -92,10 +98,16 @@ extension P384: OpenSSLSupportedNISTCurve {
9298
@inlinable
9399
static var hashToFieldByteCount: Int { 72 }
94100

95-
@TaskLocal
96101
@usableFromInline
97-
// NOTE: This could be a let when Swift 6.0 is the minimum supported version.
98-
static var __ffac = try! FiniteFieldArithmeticContext(fieldSize: P384.group.order)
102+
static var __ffac: FiniteFieldArithmeticContext {
103+
let key = "com.apple.swift-crypto.P384.__ffac"
104+
if let value = Thread.current.threadDictionary[key] as? FiniteFieldArithmeticContext {
105+
return value
106+
}
107+
let value = try! FiniteFieldArithmeticContext(fieldSize: P384.group.order)
108+
Thread.current.threadDictionary[key] = value
109+
return value
110+
}
99111
}
100112

101113
/// NOTE: This conformance applies to this type from the Crypto module even if it comes from the SDK.
@@ -119,10 +131,16 @@ extension P521: OpenSSLSupportedNISTCurve {
119131
@inlinable
120132
static var hashToFieldByteCount: Int { 98 }
121133

122-
@TaskLocal
123134
@usableFromInline
124-
// NOTE: This could be a let when Swift 6.0 is the minimum supported version.
125-
static var __ffac = try! FiniteFieldArithmeticContext(fieldSize: P521.group.order)
135+
static var __ffac: FiniteFieldArithmeticContext {
136+
let key = "com.apple.swift-crypto.P521.__ffac"
137+
if let value = Thread.current.threadDictionary[key] as? FiniteFieldArithmeticContext {
138+
return value
139+
}
140+
let value = try! FiniteFieldArithmeticContext(fieldSize: P521.group.order)
141+
Thread.current.threadDictionary[key] = value
142+
return value
143+
}
126144
}
127145

128146
@available(macOS 10.15, iOS 13, watchOS 6, tvOS 13, macCatalyst 13, visionOS 1.0, *)
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// This source file is part of the SwiftCrypto open source project
4+
//
5+
// Copyright (c) 2025 Apple Inc. and the SwiftCrypto project authors
6+
// Licensed under Apache License v2.0
7+
//
8+
// See LICENSE.txt for license information
9+
// See CONTRIBUTORS.txt for the list of SwiftCrypto project authors
10+
//
11+
// SPDX-License-Identifier: Apache-2.0
12+
//
13+
//===----------------------------------------------------------------------===//
14+
import Crypto
15+
import Foundation
16+
import XCTest
17+
18+
@testable import _CryptoExtras
19+
20+
final class ECToolboxBoringSSLTests: XCTestCase {
21+
func testThreadLocalFFAC() async {
22+
await testThreadLocalFFAC(P256.self)
23+
await testThreadLocalFFAC(P384.self)
24+
await testThreadLocalFFAC(P521.self)
25+
}
26+
27+
func testThreadLocalFFAC(_ Curve: (some OpenSSLSupportedNISTCurve & Sendable).Type) async {
28+
let numThreads = 3
29+
let numReadsPerThread = 2
30+
31+
var threads:
32+
[(
33+
thread: Thread,
34+
thisThreadDidReads: XCTestExpectation,
35+
allThreadsDidReads: XCTestExpectation,
36+
thisThreadFinished: XCTestExpectation
37+
)] = []
38+
39+
var objectIdentifiers: [(threadID: Int, ffacID: ObjectIdentifier)] = []
40+
let lock = NSLock()
41+
42+
for i in 1...numThreads {
43+
let thisThreadDidReads = expectation(description: "this thread did its reads")
44+
let allThreadsDidReads = expectation(description: "all threads did their reads")
45+
let thisThreadFinished = expectation(description: "this thread is finished")
46+
let thread = Thread {
47+
for _ in 1...numReadsPerThread {
48+
lock.lock()
49+
objectIdentifiers.append((i, ObjectIdentifier(Curve.__ffac)))
50+
lock.unlock()
51+
}
52+
thisThreadDidReads.fulfill()
53+
XCTWaiter().wait(for: [allThreadsDidReads], timeout: .greatestFiniteMagnitude)
54+
thisThreadFinished.fulfill()
55+
}
56+
thread.name = "thread-\(i)"
57+
threads.append((thread, thisThreadDidReads, allThreadsDidReads, thisThreadFinished))
58+
thread.start()
59+
}
60+
await fulfillment(of: threads.map(\.thisThreadDidReads), timeout: 0.5)
61+
for thread in threads { thread.allThreadsDidReads.fulfill() }
62+
await fulfillment(of: threads.map(\.thisThreadFinished), timeout: 0.5)
63+
64+
XCTAssertEqual(objectIdentifiers.count, numThreads * numReadsPerThread)
65+
for threadID in 1...numThreads {
66+
let partitionBoundary = objectIdentifiers.partition(by: { $0.threadID == threadID })
67+
let otherThreadsObjIDs = objectIdentifiers[..<partitionBoundary].map(\.ffacID)
68+
let thisThreadObjIDs = objectIdentifiers[partitionBoundary...].map(\.ffacID)
69+
let intersection = Set(thisThreadObjIDs).intersection(Set(otherThreadsObjIDs))
70+
XCTAssertEqual(thisThreadObjIDs.count, numReadsPerThread, "Thread should read \(numReadsPerThread) times.")
71+
XCTAssertEqual(Set(thisThreadObjIDs).count, 1, "Thread should see same object on every read.")
72+
XCTAssert(intersection.isEmpty, "Thread should see different objects from other threads.")
73+
}
74+
}
75+
}

0 commit comments

Comments
 (0)