Skip to content

Commit

Permalink
Fix resampling large files (#183)
Browse files Browse the repository at this point in the history
* Update resampling logic to handle chunking properly

* Cleanup logging

* Optimize memory usage when resampling

* Add filter to input prompt text

* Correct timestamp filter logic for #170

* Filter out zero length segments

- when calculating word timestamps
- resolves #170

* Add method for async audio loading

* Fix async load audio function

* Fix tests

* Fix tests

* Fix tests

* Revert timestamp filter changes

* Temporarily remove xcpretty for tests

* Check suspected test crash

* Remove errant test case for japanese options

* Add bigger range for early stopping test

* Reset progress between runs

* Fix progress resetting and improve example app transcription handling

* Update tests

* Minimize crash risk for early stop checks

* Fix finalize text

* Add source text to language label
  • Loading branch information
ZachNagengast authored Jul 11, 2024
1 parent 3186ca6 commit 02763ca
Show file tree
Hide file tree
Showing 13 changed files with 536 additions and 169 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/unit-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -77,4 +77,4 @@ jobs:
run: |
set -o pipefail
xcodebuild clean build-for-testing -scheme whisperkit-Package -destination '${{ matrix.run-config['clean-destination'] }}' | xcpretty
xcodebuild test -only-testing WhisperKitTests/UnitTests -scheme whisperkit-Package -destination '${{ matrix.run-config['test-destination'] }}' | xcpretty
xcodebuild test -only-testing WhisperKitTests/UnitTests -scheme whisperkit-Package -destination '${{ matrix.run-config['test-destination'] }}'
4 changes: 2 additions & 2 deletions Examples/WhisperAX/WhisperAX.xcodeproj/project.pbxproj
Original file line number Diff line number Diff line change
Expand Up @@ -890,7 +890,7 @@
LD_RUNPATH_SEARCH_PATHS = "@executable_path/Frameworks";
"LD_RUNPATH_SEARCH_PATHS[sdk=macosx*]" = "@executable_path/../Frameworks";
MACOSX_DEPLOYMENT_TARGET = 14.0;
MARKETING_VERSION = 0.3.1;
MARKETING_VERSION = 0.3.2;
PRODUCT_BUNDLE_IDENTIFIER = "com.argmax.whisperkit.WhisperAX${DEVELOPMENT_TEAM}";
PRODUCT_NAME = "$(TARGET_NAME)";
SDKROOT = auto;
Expand Down Expand Up @@ -936,7 +936,7 @@
LD_RUNPATH_SEARCH_PATHS = "@executable_path/Frameworks";
"LD_RUNPATH_SEARCH_PATHS[sdk=macosx*]" = "@executable_path/../Frameworks";
MACOSX_DEPLOYMENT_TARGET = 14.0;
MARKETING_VERSION = 0.3.1;
MARKETING_VERSION = 0.3.2;
PRODUCT_BUNDLE_IDENTIFIER = com.argmax.whisperkit.WhisperAX;
PRODUCT_NAME = "$(TARGET_NAME)";
SDKROOT = auto;
Expand Down
103 changes: 76 additions & 27 deletions Examples/WhisperAX/WhisperAX/Views/ContentView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ struct ContentView: View {
@State private var showAdvancedOptions: Bool = false
@State private var transcriptionTask: Task<Void, Never>? = nil
@State private var selectedCategoryId: MenuItem.ID?
@State private var transcribeFileTask: Task<Void, Never>? = nil
@State private var transcribeTask: Task<Void, Never>? = nil

struct MenuItem: Identifiable, Hashable {
var id = UUID()
Expand All @@ -122,7 +122,7 @@ struct ContentView: View {
// MARK: Views

func resetState() {
transcribeFileTask?.cancel()
transcribeTask?.cancel()
isRecording = false
isTranscribing = false
whisperKit?.audioProcessor.stopRecording()
Expand Down Expand Up @@ -311,15 +311,27 @@ struct ContentView: View {
.textSelection(.enabled)
.padding()
if let whisperKit,
!isRecording,
!isTranscribing,
whisperKit.progress.fractionCompleted > 0,
!isStreamMode,
isTranscribing,
let task = transcribeTask,
!task.isCancelled,
whisperKit.progress.fractionCompleted < 1
{
ProgressView(whisperKit.progress)
.progressViewStyle(.linear)
.labelsHidden()
.padding(.horizontal)
HStack {
ProgressView(whisperKit.progress)
.progressViewStyle(.linear)
.labelsHidden()
.padding(.horizontal)

Button {
transcribeTask?.cancel()
transcribeTask = nil
} label: {
Image(systemName: "xmark.circle.fill")
.foregroundColor(.secondary)
}
.buttonStyle(BorderlessButtonStyle())
}
}
}
}
Expand Down Expand Up @@ -706,7 +718,7 @@ struct ContentView: View {
}
.disabled(!(whisperKit?.modelVariant.isMultilingual ?? false))
} label: {
Label("Language", systemImage: "globe")
Label("Source Language", systemImage: "globe")
}
.padding(.horizontal)
.padding(.top)
Expand Down Expand Up @@ -1149,12 +1161,14 @@ struct ContentView: View {
func transcribeFile(path: String) {
resetState()
whisperKit?.audioProcessor = AudioProcessor()
self.transcribeFileTask = Task {
self.transcribeTask = Task {
isTranscribing = true
do {
try await transcribeCurrentFile(path: path)
} catch {
print("File selection error: \(error.localizedDescription)")
}
isTranscribing = false
}
}

Expand Down Expand Up @@ -1218,21 +1232,49 @@ struct ContentView: View {

// If not looping, transcribe the full buffer
if !loop {
Task {
self.transcribeTask = Task {
isTranscribing = true
do {
try await transcribeCurrentBuffer()
} catch {
print("Error: \(error.localizedDescription)")
}
finalizeText()
isTranscribing = false
}
}

finalizeText()
}

func finalizeText() {
// Finalize unconfirmed text
Task {
await MainActor.run {
if hypothesisText != "" {
confirmedText += hypothesisText
hypothesisText = ""
}

if unconfirmedSegments.count > 0 {
confirmedSegments.append(contentsOf: unconfirmedSegments)
unconfirmedSegments = []
}
}
}
}

// MARK: - Transcribe Logic

func transcribeCurrentFile(path: String) async throws {
let audioFileBuffer = try AudioProcessor.loadAudio(fromPath: path)
let audioFileSamples = AudioProcessor.convertBufferToArray(buffer: audioFileBuffer)
// Load and convert buffer in a limited scope
let audioFileSamples = try await Task {
try autoreleasepool {
let audioFileBuffer = try AudioProcessor.loadAudio(fromPath: path)
return AudioProcessor.convertBufferToArray(buffer: audioFileBuffer)
}
}.value

let transcription = try await transcribeAudioSamples(audioFileSamples)

await MainActor.run {
Expand All @@ -1258,7 +1300,7 @@ struct ContentView: View {

let languageCode = Constants.languages[selectedLanguage, default: Constants.defaultLanguageCode]
let task: DecodingTask = selectedTask == "transcribe" ? .transcribe : .translate
let seekClip: [Float] = []
let seekClip: [Float] = [lastConfirmedSegmentEndSeconds]

let options = DecodingOptions(
verbose: true,
Expand All @@ -1271,6 +1313,7 @@ struct ContentView: View {
usePrefillCache: enableCachePrefill,
skipSpecialTokens: !enableSpecialCharacters,
withoutTimestamps: !enableTimestamps,
wordTimestamps: true,
clipTimestamps: seekClip,
chunkingStrategy: chunkingStrategy
)
Expand All @@ -1279,7 +1322,7 @@ struct ContentView: View {
let decodingCallback: ((TranscriptionProgress) -> Bool?) = { (progress: TranscriptionProgress) in
DispatchQueue.main.async {
let fallbacks = Int(progress.timings.totalDecodingFallbacks)
let chunkId = progress.windowId
let chunkId = isStreamMode ? 0 : progress.windowId

// First check if this is a new window for the same chunk, append if so
var updatedChunk = (chunkText: [progress.text], fallbacks: fallbacks)
Expand All @@ -1292,7 +1335,7 @@ struct ContentView: View {
// This is either a new window or a fallback (only in streaming mode)
if fallbacks == currentChunk.fallbacks && isStreamMode {
// New window (since fallbacks havent changed)
updatedChunk.chunkText = currentChunk.chunkText + [progress.text]
updatedChunk.chunkText = [updatedChunk.chunkText.first ?? "" + progress.text]
} else {
// Fallback, overwrite the previous bad text
updatedChunk.chunkText[currentChunk.chunkText.endIndex - 1] = progress.text
Expand Down Expand Up @@ -1419,6 +1462,7 @@ struct ContentView: View {
// Run realtime transcribe using word timestamps for segmentation
let transcription = try await transcribeEagerMode(Array(currentBuffer))
await MainActor.run {
currentText = ""
self.tokensPerSecond = transcription?.timings.tokensPerSecond ?? 0
self.firstTokenTime = transcription?.timings.firstTokenTime ?? 0
self.pipelineStart = transcription?.timings.pipelineStart ?? 0
Expand Down Expand Up @@ -1464,10 +1508,13 @@ struct ContentView: View {
// Update lastConfirmedSegmentEnd based on the last confirmed segment
if let lastConfirmedSegment = confirmedSegmentsArray.last, lastConfirmedSegment.end > lastConfirmedSegmentEndSeconds {
lastConfirmedSegmentEndSeconds = lastConfirmedSegment.end
print("Last confirmed segment end: \(lastConfirmedSegmentEndSeconds)")

// Add confirmed segments to the confirmedSegments array
if !self.confirmedSegments.contains(confirmedSegmentsArray) {
self.confirmedSegments.append(contentsOf: confirmedSegmentsArray)
for segment in confirmedSegmentsArray {
if !self.confirmedSegments.contains(segment: segment) {
self.confirmedSegments.append(segment)
}
}
}

Expand Down Expand Up @@ -1584,18 +1631,20 @@ struct ContentView: View {
eagerResults.append(transcription)
}
}

await MainActor.run {
let finalWords = confirmedWords.map { $0.word }.joined()
confirmedText = finalWords

// Accept the final hypothesis because it is the last of the available audio
let lastHypothesis = lastAgreedWords + findLongestDifferentSuffix(prevWords, hypothesisWords)
hypothesisText = lastHypothesis.map { $0.word }.joined()
}
} catch {
Logging.error("[EagerMode] Error: \(error)")
finalizeText()
}

await MainActor.run {
let finalWords = confirmedWords.map { $0.word }.joined()
confirmedText = finalWords

// Accept the final hypothesis because it is the last of the available audio
let lastHypothesis = lastAgreedWords + findLongestDifferentSuffix(prevWords, hypothesisWords)
hypothesisText = lastHypothesis.map { $0.word }.joined()
}

let mergedResult = mergeTranscriptionResults(eagerResults, confirmedWords: confirmedWords)

Expand Down
Loading

0 comments on commit 02763ca

Please sign in to comment.