Skip to content

Commit

Permalink
Fixes and cleanup from early feedback (argmaxinc#15)
Browse files Browse the repository at this point in the history
* Fixes and cleanup from early feedback

* Formatting

* Update tests
  • Loading branch information
ZachNagengast authored Feb 5, 2024
1 parent e5ed2e9 commit 82c1fa1
Show file tree
Hide file tree
Showing 11 changed files with 140 additions and 43 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/unit-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,9 @@ jobs:
path: Models
key: ${{ runner.os }}-models
- name: Build
run: xcodebuild build-for-testing -scheme whisperkit-Package -destination 'platform=macOS'
run: xcodebuild clean build-for-testing -scheme whisperkit-Package -destination 'platform=macOS'
- name: Run tests
run: |
set -o pipefail
xcodebuild test-without-building -only-testing WhisperKitTests/UnitTests -scheme whisperkit-Package -showdestinations
xcodebuild test-without-building -only-testing WhisperKitTests/UnitTests -scheme whisperkit-Package -destination "platform=macOS,arch=arm64" | xcpretty
11 changes: 8 additions & 3 deletions Examples/WhisperAX/WhisperAX/Views/ContentView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -157,19 +157,19 @@ struct ContentView: View {
#if os(macOS)
selectedCategoryId = menu.first(where: { $0.name == selectedTab })?.id
#endif

fetchModels()
}
}


// MARK: - Transcription

var transcriptionView: some View {
VStack {
ScrollView(.horizontal) {
HStack(spacing: 1) {
let startIndex = max(bufferEnergy.count - 300, 0)
ForEach(Array(bufferEnergy.enumerated())[startIndex...], id: \.offset) { index, energy in
ForEach(Array(bufferEnergy.enumerated())[startIndex...], id: \.element) { index, energy in
ZStack {
RoundedRectangle(cornerRadius: 2)
.frame(width: 2, height: CGFloat(energy) * 24)
Expand Down Expand Up @@ -660,7 +660,12 @@ struct ContentView: View {
}

localModels = WhisperKit.formatModelFiles(localModels)
availableModels = localModels
for model in localModels {
if !availableModels.contains(model),
!disabledModels.contains(model){
availableModels.append(model)
}
}

print("Found locally: \(localModels)")
print("Previously selected model: \(selectedModel)")
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ setup-model-repo:
cd $(MODEL_REPO_DIR) && git fetch --all && git reset --hard origin/main && git clean -fdx; \
else \
echo "Repository not found, initializing..."; \
GIT_LFS_SKIP_SMUDGE=1 git clone https://hf.co/$(MODEL_REPO) $(MODEL_REPO_DIR); \
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/$(MODEL_REPO) $(MODEL_REPO_DIR); \
fi

# Download all models
Expand Down
45 changes: 37 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,19 @@

<div align="center">

<a href="https://github.com/argmaxinc/WhisperKit#gh-light-mode-only">
<img src="https://github.com/argmaxinc/WhisperKit/assets/1981179/6ac3360b-2f5c-4392-a71a-05c5dda71093" alt="WhisperKit" width="20%" />
</a>

# WhisperKit

WhisperKit is a Swift package that integrates OpenAI's popular [Whisper](https://github.com/openai/whisper) speech recognition model with Apple's CoreML framework for efficient, local inference on Apple devices.
[![Unit Tests](https://github.com/argmaxinc/whisperkit/actions/workflows/unit-tests.yml/badge.svg)](https://github.com/argmaxinc/whisperkit/actions/workflows/unit-tests.yml)
[![Supported Swift Version](https://img.shields.io/endpoint?url=https%3A%2F%2Fswiftpackageindex.com%2Fapi%2Fpackages%2Fargmaxinc%2FWhisperKit%2Fbadge%3Ftype%3Dswift-versions)](https://swiftpackageindex.com/argmaxinc/WhisperKit) [![Supported Platforms](https://img.shields.io/endpoint?url=https%3A%2F%2Fswiftpackageindex.com%2Fapi%2Fpackages%2Fargmaxinc%2FWhisperKit%2Fbadge%3Ftype%3Dplatforms)](https://swiftpackageindex.com/argmaxinc/WhisperKit)
[![License](https://img.shields.io/github/license/argmaxinc/whisperkit?color=green)](LICENSE.md)

</div>

WhisperKit is a Swift package that integrates OpenAI's popular [Whisper](https://github.com/openai/whisper) speech recognition model with Apple's CoreML framework for efficient, local inference on Apple devices.

Check out the demo app on [TestFlight](https://testflight.apple.com/join/LPVOyJZW).

Expand All @@ -21,23 +34,28 @@ Check out the demo app on [TestFlight](https://testflight.apple.com/join/LPVOyJZ
- [Citation](#citation)

## Installation

WhisperKit can be integrated into your Swift project using the Swift Package Manager.

### Prerequisites

- macOS 14.0 or later.
- Xcode 15.0 or later.

### Steps

1. Open your Swift project in Xcode.
2. Navigate to `File` > `Add Package Dependencies...`.
3. Enter the package repository URL: `https://github.com/argmaxinc/whisperkit`.
4. Choose the version range or specific version.
5. Click `Finish` to add WhisperKit to your project.

## Getting Started

To get started with WhisperKit, you need to initialize it in your project.

### Quick Example

This example demonstrates how to transcribe a local audio file:

```swift
Expand All @@ -52,7 +70,9 @@ Task {
```

### Model Selection

WhisperKit automatically downloads the recommended model for the device if not specified. You can also select a specific model by passing in the model name:

```swift
let pipe = try? await WhisperKit(model: "large-v3")
```
Expand All @@ -76,38 +96,47 @@ git clone https://github.com/argmaxinc/whisperkit.git
cd whisperkit
```

Then, setup the environment and download the models.
Then, setup the environment and download your desired model.

```bash
make setup
make download-model MODEL=large-v3
```

**Note**:
1. this will download all available models to your local folder, if you only want to download a specific model, see our [HuggingFace repo](https://huggingface.co/argmaxinc/whisperkit-coreml))
2. before running `download-models`, make sure [git-lfs](https://git-lfs.com) is installed

1. This will download only the model specified by `MODEL` (see what's available in our [HuggingFace repo](https://huggingface.co/argmaxinc/whisperkit-coreml), where we use the prefix `openai_whisper-{MODEL}`)
2. Before running `download-model`, make sure [git-lfs](https://git-lfs.com) is installed

If you would like download all available models to your local folder, use this command instead:

```bash
make setup
make download-models
```

You can then run the CLI with:
You can then run them via the CLI with:

```bash
swift run transcribe --model-path "Models/whisperkit-coreml/openai_whisper-large-v3" --audio-path "path/to/your/audio.{wav,mp3,m4a,flac}"
```

Which should print a transcription of the audio file.


## Contributing & Roadmap

Our goal is to make WhisperKit better and better over time and we'd love your help! Just search the code for "TODO" for a variety of features that are yet to be built. Please refer to our [contribution guidelines](CONTRIBUTING.md) for submitting issues, pull requests, and coding standards, where we also have a public roadmap of features we are looking forward to building in the future.

## License

WhisperKit is released under the MIT License. See [LICENSE.md](LICENSE.md) for more details.

## Citation

If you use WhisperKit for something cool or just find it useful, please drop us a note at [info@takeargmax.com](mailto:info@takeargmax.com)!

If you use WhisperKit for academic work, here is the BibTeX:

```
```bibtex
@misc{whisperkit-argmax,
title = {WhisperKit},
author = {Argmax, Inc.},
Expand Down
11 changes: 4 additions & 7 deletions Sources/WhisperKit/Core/AudioProcessor.swift
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public protocol AudioProcessing {
var relativeEnergyWindow: Int { get set }

/// Starts recording audio from the specified input device, resetting the previous state
func startRecordingLive(from inputDevice: AVCaptureDevice?, callback: (([Float]) -> Void)?) throws
func startRecordingLive(callback: (([Float]) -> Void)?) throws

/// Pause recording
func pauseRecording()
Expand All @@ -53,7 +53,7 @@ public protocol AudioProcessing {
public extension AudioProcessing {
// Use default recording device
func startRecordingLive(callback: (([Float]) -> Void)?) throws {
try startRecordingLive(from: nil, callback: callback)
try startRecordingLive(callback: callback)
}

static func padOrTrimAudio(fromArray audioArray: [Float], startAt startIndex: Int = 0, toLength frameLength: Int = 480_000, saveSegment: Bool = false) -> MLMultiArray? {
Expand Down Expand Up @@ -382,14 +382,11 @@ public extension AudioProcessor {
}
}

func startRecordingLive(from inputDevice: AVCaptureDevice? = nil, callback: (([Float]) -> Void)? = nil) throws {
func startRecordingLive(callback: (([Float]) -> Void)? = nil) throws {
audioSamples = []
audioEnergy = []

if inputDevice != nil {
// TODO: implement selecting input device
Logging.debug("Input device selection not yet supported")
}
// TODO: implement selecting input device

audioEngine = try setupEngine()

Expand Down
8 changes: 5 additions & 3 deletions Sources/WhisperKit/Core/TextDecoder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ public protocol TextDecoding {

@available(macOS 14, iOS 17, tvOS 14, watchOS 10, *)
public extension TextDecoding {
func prepareDecoderInputs(withPrompt initialPrompt: [Int]) -> DecodingInputs {
func prepareDecoderInputs(withPrompt initialPrompt: [Int]) -> DecodingInputs? {
let tokenShape = [NSNumber(value: 1), NSNumber(value: initialPrompt.count)]

// Initialize MLMultiArray for tokens
Expand All @@ -59,11 +59,13 @@ public extension TextDecoding {
}

guard let kvCacheEmbedDim = self.kvCacheEmbedDim else {
fatalError("Unable to determine kvCacheEmbedDim")
Logging.error("Unable to determine kvCacheEmbedDim")
return nil
}

guard let kvCacheMaxSequenceLength = self.kvCacheMaxSequenceLength else {
fatalError("Unable to determine kvCacheMaxSequenceLength")
Logging.error("Unable to determine kvCacheMaxSequenceLength")
return nil
}

// Initialize each MLMultiArray
Expand Down
63 changes: 57 additions & 6 deletions Sources/WhisperKit/Core/Utils.swift
Original file line number Diff line number Diff line change
Expand Up @@ -153,14 +153,65 @@ public func modelSupport(for deviceName: String) -> (default: String, disabled:
let model where model.hasPrefix("iPhone16"): // A17
return ("base", ["large-v3_turbo", "large-v3", "large-v2_turbo", "large-v2"])

// TODO: Disable turbo variants for M1
case let model where model.hasPrefix("arm64"): // Mac
return ("base", [""])

// Catch-all for unhandled models or macs
// Fall through to macOS checks
default:
return ("base", [""])
break
}

#if os(macOS)
if deviceName.hasPrefix("arm64") {
if Process.processor.contains("Apple M1") {
// Disable turbo variants for M1
return ("base", ["large-v3_turbo", "large-v3_turbo_1049MB", "large-v3_turbo_1307MB", "large-v2_turbo", "large-v2_turbo_1116MB", "large-v2_turbo_1430MB"])
} else {
// Enable all variants for M2 or M3, none disabled
return ("base", [])
}
}
#endif

// Unhandled device to base variant
return ("base", [""])
}

#if os(macOS)
// From: https://stackoverflow.com/a/71726663
extension Process {
static func stringFromTerminal(command: String) -> String {
let task = Process()
let pipe = Pipe()
task.standardOutput = pipe
task.launchPath = "/bin/bash"
task.arguments = ["-c", "sysctl -n " + command]
task.launch()
return String(bytes: pipe.fileHandleForReading.availableData, encoding: .utf8) ?? ""
}
static let processor = stringFromTerminal(command: "machdep.cpu.brand_string")
static let cores = stringFromTerminal(command: "machdep.cpu.core_count")
static let threads = stringFromTerminal(command: "machdep.cpu.thread_count")
static let vendor = stringFromTerminal(command: "machdep.cpu.vendor")
static let family = stringFromTerminal(command: "machdep.cpu.family")
}
#endif

public func resolveAbsolutePath(_ inputPath: String) -> String {
let fileManager = FileManager.default

// Expanding tilde if present
let pathWithTildeExpanded = NSString(string: inputPath).expandingTildeInPath

// If the path is already absolute, return it
if pathWithTildeExpanded.hasPrefix("/") {
return pathWithTildeExpanded
}

// Resolving relative path based on the current working directory
if let cwd = fileManager.currentDirectoryPath as String? {
let resolvedPath = URL(fileURLWithPath: cwd).appendingPathComponent(pathWithTildeExpanded).path
return resolvedPath
}

return inputPath
}

func loadTokenizer(for pretrained: ModelVariant) async throws -> Tokenizer {
Expand Down
2 changes: 1 addition & 1 deletion Sources/WhisperKit/Core/WhisperKit.swift
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ public class WhisperKit {
return (modelInfo + additionalInfo).trimmingFromEnd(character: "/", upto: 1)
}

// Custom sorting order
// Sorting order based on enum
let sizeOrder = ModelVariant.allCases.map { $0.description }

let sortedModels = availableModels.sorted { firstModel, secondModel in
Expand Down
17 changes: 12 additions & 5 deletions Sources/WhisperKitCLI/transcribe.swift
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ import WhisperKit
@main
struct WhisperKitCLI: AsyncParsableCommand {
@Option(help: "Path to audio file")
var audioPath: String = "./Tests/WhisperKitTests/Resources/jfk.wav"
var audioPath: String = "Tests/WhisperKitTests/Resources/jfk.wav"

@Option(help: "Path of model files")
var modelPath: String = "./Models/whisperkit-coreml/openai_whisper-tiny"
var modelPath: String = "Models/whisperkit-coreml/openai_whisper-tiny"

@Option(help: "Compute units for audio encoder model with {all,cpuOnly,cpuAndGPU,cpuAndNeuralEngine,random}")
var audioEncoderComputeUnits: ComputeUnits = .cpuAndNeuralEngine
Expand Down Expand Up @@ -71,10 +71,17 @@ struct WhisperKitCLI: AsyncParsableCommand {
var reportPath: String = "."

func transcribe(audioPath: String, modelPath: String) async throws {
guard FileManager.default.fileExists(atPath: modelPath) else {
fatalError("Resource path does not exist \(modelPath)")
let resolvedModelPath = resolveAbsolutePath(modelPath)
guard FileManager.default.fileExists(atPath: resolvedModelPath) else {
fatalError("Model path does not exist \(resolvedModelPath)")
}

let resolvedAudioPath = resolveAbsolutePath(audioPath)
guard FileManager.default.fileExists(atPath: resolvedAudioPath) else {
fatalError("Resource path does not exist \(resolvedAudioPath)")
}


let computeOptions = ModelComputeOptions(
audioEncoderCompute: audioEncoderComputeUnits.asMLComputeUnits,
textDecoderCompute: textDecoderComputeUnits.asMLComputeUnits
Expand Down Expand Up @@ -104,7 +111,7 @@ struct WhisperKitCLI: AsyncParsableCommand {
noSpeechThreshold: noSpeechThreshold
)

let transcribeResult = try await whisperKit.transcribe(audioPath: audioPath, decodeOptions: options)
let transcribeResult = try await whisperKit.transcribe(audioPath: resolvedAudioPath, decodeOptions: options)

let transcription = transcribeResult?.text ?? "Transcription failed"

Expand Down
6 changes: 6 additions & 0 deletions Tests/WhisperKitTests/FunctionalTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@ import XCTest

@available(macOS 14, iOS 17, *)
final class FunctionalTests: XCTestCase {
func testInitLarge() async {
let modelPath = largev3ModelPath()
let whisperKit = try? await WhisperKit(modelFolder: modelPath, logLevel: .error)
XCTAssertNotNil(whisperKit)
}

func testOutputAll() async throws {
let modelPaths = allModelPaths()

Expand Down
Loading

0 comments on commit 82c1fa1

Please sign in to comment.