Skip to content

Commit

Permalink
Example app VAD default + memory reduction (#217)
Browse files Browse the repository at this point in the history
* Release memory when transcribing single files

Co-authored-by: keleftheriou <keleftheriou@users.noreply.github.com>

* Add method to load from file into float array iteratively

- Reduces peak memory by doing the array conversion while loading in chunks so the array copy size is lower
- Previously copied the entire buffer which spiked the memory 2x

* Fix leak

* Use vad by default in examples

* Fix vad thread issue

* Fix unused warning

* Revert change to early stop callback

* Fix warnings

- Optional cli commands are deprecated
- @_disfavoredOverload required @available to prevent infinite loop

* PR review - simplify early stop test logic

Co-authored-by: Andrey Leonov <aleonov@gmail.com>

* Cleanup from review

---------

Co-authored-by: keleftheriou <keleftheriou@users.noreply.github.com>
Co-authored-by: Andrey Leonov <aleonov@gmail.com>
  • Loading branch information
3 people authored Oct 8, 2024
1 parent bfb1316 commit e3e21d4
Show file tree
Hide file tree
Showing 9 changed files with 117 additions and 55 deletions.
11 changes: 8 additions & 3 deletions Examples/WhisperAX/WhisperAX/Views/ContentView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ struct ContentView: View {
@AppStorage("silenceThreshold") private var silenceThreshold: Double = 0.3
@AppStorage("useVAD") private var useVAD: Bool = true
@AppStorage("tokenConfirmationsNeeded") private var tokenConfirmationsNeeded: Double = 2
@AppStorage("chunkingStrategy") private var chunkingStrategy: ChunkingStrategy = .none
@AppStorage("concurrentWorkerCount") private var concurrentWorkerCount: Int = 4
@AppStorage("chunkingStrategy") private var chunkingStrategy: ChunkingStrategy = .vad
@AppStorage("encoderComputeUnits") private var encoderComputeUnits: MLComputeUnits = .cpuAndNeuralEngine
@AppStorage("decoderComputeUnits") private var decoderComputeUnits: MLComputeUnits = .cpuAndNeuralEngine

Expand Down Expand Up @@ -1269,12 +1270,15 @@ struct ContentView: View {

func transcribeCurrentFile(path: String) async throws {
// Load and convert buffer in a limited scope
Logging.debug("Loading audio file: \(path)")
let loadingStart = Date()
let audioFileSamples = try await Task {
try autoreleasepool {
let audioFileBuffer = try AudioProcessor.loadAudio(fromPath: path)
return AudioProcessor.convertBufferToArray(buffer: audioFileBuffer)
return try AudioProcessor.loadAudioAsFloatArray(fromPath: path)
}
}.value
Logging.debug("Loaded audio file in \(Date().timeIntervalSince(loadingStart)) seconds")


let transcription = try await transcribeAudioSamples(audioFileSamples)

Expand Down Expand Up @@ -1316,6 +1320,7 @@ struct ContentView: View {
withoutTimestamps: !enableTimestamps,
wordTimestamps: true,
clipTimestamps: seekClip,
concurrentWorkerCount: concurrentWorkerCount,
chunkingStrategy: chunkingStrategy
)

Expand Down
1 change: 0 additions & 1 deletion Sources/WhisperKit/Core/Audio/AudioChunker.swift
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ open class VADAudioChunker: AudioChunking {
// Typically this will be the full audio file, unless seek points are explicitly provided
var startIndex = seekClipStart
while startIndex < seekClipEnd - windowPadding {
let currentFrameLength = audioArray.count
guard startIndex >= 0 && startIndex < audioArray.count else {
throw WhisperError.audioProcessingFailed("startIndex is outside the buffer size")
}
Expand Down
67 changes: 58 additions & 9 deletions Sources/WhisperKit/Core/Audio/AudioProcessor.swift
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,6 @@ public extension AudioProcessing {
}

static func padOrTrimAudio(fromArray audioArray: [Float], startAt startIndex: Int = 0, toLength frameLength: Int = 480_000, saveSegment: Bool = false) -> MLMultiArray? {
let currentFrameLength = audioArray.count

guard startIndex >= 0 && startIndex < audioArray.count else {
Logging.error("startIndex is outside the buffer size")
return nil
Expand Down Expand Up @@ -197,7 +195,15 @@ public class AudioProcessor: NSObject, AudioProcessing {

let audioFileURL = URL(fileURLWithPath: audioFilePath)
let audioFile = try AVAudioFile(forReading: audioFileURL, commonFormat: .pcmFormatFloat32, interleaved: false)
return try loadAudio(fromFile: audioFile, startTime: startTime, endTime: endTime, maxReadFrameSize: maxReadFrameSize)
}

public static func loadAudio(
fromFile audioFile: AVAudioFile,
startTime: Double? = 0,
endTime: Double? = nil,
maxReadFrameSize: AVAudioFrameCount? = nil
) throws -> AVAudioPCMBuffer {
let sampleRate = audioFile.fileFormat.sampleRate
let channelCount = audioFile.fileFormat.channelCount
let frameLength = AVAudioFrameCount(audioFile.length)
Expand Down Expand Up @@ -243,13 +249,56 @@ public class AudioProcessor: NSObject, AudioProcessing {
}
}

public static func loadAudioAsFloatArray(
fromPath audioFilePath: String,
startTime: Double? = 0,
endTime: Double? = nil
) throws -> [Float] {
guard FileManager.default.fileExists(atPath: audioFilePath) else {
throw WhisperError.loadAudioFailed("Resource path does not exist \(audioFilePath)")
}

let audioFileURL = URL(fileURLWithPath: audioFilePath)
let audioFile = try AVAudioFile(forReading: audioFileURL, commonFormat: .pcmFormatFloat32, interleaved: false)
let inputSampleRate = audioFile.fileFormat.sampleRate
let inputFrameCount = AVAudioFrameCount(audioFile.length)
let inputDuration = Double(inputFrameCount) / inputSampleRate

let start = startTime ?? 0
let end = min(endTime ?? inputDuration, inputDuration)

// Load 10m of audio at a time to reduce peak memory while converting
// Particularly impactful for large audio files
let chunkDuration: Double = 60 * 10
var currentTime = start
var result: [Float] = []

while currentTime < end {
let chunkEnd = min(currentTime + chunkDuration, end)

try autoreleasepool {
let buffer = try loadAudio(
fromFile: audioFile,
startTime: currentTime,
endTime: chunkEnd
)

let floatArray = Self.convertBufferToArray(buffer: buffer)
result.append(contentsOf: floatArray)
}

currentTime = chunkEnd
}

return result
}

public static func loadAudio(at audioPaths: [String]) async -> [Result<[Float], Swift.Error>] {
await withTaskGroup(of: [(index: Int, result: Result<[Float], Swift.Error>)].self) { taskGroup -> [Result<[Float], Swift.Error>] in
for (index, audioPath) in audioPaths.enumerated() {
taskGroup.addTask {
do {
let audioBuffer = try AudioProcessor.loadAudio(fromPath: audioPath)
let audio = AudioProcessor.convertBufferToArray(buffer: audioBuffer)
let audio = try AudioProcessor.loadAudioAsFloatArray(fromPath: audioPath)
return [(index: index, result: .success(audio))]
} catch {
return [(index: index, result: .failure(error))]
Expand Down Expand Up @@ -280,10 +329,10 @@ public class AudioProcessor: NSObject, AudioProcessing {
frameCount: AVAudioFrameCount? = nil,
maxReadFrameSize: AVAudioFrameCount = Constants.defaultAudioReadFrameSize
) -> AVAudioPCMBuffer? {
let inputFormat = audioFile.fileFormat
let inputSampleRate = audioFile.fileFormat.sampleRate
let inputStartFrame = audioFile.framePosition
let inputFrameCount = frameCount ?? AVAudioFrameCount(audioFile.length)
let inputDuration = Double(inputFrameCount) / inputFormat.sampleRate
let inputDuration = Double(inputFrameCount) / inputSampleRate
let endFramePosition = min(inputStartFrame + AVAudioFramePosition(inputFrameCount), audioFile.length + 1)

guard let outputFormat = AVAudioFormat(standardFormatWithSampleRate: sampleRate, channels: channelCount) else {
Expand All @@ -305,8 +354,8 @@ public class AudioProcessor: NSObject, AudioProcessing {
let remainingFrames = AVAudioFrameCount(endFramePosition - audioFile.framePosition)
let framesToRead = min(remainingFrames, maxReadFrameSize)

let currentPositionInSeconds = Double(audioFile.framePosition) / inputFormat.sampleRate
let nextPositionInSeconds = (Double(audioFile.framePosition) + Double(framesToRead)) / inputFormat.sampleRate
let currentPositionInSeconds = Double(audioFile.framePosition) / inputSampleRate
let nextPositionInSeconds = (Double(audioFile.framePosition) + Double(framesToRead)) / inputSampleRate
Logging.debug("Resampling \(String(format: "%.2f", currentPositionInSeconds))s - \(String(format: "%.2f", nextPositionInSeconds))s")

do {
Expand Down Expand Up @@ -644,7 +693,7 @@ public class AudioProcessor: NSObject, AudioProcessing {
&propertySize,
&name
)
if status == noErr, let deviceNameCF = name?.takeUnretainedValue() as String? {
if status == noErr, let deviceNameCF = name?.takeRetainedValue() as String? {
deviceName = deviceNameCF
}

Expand Down
6 changes: 4 additions & 2 deletions Sources/WhisperKit/Core/TextDecoder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -591,9 +591,11 @@ open class TextDecoder: TextDecoding, WhisperMLModel {
var hasAlignment = false
var isFirstTokenLogProbTooLow = false
let windowUUID = UUID()
DispatchQueue.global().async { [weak self] in
Task { [weak self] in
guard let self = self else { return }
self.shouldEarlyStop[windowUUID] = false
await MainActor.run {
self.shouldEarlyStop[windowUUID] = false
}
}
for tokenIndex in prefilledIndex..<loopCount {
let loopStart = Date()
Expand Down
11 changes: 10 additions & 1 deletion Sources/WhisperKit/Core/Utils.swift
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,14 @@ public extension String {
}

extension AVAudioPCMBuffer {
/// Converts the buffer to a float array
func asFloatArray() throws -> [Float] {
guard let data = floatChannelData?.pointee else {
throw WhisperError.audioProcessingFailed("Error converting audio, missing floatChannelData")
}
return Array(UnsafeBufferPointer(start: data, count: Int(frameLength)))
}

/// Appends the contents of another buffer to the current buffer
func appendContents(of buffer: AVAudioPCMBuffer) -> Bool {
return appendContents(of: buffer, startingFrame: 0, frameCount: buffer.frameLength)
Expand Down Expand Up @@ -446,8 +454,9 @@ public func modelSupport(for deviceName: String, from config: ModelSupportConfig
/// Deprecated
@available(*, deprecated, message: "Subject to removal in a future version. Use modelSupport(for:from:) -> ModelSupport instead.")
@_disfavoredOverload
@available(macOS 13, iOS 16, watchOS 10, visionOS 1, *)
public func modelSupport(for deviceName: String, from config: ModelSupportConfig? = nil) -> (default: String, disabled: [String]) {
let modelSupport = modelSupport(for: deviceName, from: config)
let modelSupport: ModelSupport = modelSupport(for: deviceName, from: config)
return (modelSupport.default, modelSupport.disabled)
}

Expand Down
49 changes: 26 additions & 23 deletions Sources/WhisperKit/Core/WhisperKit.swift
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,8 @@ open class WhisperKit {
open func detectLanguage(
audioPath: String
) async throws -> (language: String, langProbs: [String: Float]) {
let audioBuffer = try AudioProcessor.loadAudio(fromPath: audioPath)
// Only need the first 30s for language detection
let audioBuffer = try AudioProcessor.loadAudio(fromPath: audioPath, endTime: 30.0)
let audioArray = AudioProcessor.convertBufferToArray(buffer: audioBuffer)
return try await detectLangauge(audioArray: audioArray)
}
Expand Down Expand Up @@ -721,15 +722,17 @@ open class WhisperKit {
callback: TranscriptionCallback = nil
) async throws -> [TranscriptionResult] {
// Process input audio file into audio samples
let loadAudioStart = Date()
let audioBuffer = try AudioProcessor.loadAudio(fromPath: audioPath)
let loadTime = Date().timeIntervalSince(loadAudioStart)
let audioArray = try await withThrowingTaskGroup(of: [Float].self) { group -> [Float] in
let convertAudioStart = Date()
defer {
let convertTime = Date().timeIntervalSince(convertAudioStart)
currentTimings.audioLoading = convertTime
Logging.debug("Audio loading and convert time: \(convertTime)")
logCurrentMemoryUsage("Audio Loading and Convert")
}

let convertAudioStart = Date()
let audioArray = AudioProcessor.convertBufferToArray(buffer: audioBuffer)
let convertTime = Date().timeIntervalSince(convertAudioStart)
currentTimings.audioLoading = loadTime + convertTime
Logging.debug("Audio loading time: \(loadTime), Audio convert time: \(convertTime)")
return try AudioProcessor.loadAudioAsFloatArray(fromPath: audioPath)
}

let transcribeResults: [TranscriptionResult] = try await transcribe(
audioArray: audioArray,
Expand Down Expand Up @@ -837,23 +840,23 @@ open class WhisperKit {
throw WhisperError.tokenizerUnavailable()
}

let childProgress = Progress()
progress.totalUnitCount += 1
progress.addChild(childProgress, withPendingUnitCount: 1)

let transcribeTask = TranscribeTask(
currentTimings: currentTimings,
progress: childProgress,
audioEncoder: audioEncoder,
featureExtractor: featureExtractor,
segmentSeeker: segmentSeeker,
textDecoder: textDecoder,
tokenizer: tokenizer
)

do {
try Task.checkCancellation()

let childProgress = Progress()
progress.totalUnitCount += 1
progress.addChild(childProgress, withPendingUnitCount: 1)

let transcribeTask = TranscribeTask(
currentTimings: currentTimings,
progress: childProgress,
audioEncoder: audioEncoder,
featureExtractor: featureExtractor,
segmentSeeker: segmentSeeker,
textDecoder: textDecoder,
tokenizer: tokenizer
)

let transcribeTaskResult = try await transcribeTask.run(
audioArray: audioArray,
decodeOptions: decodeOptions,
Expand Down
8 changes: 4 additions & 4 deletions Sources/WhisperKitCLI/CLIArguments.swift
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,9 @@ struct CLIArguments: ParsableArguments {
@Flag(help: "Simulate streaming transcription using the input audio file")
var streamSimulated: Bool = false

@Option(help: "Maximum concurrent inference, might be helpful when processing more than 1 audio file at the same time. 0 means unlimited")
var concurrentWorkerCount: Int = 0
@Option(help: "Maximum concurrent inference, might be helpful when processing more than 1 audio file at the same time. 0 means unlimited. Default: 4")
var concurrentWorkerCount: Int = 4

@Option(help: "Chunking strategy for audio processing, `nil` means no chunking, `vad` means using voice activity detection")
var chunkingStrategy: String? = nil
@Option(help: "Chunking strategy for audio processing, `none` means no chunking, `vad` means using voice activity detection. Default: `vad`")
var chunkingStrategy: String = "vad"
}
14 changes: 3 additions & 11 deletions Sources/WhisperKitCLI/TranscribeCLI.swift
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,8 @@ struct TranscribeCLI: AsyncParsableCommand {
cliArguments.audioPath = audioFiles.map { audioFolder + "/" + $0 }
}

if let chunkingStrategyRaw = cliArguments.chunkingStrategy {
if ChunkingStrategy(rawValue: chunkingStrategyRaw) == nil {
throw ValidationError("Wrong chunking strategy \"\(chunkingStrategyRaw)\", valid strategies: \(ChunkingStrategy.allCases.map { $0.rawValue })")
}
if ChunkingStrategy(rawValue: cliArguments.chunkingStrategy) == nil {
throw ValidationError("Wrong chunking strategy \"\(cliArguments.chunkingStrategy)\", valid strategies: \(ChunkingStrategy.allCases.map { $0.rawValue })")
}
}

Expand Down Expand Up @@ -318,12 +316,6 @@ struct TranscribeCLI: AsyncParsableCommand {
}

private func decodingOptions(task: DecodingTask) -> DecodingOptions {
let chunkingStrategy: ChunkingStrategy? =
if let chunkingStrategyRaw = cliArguments.chunkingStrategy {
ChunkingStrategy(rawValue: chunkingStrategyRaw)
} else {
nil
}
return DecodingOptions(
verbose: cliArguments.verbose,
task: task,
Expand All @@ -344,7 +336,7 @@ struct TranscribeCLI: AsyncParsableCommand {
firstTokenLogProbThreshold: cliArguments.firstTokenLogProbThreshold,
noSpeechThreshold: cliArguments.noSpeechThreshold ?? 0.6,
concurrentWorkerCount: cliArguments.concurrentWorkerCount,
chunkingStrategy: chunkingStrategy
chunkingStrategy: ChunkingStrategy(rawValue: cliArguments.chunkingStrategy)
)
}

Expand Down
5 changes: 4 additions & 1 deletion Tests/WhisperKitTests/UnitTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -548,9 +548,11 @@ final class UnitTests: XCTestCase {
}

func testDecodingEarlyStopping() async throws {
let earlyStopTokenCount = 10
let options = DecodingOptions()
let continuationCallback: TranscriptionCallback = { (progress: TranscriptionProgress) -> Bool? in
false
// Stop after only 10 tokens (full test audio contains 16)
return progress.tokens.count <= earlyStopTokenCount
}

let result = try await XCTUnwrapAsync(
Expand All @@ -576,6 +578,7 @@ final class UnitTests: XCTestCase {
XCTAssertNotNil(resultWithWait)
let tokenCountWithWait = resultWithWait.segments.flatMap { $0.tokens }.count
let decodingTimePerTokenWithWait = resultWithWait.timings.decodingLoop / Double(tokenCountWithWait)
Logging.debug("Decoding loop without wait: \(result.timings.decodingLoop), with wait: \(resultWithWait.timings.decodingLoop)")

// Assert that the decoding predictions per token are not slower with the waiting
XCTAssertEqual(decodingTimePerTokenWithWait, decodingTimePerToken, accuracy: decodingTimePerToken, "Decoding predictions per token should not be significantly slower with waiting")
Expand Down

0 comments on commit e3e21d4

Please sign in to comment.