Skip to content

Commit

Permalink
Fix progress when using VAD chunking (#179)
Browse files Browse the repository at this point in the history
Co-authored-by: Zach Nagengast <zacharynagengast@gmail.com>
  • Loading branch information
finnvoor and ZachNagengast authored Jul 3, 2024
1 parent 9773663 commit b3f12fc
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 1 deletion.
5 changes: 4 additions & 1 deletion Sources/WhisperKit/Core/WhisperKit.swift
Original file line number Diff line number Diff line change
Expand Up @@ -769,9 +769,12 @@ open class WhisperKit {
}
try Task.checkCancellation()

let childProgress = Progress()
progress.totalUnitCount += 1
progress.addChild(childProgress, withPendingUnitCount: 1)
let transcribeTask = TranscribeTask(
currentTimings: currentTimings,
progress: progress,
progress: childProgress,
audioEncoder: audioEncoder,
featureExtractor: featureExtractor,
segmentSeeker: segmentSeeker,
Expand Down
9 changes: 9 additions & 0 deletions Tests/WhisperKitTests/TestUtils.swift
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import CoreML
import Combine
import Foundation
@testable import WhisperKit
import XCTest
Expand Down Expand Up @@ -274,3 +275,11 @@ extension Collection where Element == TranscriptionResult {
flatMap(\.segments)
}
}

extension Publisher {
public func withPrevious() -> AnyPublisher<(previous: Output?, current: Output), Failure> {
scan((Output?, Output)?.none) { ($0?.1, $1) }
.compactMap { $0 }
.eraseToAnyPublisher()
}
}
19 changes: 19 additions & 0 deletions Tests/WhisperKitTests/UnitTests.swift
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// For licensing see accompanying LICENSE.md file.
// Copyright © 2024 Argmax, Inc. All rights reserved.

import Combine
import AVFoundation
import CoreML
import Hub
Expand Down Expand Up @@ -1178,6 +1179,24 @@ final class UnitTests: XCTestCase {
XCTAssertTrue(chunkedResult.text.normalized.contains("But then came my 90 page senior".normalized), "Expected text not found in \(chunkedResult.text.normalized)")
}

func testVADProgress() async throws {
let pipe = try await WhisperKit(model: "tiny.en")

let cancellable: AnyCancellable? = pipe.progress.publisher(for: \.fractionCompleted)
.removeDuplicates()
.withPrevious()
.sink { previous, current in
if let previous {
XCTAssertLessThan(previous, current)
}
}
_ = try await pipe.transcribe(
audioPath: Bundle.module.path(forResource: "ted_60", ofType: "m4a")!,
decodeOptions: .init(chunkingStrategy: .vad)
)
cancellable?.cancel()
}

// MARK: - Word Timestamp Tests

func testDynamicTimeWarpingSimpleMatrix() {
Expand Down

0 comments on commit b3f12fc

Please sign in to comment.