Skip to content

Commit

Permalink
Memory and Latency Regression Tests (argmaxinc#99)
Browse files Browse the repository at this point in the history
* Add initial code for regression tests

* Add processor info & generalize file write

* Add WER calculations

* Add unit tests for WER

* Add regression tests for each model

* * capture transcript in test report
* make `getMemoryUsed` static
* remove `jfk_long.mp4` as its unused
* update dataset url to point to whisperkit
* dynamically test all models available on the hub

* Update Tests/WhisperKitTests/FunctionalTests.swift

Co-authored-by: Zach Nagengast <zacharynagengast@gmail.com>

* Remover WERUtils as it's not part of the current changes

---------

Co-authored-by: Zach Nagengast <zacharynagengast@gmail.com>
  • Loading branch information
Abhinay1997 and ZachNagengast authored Apr 21, 2024
1 parent c6782af commit d3a9a99
Show file tree
Hide file tree
Showing 2 changed files with 334 additions and 0 deletions.
177 changes: 177 additions & 0 deletions Tests/WhisperKitTests/MemoryTestUtils.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
import Foundation
import WhisperKit

// MARK: RegressionStats
class RegressionStats: JSONCodable {
let testInfo: TestInfo
let memoryStats: MemoryStats
let latencyStats: LatencyStats

init(testInfo: TestInfo, memoryStats: MemoryStats, latencyStats: LatencyStats) {
self.testInfo = testInfo
self.memoryStats = memoryStats
self.latencyStats = latencyStats
}

func jsonData() throws -> Data {
return try JSONEncoder().encode(self)
}
}

// MARK: TestInfo
class TestInfo: JSONCodable {
let device, audioFile: String
let model: String
let date: String
let timeElapsedInSeconds: TimeInterval
let timings: TranscriptionTimings?
let transcript: String?

init(device: String, audioFile: String, model: String, date: String, timeElapsedInSeconds: TimeInterval, timings: TranscriptionTimings?, transcript: String?) {
self.device = device
self.audioFile = audioFile
self.model = model
self.date = date
self.timeElapsedInSeconds = timeElapsedInSeconds
self.timings = timings
self.transcript = transcript
}
}

// MARK: TestReport
struct TestReport: JSONCodable{
let device: String
let modelsTested: [String]
let failureInfo: [String:String]

init(device: String, modelsTested: [String], failureInfo: [String:String]) {
self.device = device
self.modelsTested = modelsTested
self.failureInfo = failureInfo
}
}

// MARK: Stats
class Stats: JSONCodable {
var measurements: [Measurement]
let units: String
var totalNumberOfMeasurements: Int

init(measurements: [Measurement], units: String, totalNumberOfMeasurements: Int) {
self.measurements = measurements
self.units = units
self.totalNumberOfMeasurements = totalNumberOfMeasurements
}

func measure(from values: [Float], timeElapsed: TimeInterval){
var measurement: Measurement
if let min = values.min(),let max = values.max(){
measurement = Measurement(
min: min,
max: max,
average: values.reduce(0,+) / Float(values.count),
numberOfMeasurements: values.count,
timeElapsed: timeElapsed
)
self.measurements.append(measurement)
self.totalNumberOfMeasurements += values.count
}
}
}

// MARK: LatencyStats
class LatencyStats: Stats{
override init(measurements: [Measurement] = [], units: String, totalNumberOfMeasurements: Int = 0) {
super.init(measurements: measurements, units: units, totalNumberOfMeasurements: totalNumberOfMeasurements)
}

required init(from decoder: any Decoder) throws {
fatalError("init(from:) has not been implemented")
}

func calculate(from total: Double, runs: Int) -> Double{
return runs > 0 ? total / Double(runs) : -1
}
}

class MemoryStats: Stats{
var preTranscribeMemory: Float
var postTranscribeMemory: Float

init(measurements: [Measurement] = [], units: String, totalNumberOfMeasurements: Int = 0, preTranscribeMemory: Float, postTranscribeMemory: Float) {
self.preTranscribeMemory = preTranscribeMemory
self.postTranscribeMemory = postTranscribeMemory
super.init(measurements: measurements, units: units, totalNumberOfMeasurements: totalNumberOfMeasurements)
}

required init(from decoder: any Decoder) throws {
fatalError("init(from:) has not been implemented")
}

// Implement the encode(to:) method
override func encode(to encoder: Encoder) throws {
var container = encoder.container(keyedBy: CodingKeys.self)
try super.encode(to: encoder)
try container.encode(preTranscribeMemory, forKey: .preTranscribeMemory)
try container.encode(postTranscribeMemory, forKey: .postTranscribeMemory)
}

// Coding keys for MemoryStats properties
enum CodingKeys: String, CodingKey {
case preTranscribeMemory
case postTranscribeMemory
}
}

struct Measurement: JSONCodable{
let min, max, average: Float
let numberOfMeasurements: Int
let timeElapsed: TimeInterval
}

protocol JSONCodable: Codable {
}
extension JSONCodable{
func jsonData() throws -> Data {
return try JSONEncoder().encode(self)
}
}

extension Data {
var prettyPrintedJSONString: NSString? { /// NSString gives us a nice sanitized debugDescription
guard let object = try? JSONSerialization.jsonObject(with: self, options: []),
let data = try? JSONSerialization.data(withJSONObject: object, options: [.prettyPrinted, .sortedKeys]),
let prettyPrintedString = NSString(data: data, encoding: String.Encoding.utf8.rawValue) else { return nil }

return prettyPrintedString
}
}

// MARK: - SystemMemoryChecker
@available(macOS 13, iOS 16, watchOS 10, visionOS 1, *)
class SystemMemoryChecker: NSObject{

static func getMemoryUsed() -> UInt64 {
// The `TASK_VM_INFO_COUNT` and `TASK_VM_INFO_REV1_COUNT` macros are too
// complex for the Swift C importer, so we have to define them ourselves.
let TASK_VM_INFO_COUNT = mach_msg_type_number_t(MemoryLayout<task_vm_info_data_t>.size / MemoryLayout<integer_t>.size)
guard let offset = MemoryLayout.offset(of: \task_vm_info_data_t.min_address) else {return 0}
let TASK_VM_INFO_REV1_COUNT = mach_msg_type_number_t(offset / MemoryLayout<integer_t>.size)
var info = task_vm_info_data_t()
var count = TASK_VM_INFO_COUNT
let kr = withUnsafeMutablePointer(to: &info) { infoPtr in
infoPtr.withMemoryRebound(to: integer_t.self, capacity: Int(count)) { intPtr in
task_info(mach_task_self_, task_flavor_t(TASK_VM_INFO), intPtr, &count)
}
}
guard
kr == KERN_SUCCESS,
count >= TASK_VM_INFO_REV1_COUNT
else { return 0}

let usedBytes = Float(info.phys_footprint)
let usedBytesInt: UInt64 = UInt64(usedBytes)
let usedMB = usedBytesInt / 1024 / 1024
return usedMB
}
}
157 changes: 157 additions & 0 deletions Tests/WhisperKitTests/RegressionTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import CoreML
import Hub
@testable import WhisperKit
import XCTest

@available(macOS 13, iOS 16, watchOS 10, visionOS 1, *)
final class RegressionTests: XCTestCase {

var audioFileURL: URL?

override func setUp() {
super.setUp()

if self.audioFileURL == nil{
let expectation = XCTestExpectation(description: "Download test audio")
downloadTestAudio { success in
if success {
expectation.fulfill()
} else {
XCTFail("Downloading audio file for testing failed")
}
}
// Wait for the expectation with a timeout
wait(for: [expectation], timeout: 30)
}
}

func downloadTestAudio(completion: @escaping (Bool) -> Void) {
Task {
do {
let earnings22CompressedDataset = Hub.Repo(id: "argmaxinc/whisperkit-test-data", type: .datasets)
let tempPath = FileManager.default.temporaryDirectory
let downloadBase = tempPath.appending(component: "huggingface")
let hubApi = HubApi(downloadBase: downloadBase)
let fileURL = try await hubApi.snapshot(from: earnings22CompressedDataset, matching: ["4484146.mp3"])
self.audioFileURL = fileURL.appending(component: "4484146.mp3")
completion(true)
} catch {
XCTFail("Async setup failed with error: \(error)")
completion(false)
}
}
}

func testAndMeasureModelPerformance(model: String, device: String) async throws{
let audioFilePath = try XCTUnwrap(
self.audioFileURL?.path(),
"Audio file not found"
)

let startTime = Date()
let iso8601DateTimeString = ISO8601DateFormatter().string(from: Date())

var currentMemoryValues = [Float]()
var currentTPSValues = [Float]()

let memoryStats = MemoryStats(
measurements: [], units: "MB",
totalNumberOfMeasurements: 0,
preTranscribeMemory: -1,
postTranscribeMemory: -1
)
let latencyStats = LatencyStats(
measurements: [], units: "Tokens/Sec",
totalNumberOfMeasurements: 0
)
var count: Int = 0

let callback = {
(result:TranscriptionProgress) -> Bool in
count += 1
let currentMemory = SystemMemoryChecker.getMemoryUsed()
let currentTPS = result.timings.tokensPerSecond
if currentMemory != 0{
currentMemoryValues.append(Float(currentMemory))
}
if !currentTPS.isNaN{
currentTPSValues.append(Float(currentTPS))
}
if count % 100 == 1{
let timeElapsed = Date().timeIntervalSince(startTime)
memoryStats.measure(from: currentMemoryValues, timeElapsed: timeElapsed)
latencyStats.measure(from: currentTPSValues, timeElapsed: timeElapsed)
currentMemoryValues = []
currentTPSValues = []
}
return true
}

let whisperKit = try await WhisperKit(model: model)
memoryStats.preTranscribeMemory = Float(SystemMemoryChecker.getMemoryUsed())

let transcriptionResult = try await XCTUnwrapAsync(
await whisperKit.transcribe(audioPath: audioFilePath, callback: callback),
"Transcription failed"
)
XCTAssert(transcriptionResult.text.isEmpty == false, "Transcription failed")

memoryStats.postTranscribeMemory = Float(SystemMemoryChecker.getMemoryUsed())
let testInfo = TestInfo(
device: device,
audioFile: audioFilePath,
model: model,
date: startTime.formatted(Date.ISO8601FormatStyle().dateSeparator(.dash)),
timeElapsedInSeconds: Date().timeIntervalSince(startTime),
timings: transcriptionResult.timings,
transcript: transcriptionResult.text
)
let json = RegressionStats(testInfo: testInfo, memoryStats: memoryStats, latencyStats: latencyStats)
do{
let attachment = try XCTAttachment(data: json.jsonData(), uniformTypeIdentifier: "json")
attachment.lifetime = .keepAlways
attachment.name = "\(device)_\(model)_\(iso8601DateTimeString).json"
add(attachment)
}
catch{
XCTFail("Failed with error: \(error)")
}
}

func testRegressionAndLatencyForAllModels() async throws{
var allModels: [String] = []
var failureInfo: [String:String] = [:]
var currentDevice = WhisperKit.deviceName()
let iso8601DateTimeString = ISO8601DateFormatter().string(from: Date())

#if os(macOS) && arch(arm64)
currentDevice = Process.processor
#endif

do{
allModels = try await WhisperKit.fetchAvailableModels()
}
catch{
XCTFail("Failed to fetch available models: \(error.localizedDescription)")
}

for model in allModels{
do{
try await testAndMeasureModelPerformance(model: model, device: currentDevice)
}
catch{
failureInfo[model] = error.localizedDescription
}
}
let testReport = TestReport(device: currentDevice, modelsTested: allModels, failureInfo: failureInfo)
do{
let attachment = try XCTAttachment(data: testReport.jsonData(), uniformTypeIdentifier: "json")
attachment.lifetime = .keepAlways
attachment.name = "\(currentDevice)_summary_\(iso8601DateTimeString).json"
add(attachment)
}catch{
XCTFail("Failed with error: \(error)")
}
}

}

0 comments on commit d3a9a99

Please sign in to comment.