Skip to content

Commit 2b104b4

Browse files
committed
Refactor Logging utility for thread safety
1 parent 9c673e3 commit 2b104b4

File tree

5 files changed

+90
-49
lines changed

5 files changed

+90
-49
lines changed

Sources/WhisperKit/Core/TranscribeTask.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ open class TranscribeTask {
6666
Logging.debug("Starting pipeline at: \(Date())")
6767

6868
var options = decodeOptions ?? DecodingOptions()
69-
options.verbose = Logging.shared.logLevel != .none
69+
options.verbose = await Logging.isLoggingEnabled()
7070

7171
var detectedLanguage: String?
7272

Sources/WhisperKit/Core/WhisperKit.swift

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ open class WhisperKit {
6969
tokenizerFolder = config.tokenizerFolder ?? config.downloadBase
7070
useBackgroundDownloadSession = config.useBackgroundDownloadSession
7171
currentTimings = TranscriptionTimings()
72-
Logging.shared.logLevel = config.verbose ? config.logLevel : .none
72+
await Logging.updateLogLevel(config.verbose ? config.logLevel : .none)
7373

7474
try await setupModels(
7575
model: config.model,
@@ -282,7 +282,7 @@ open class WhisperKit {
282282

283283
Logging.debug("Downloading model \(variantPath)...")
284284
let modelFolder = try await hubApi.snapshot(from: repo, matching: [modelSearchPath]) { progress in
285-
Logging.debug(progress)
285+
Logging.debug(progress.debugDescription)
286286
if let callback = progressCallback {
287287
callback(progress)
288288
}
@@ -291,7 +291,7 @@ open class WhisperKit {
291291
let modelFolderName = modelFolder.appending(path: variantPath)
292292
return modelFolderName
293293
} catch {
294-
Logging.debug(error)
294+
Logging.debug(error.localizedDescription)
295295
throw error
296296
}
297297
}
@@ -518,7 +518,9 @@ open class WhisperKit {
518518

519519
/// Pass in your own logging callback here
520520
open func loggingCallback(_ callback: Logging.LoggingCallback?) {
521-
Logging.shared.loggingCallback = callback
521+
Task(priority: .utility) {
522+
await Logging.updateCallback(callback)
523+
}
522524
}
523525

524526
// MARK: - Detect language
@@ -557,7 +559,7 @@ open class WhisperKit {
557559
throw WhisperError.tokenizerUnavailable()
558560
}
559561

560-
let options = DecodingOptions(verbose: Logging.shared.logLevel != .none)
562+
let options = DecodingOptions(verbose: await Logging.isLoggingEnabled())
561563
let decoderInputs = try textDecoder.prepareDecoderInputs(withPrompt: [tokenizer.specialTokens.startOfTranscriptToken])
562564

563565
// Detect language using up to the first 30 seconds

Sources/WhisperKit/Utilities/Logging.swift

Lines changed: 79 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -3,57 +3,105 @@
33

44
import OSLog
55

6-
open class Logging {
7-
public static let shared = Logging()
8-
public var logLevel: LogLevel = .none
6+
public enum Logging: Sendable {
97

10-
public typealias LoggingCallback = (_ message: String) -> Void
11-
public var loggingCallback: LoggingCallback?
8+
// MARK: - Helper Types
129

13-
private let logger = OSLog(subsystem: Bundle.main.bundleIdentifier ?? "com.argmax.whisperkit", category: "WhisperKit")
10+
public typealias LoggingCallback = @Sendable (_ message: String) -> Void
1411

15-
@frozen
16-
public enum LogLevel: Int {
12+
public enum LogLevel: Int, Sendable, Comparable {
1713
case debug = 1
1814
case info = 2
1915
case error = 3
2016
case none = 4
2117

22-
func shouldLog(level: LogLevel) -> Bool {
23-
return self.rawValue <= level.rawValue
18+
var osLogType: OSLogType {
19+
switch self {
20+
case .debug: return .debug
21+
case .info: return .info
22+
case .error: return .error
23+
case .none: return .default
24+
}
25+
}
26+
27+
public static func < (lhs: LogLevel, rhs: LogLevel) -> Bool {
28+
lhs.rawValue < rhs.rawValue
2429
}
2530
}
2631

27-
private init() {}
32+
private actor Engine {
33+
var level: LogLevel
34+
var callback: LoggingCallback?
35+
private let logger: Logger
36+
37+
init(level: LogLevel = .none, callback: LoggingCallback? = nil) {
38+
self.level = level
39+
self.callback = callback
40+
self.logger = Logger(
41+
subsystem: Constants.Logging.subsystem,
42+
category: "WhisperKit"
43+
)
44+
}
2845

29-
public func log(_ items: Any..., separator: String = " ", terminator: String = "\n", type: OSLogType) {
30-
let message = items.map { "\($0)" }.joined(separator: separator)
31-
if let logger = loggingCallback {
32-
logger(message)
33-
} else {
34-
os_log("%{public}@", log: logger, type: type, message)
46+
func updateLogLevel(_ level: LogLevel) {
47+
self.level = level
3548
}
36-
}
3749

38-
public static func debug(_ items: Any..., separator: String = " ", terminator: String = "\n") {
39-
if shared.logLevel.shouldLog(level: .debug) {
40-
shared.log(items, separator: separator, terminator: terminator, type: .debug)
50+
func updateCallback(_ callback: LoggingCallback?) {
51+
self.callback = callback
4152
}
42-
}
4353

44-
public static func info(_ items: Any..., separator: String = " ", terminator: String = "\n") {
45-
if shared.logLevel.shouldLog(level: .info) {
46-
shared.log(items, separator: separator, terminator: terminator, type: .info)
54+
func log(level: LogLevel, message: String) {
55+
guard self.level != .none, self.level <= level else { return }
56+
57+
if let callback {
58+
callback(message)
59+
} else {
60+
logger.log(level: level.osLogType, "\(message, privacy: .public)")
61+
}
4762
}
4863
}
4964

50-
public static func error(_ items: Any..., separator: String = " ", terminator: String = "\n") {
51-
if shared.logLevel.shouldLog(level: .error) {
52-
shared.log(items, separator: separator, terminator: terminator, type: .error)
65+
// MARK: - Properties
66+
67+
private static let engine = Engine()
68+
69+
public static func isLoggingEnabled() async -> Bool {
70+
let level = await engine.level
71+
return level != .none
72+
}
73+
74+
public static func updateLogLevel(_ level: LogLevel) async {
75+
await engine.updateLogLevel(level)
76+
}
77+
78+
public static func updateCallback(_ callback: LoggingCallback?) async {
79+
await engine.updateCallback(callback)
80+
}
81+
82+
// MARK: - Convenience
83+
84+
public static func debug(_ message: String) {
85+
dispatch(level: .debug, message)
86+
}
87+
88+
public static func info(_ message: String) {
89+
dispatch(level: .info, message)
90+
}
91+
92+
public static func error(_ message: String) {
93+
dispatch(level: .error, message)
94+
}
95+
96+
private static func dispatch(level: LogLevel, _ message: String) {
97+
Task(priority: .utility) {
98+
await engine.log(level: level, message: message)
5399
}
54100
}
55101
}
56102

103+
// MARK: - Memory Usage
104+
57105
public extension Logging {
58106
static func logCurrentMemoryUsage(_ message: String) {
59107
let memoryUsage = getMemoryUsage()
@@ -78,15 +126,7 @@ public extension Logging {
78126
}
79127
}
80128

81-
@available(*, deprecated, message: "Subject to removal in a future version. Use `Logging.logCurrentMemoryUsage(_:)` instead.")
82-
public func logCurrentMemoryUsage(_ message: String) {
83-
Logging.logCurrentMemoryUsage(message)
84-
}
85-
86-
@available(*, deprecated, message: "Subject to removal in a future version. Use `Logging.getMemoryUsage()` instead.")
87-
public func getMemoryUsage() -> UInt64 {
88-
return Logging.getMemoryUsage()
89-
}
129+
// MARK: - Feature Specific Logging
90130

91131
extension Logging {
92132
enum AudioEncoding {
@@ -130,14 +170,13 @@ extension Logging {
130170
}
131171

132172
static func formatTimestamp(_ timestamp: Float) -> String {
133-
return String(format: "%.2f", timestamp)
173+
String(format: "%.2f", timestamp)
134174
}
135175

136176
static func formatTimeWithPercentage(_ time: Double, _ runs: Double, _ fullPipelineDuration: Double) -> String {
137-
let percentage = (time * 1000 / fullPipelineDuration) * 100 // Convert to percentage
177+
let percentage = (time * 1000 / fullPipelineDuration) * 100
138178
let runTime = runs > 0 ? time * 1000 / Double(runs) : 0
139179
let formattedString = String(format: "%8.2f ms / %6.0f runs (%8.2f ms/run) %5.2f%%", time * 1000, runs, runTime, percentage)
140180
return formattedString
141181
}
142182
}
143-

Tests/WhisperKitTests/RegressionTests.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -666,7 +666,7 @@ class RegressionTests: XCTestCase {
666666
let monitorTask = Task {
667667
while true {
668668
let remainingMemory = SystemMemoryCheckerAdvanced.getMemoryUsage().totalAvailableGB
669-
Logging.debug(remainingMemory, "GB of memory left")
669+
Logging.debug("\(remainingMemory) GB of memory left")
670670

671671
if remainingMemory <= 0.1 { // Cancel with 100MB remaining
672672
Logging.debug("Cancelling due to oom")
@@ -712,7 +712,7 @@ class RegressionTests: XCTestCase {
712712
initializationTask.cancel()
713713
monitorTask.cancel()
714714
timeoutTask.cancel()
715-
Logging.debug(error)
715+
Logging.debug(error.localizedDescription)
716716
throw error
717717
}
718718
}

Tests/WhisperKitTests/UnitTests.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ import XCTest
1212

1313
final class UnitTests: XCTestCase {
1414
override func setUp() async throws {
15-
Logging.shared.logLevel = .debug
15+
await Logging.updateLogLevel(.debug)
1616
}
1717

1818
// MARK: - Model Loading Test

0 commit comments

Comments
 (0)