forked from argmaxinc/WhisperKit
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Memory and Latency Regression Tests (argmaxinc#99)
* Add initial code for regression tests * Add processor info & generalize file write * Add WER calculations * Add unit tests for WER * Add regression tests for each model * * capture transcript in test report * make `getMemoryUsed` static * remove `jfk_long.mp4` as its unused * update dataset url to point to whisperkit * dynamically test all models available on the hub * Update Tests/WhisperKitTests/FunctionalTests.swift Co-authored-by: Zach Nagengast <zacharynagengast@gmail.com> * Remover WERUtils as it's not part of the current changes --------- Co-authored-by: Zach Nagengast <zacharynagengast@gmail.com>
- Loading branch information
1 parent
c6782af
commit d3a9a99
Showing
2 changed files
with
334 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,177 @@ | ||
import Foundation | ||
import WhisperKit | ||
|
||
// MARK: RegressionStats | ||
class RegressionStats: JSONCodable { | ||
let testInfo: TestInfo | ||
let memoryStats: MemoryStats | ||
let latencyStats: LatencyStats | ||
|
||
init(testInfo: TestInfo, memoryStats: MemoryStats, latencyStats: LatencyStats) { | ||
self.testInfo = testInfo | ||
self.memoryStats = memoryStats | ||
self.latencyStats = latencyStats | ||
} | ||
|
||
func jsonData() throws -> Data { | ||
return try JSONEncoder().encode(self) | ||
} | ||
} | ||
|
||
// MARK: TestInfo | ||
class TestInfo: JSONCodable { | ||
let device, audioFile: String | ||
let model: String | ||
let date: String | ||
let timeElapsedInSeconds: TimeInterval | ||
let timings: TranscriptionTimings? | ||
let transcript: String? | ||
|
||
init(device: String, audioFile: String, model: String, date: String, timeElapsedInSeconds: TimeInterval, timings: TranscriptionTimings?, transcript: String?) { | ||
self.device = device | ||
self.audioFile = audioFile | ||
self.model = model | ||
self.date = date | ||
self.timeElapsedInSeconds = timeElapsedInSeconds | ||
self.timings = timings | ||
self.transcript = transcript | ||
} | ||
} | ||
|
||
// MARK: TestReport | ||
struct TestReport: JSONCodable{ | ||
let device: String | ||
let modelsTested: [String] | ||
let failureInfo: [String:String] | ||
|
||
init(device: String, modelsTested: [String], failureInfo: [String:String]) { | ||
self.device = device | ||
self.modelsTested = modelsTested | ||
self.failureInfo = failureInfo | ||
} | ||
} | ||
|
||
// MARK: Stats | ||
class Stats: JSONCodable { | ||
var measurements: [Measurement] | ||
let units: String | ||
var totalNumberOfMeasurements: Int | ||
|
||
init(measurements: [Measurement], units: String, totalNumberOfMeasurements: Int) { | ||
self.measurements = measurements | ||
self.units = units | ||
self.totalNumberOfMeasurements = totalNumberOfMeasurements | ||
} | ||
|
||
func measure(from values: [Float], timeElapsed: TimeInterval){ | ||
var measurement: Measurement | ||
if let min = values.min(),let max = values.max(){ | ||
measurement = Measurement( | ||
min: min, | ||
max: max, | ||
average: values.reduce(0,+) / Float(values.count), | ||
numberOfMeasurements: values.count, | ||
timeElapsed: timeElapsed | ||
) | ||
self.measurements.append(measurement) | ||
self.totalNumberOfMeasurements += values.count | ||
} | ||
} | ||
} | ||
|
||
// MARK: LatencyStats | ||
class LatencyStats: Stats{ | ||
override init(measurements: [Measurement] = [], units: String, totalNumberOfMeasurements: Int = 0) { | ||
super.init(measurements: measurements, units: units, totalNumberOfMeasurements: totalNumberOfMeasurements) | ||
} | ||
|
||
required init(from decoder: any Decoder) throws { | ||
fatalError("init(from:) has not been implemented") | ||
} | ||
|
||
func calculate(from total: Double, runs: Int) -> Double{ | ||
return runs > 0 ? total / Double(runs) : -1 | ||
} | ||
} | ||
|
||
class MemoryStats: Stats{ | ||
var preTranscribeMemory: Float | ||
var postTranscribeMemory: Float | ||
|
||
init(measurements: [Measurement] = [], units: String, totalNumberOfMeasurements: Int = 0, preTranscribeMemory: Float, postTranscribeMemory: Float) { | ||
self.preTranscribeMemory = preTranscribeMemory | ||
self.postTranscribeMemory = postTranscribeMemory | ||
super.init(measurements: measurements, units: units, totalNumberOfMeasurements: totalNumberOfMeasurements) | ||
} | ||
|
||
required init(from decoder: any Decoder) throws { | ||
fatalError("init(from:) has not been implemented") | ||
} | ||
|
||
// Implement the encode(to:) method | ||
override func encode(to encoder: Encoder) throws { | ||
var container = encoder.container(keyedBy: CodingKeys.self) | ||
try super.encode(to: encoder) | ||
try container.encode(preTranscribeMemory, forKey: .preTranscribeMemory) | ||
try container.encode(postTranscribeMemory, forKey: .postTranscribeMemory) | ||
} | ||
|
||
// Coding keys for MemoryStats properties | ||
enum CodingKeys: String, CodingKey { | ||
case preTranscribeMemory | ||
case postTranscribeMemory | ||
} | ||
} | ||
|
||
struct Measurement: JSONCodable{ | ||
let min, max, average: Float | ||
let numberOfMeasurements: Int | ||
let timeElapsed: TimeInterval | ||
} | ||
|
||
protocol JSONCodable: Codable { | ||
} | ||
extension JSONCodable{ | ||
func jsonData() throws -> Data { | ||
return try JSONEncoder().encode(self) | ||
} | ||
} | ||
|
||
extension Data { | ||
var prettyPrintedJSONString: NSString? { /// NSString gives us a nice sanitized debugDescription | ||
guard let object = try? JSONSerialization.jsonObject(with: self, options: []), | ||
let data = try? JSONSerialization.data(withJSONObject: object, options: [.prettyPrinted, .sortedKeys]), | ||
let prettyPrintedString = NSString(data: data, encoding: String.Encoding.utf8.rawValue) else { return nil } | ||
|
||
return prettyPrintedString | ||
} | ||
} | ||
|
||
// MARK: - SystemMemoryChecker | ||
@available(macOS 13, iOS 16, watchOS 10, visionOS 1, *) | ||
class SystemMemoryChecker: NSObject{ | ||
|
||
static func getMemoryUsed() -> UInt64 { | ||
// The `TASK_VM_INFO_COUNT` and `TASK_VM_INFO_REV1_COUNT` macros are too | ||
// complex for the Swift C importer, so we have to define them ourselves. | ||
let TASK_VM_INFO_COUNT = mach_msg_type_number_t(MemoryLayout<task_vm_info_data_t>.size / MemoryLayout<integer_t>.size) | ||
guard let offset = MemoryLayout.offset(of: \task_vm_info_data_t.min_address) else {return 0} | ||
let TASK_VM_INFO_REV1_COUNT = mach_msg_type_number_t(offset / MemoryLayout<integer_t>.size) | ||
var info = task_vm_info_data_t() | ||
var count = TASK_VM_INFO_COUNT | ||
let kr = withUnsafeMutablePointer(to: &info) { infoPtr in | ||
infoPtr.withMemoryRebound(to: integer_t.self, capacity: Int(count)) { intPtr in | ||
task_info(mach_task_self_, task_flavor_t(TASK_VM_INFO), intPtr, &count) | ||
} | ||
} | ||
guard | ||
kr == KERN_SUCCESS, | ||
count >= TASK_VM_INFO_REV1_COUNT | ||
else { return 0} | ||
|
||
let usedBytes = Float(info.phys_footprint) | ||
let usedBytesInt: UInt64 = UInt64(usedBytes) | ||
let usedMB = usedBytesInt / 1024 / 1024 | ||
return usedMB | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,157 @@ | ||
import CoreML | ||
import Hub | ||
@testable import WhisperKit | ||
import XCTest | ||
|
||
@available(macOS 13, iOS 16, watchOS 10, visionOS 1, *) | ||
final class RegressionTests: XCTestCase { | ||
|
||
var audioFileURL: URL? | ||
|
||
override func setUp() { | ||
super.setUp() | ||
|
||
if self.audioFileURL == nil{ | ||
let expectation = XCTestExpectation(description: "Download test audio") | ||
downloadTestAudio { success in | ||
if success { | ||
expectation.fulfill() | ||
} else { | ||
XCTFail("Downloading audio file for testing failed") | ||
} | ||
} | ||
// Wait for the expectation with a timeout | ||
wait(for: [expectation], timeout: 30) | ||
} | ||
} | ||
|
||
func downloadTestAudio(completion: @escaping (Bool) -> Void) { | ||
Task { | ||
do { | ||
let earnings22CompressedDataset = Hub.Repo(id: "argmaxinc/whisperkit-test-data", type: .datasets) | ||
let tempPath = FileManager.default.temporaryDirectory | ||
let downloadBase = tempPath.appending(component: "huggingface") | ||
let hubApi = HubApi(downloadBase: downloadBase) | ||
let fileURL = try await hubApi.snapshot(from: earnings22CompressedDataset, matching: ["4484146.mp3"]) | ||
self.audioFileURL = fileURL.appending(component: "4484146.mp3") | ||
completion(true) | ||
} catch { | ||
XCTFail("Async setup failed with error: \(error)") | ||
completion(false) | ||
} | ||
} | ||
} | ||
|
||
func testAndMeasureModelPerformance(model: String, device: String) async throws{ | ||
let audioFilePath = try XCTUnwrap( | ||
self.audioFileURL?.path(), | ||
"Audio file not found" | ||
) | ||
|
||
let startTime = Date() | ||
let iso8601DateTimeString = ISO8601DateFormatter().string(from: Date()) | ||
|
||
var currentMemoryValues = [Float]() | ||
var currentTPSValues = [Float]() | ||
|
||
let memoryStats = MemoryStats( | ||
measurements: [], units: "MB", | ||
totalNumberOfMeasurements: 0, | ||
preTranscribeMemory: -1, | ||
postTranscribeMemory: -1 | ||
) | ||
let latencyStats = LatencyStats( | ||
measurements: [], units: "Tokens/Sec", | ||
totalNumberOfMeasurements: 0 | ||
) | ||
var count: Int = 0 | ||
|
||
let callback = { | ||
(result:TranscriptionProgress) -> Bool in | ||
count += 1 | ||
let currentMemory = SystemMemoryChecker.getMemoryUsed() | ||
let currentTPS = result.timings.tokensPerSecond | ||
if currentMemory != 0{ | ||
currentMemoryValues.append(Float(currentMemory)) | ||
} | ||
if !currentTPS.isNaN{ | ||
currentTPSValues.append(Float(currentTPS)) | ||
} | ||
if count % 100 == 1{ | ||
let timeElapsed = Date().timeIntervalSince(startTime) | ||
memoryStats.measure(from: currentMemoryValues, timeElapsed: timeElapsed) | ||
latencyStats.measure(from: currentTPSValues, timeElapsed: timeElapsed) | ||
currentMemoryValues = [] | ||
currentTPSValues = [] | ||
} | ||
return true | ||
} | ||
|
||
let whisperKit = try await WhisperKit(model: model) | ||
memoryStats.preTranscribeMemory = Float(SystemMemoryChecker.getMemoryUsed()) | ||
|
||
let transcriptionResult = try await XCTUnwrapAsync( | ||
await whisperKit.transcribe(audioPath: audioFilePath, callback: callback), | ||
"Transcription failed" | ||
) | ||
XCTAssert(transcriptionResult.text.isEmpty == false, "Transcription failed") | ||
|
||
memoryStats.postTranscribeMemory = Float(SystemMemoryChecker.getMemoryUsed()) | ||
let testInfo = TestInfo( | ||
device: device, | ||
audioFile: audioFilePath, | ||
model: model, | ||
date: startTime.formatted(Date.ISO8601FormatStyle().dateSeparator(.dash)), | ||
timeElapsedInSeconds: Date().timeIntervalSince(startTime), | ||
timings: transcriptionResult.timings, | ||
transcript: transcriptionResult.text | ||
) | ||
let json = RegressionStats(testInfo: testInfo, memoryStats: memoryStats, latencyStats: latencyStats) | ||
do{ | ||
let attachment = try XCTAttachment(data: json.jsonData(), uniformTypeIdentifier: "json") | ||
attachment.lifetime = .keepAlways | ||
attachment.name = "\(device)_\(model)_\(iso8601DateTimeString).json" | ||
add(attachment) | ||
} | ||
catch{ | ||
XCTFail("Failed with error: \(error)") | ||
} | ||
} | ||
|
||
func testRegressionAndLatencyForAllModels() async throws{ | ||
var allModels: [String] = [] | ||
var failureInfo: [String:String] = [:] | ||
var currentDevice = WhisperKit.deviceName() | ||
let iso8601DateTimeString = ISO8601DateFormatter().string(from: Date()) | ||
|
||
#if os(macOS) && arch(arm64) | ||
currentDevice = Process.processor | ||
#endif | ||
|
||
do{ | ||
allModels = try await WhisperKit.fetchAvailableModels() | ||
} | ||
catch{ | ||
XCTFail("Failed to fetch available models: \(error.localizedDescription)") | ||
} | ||
|
||
for model in allModels{ | ||
do{ | ||
try await testAndMeasureModelPerformance(model: model, device: currentDevice) | ||
} | ||
catch{ | ||
failureInfo[model] = error.localizedDescription | ||
} | ||
} | ||
let testReport = TestReport(device: currentDevice, modelsTested: allModels, failureInfo: failureInfo) | ||
do{ | ||
let attachment = try XCTAttachment(data: testReport.jsonData(), uniformTypeIdentifier: "json") | ||
attachment.lifetime = .keepAlways | ||
attachment.name = "\(currentDevice)_summary_\(iso8601DateTimeString).json" | ||
add(attachment) | ||
}catch{ | ||
XCTFail("Failed with error: \(error)") | ||
} | ||
} | ||
|
||
} |