Skip to content

Commit

Permalink
Support distil whisper models (#88)
Browse files Browse the repository at this point in the history
* Support distil models

* Lint

* Update remaining open classes

* Update readme

* Increase tolerance for accuracy test

* Normalize text in accuracy test

* Check for model uniqueness before downloading

- Allows previous model download flows to continue working

* Documentation and cleanup

* Add functional test for non-blob models
  • Loading branch information
ZachNagengast authored Mar 25, 2024
1 parent 508240f commit 0f19f7e
Show file tree
Hide file tree
Showing 20 changed files with 258 additions and 100 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@
"kind" : "remoteSourceControl",
"location" : "https://github.com/huggingface/swift-transformers.git",
"state" : {
"revision" : "564442fba36b0b694d730a62d0593e5f54043b55",
"version" : "0.1.2"
"revision" : "4f915610451d29a05948802a140880ff37494dad",
"version" : "0.1.6"
}
}
],
Expand Down
11 changes: 6 additions & 5 deletions Examples/WhisperAX/WhisperAX/Views/ContentView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ struct ContentView: View {
@State private var currentFallbacks: Int = 0
@State private var lastBufferSize: Int = 0
@State private var lastConfirmedSegmentEndSeconds: Float = 0
@State private var requiredSegmentsForConfirmation: Int = 2
@State private var requiredSegmentsForConfirmation: Int = 4
@State private var bufferEnergy: [Float] = []
@State private var confirmedSegments: [TranscriptionSegment] = []
@State private var unconfirmedSegments: [TranscriptionSegment] = []
Expand Down Expand Up @@ -269,7 +269,8 @@ struct ContentView: View {

#if os(macOS)
Button(action: {
if let folder = whisperKit?.modelFolder {
let folderURL = whisperKit?.modelFolder ?? (localModels.contains(selectedModel) ? URL(fileURLWithPath: localModelPath) : nil)
if let folder = folderURL {
NSWorkspace.shared.open(folder)
}
}, label: {
Expand Down Expand Up @@ -708,7 +709,7 @@ struct ContentView: View {
localModelPath = modelPath
do {
let downloadedModels = try FileManager.default.contentsOfDirectory(atPath: modelPath)
for model in downloadedModels where !localModels.contains(model) && model.starts(with: "openai") {
for model in downloadedModels where !localModels.contains(model) {
localModels.append(model)
}
} catch {
Expand Down Expand Up @@ -763,7 +764,7 @@ struct ContentView: View {
if localModels.contains(model) && !redownload {
// Get local model folder URL from localModels
// TODO: Make this configurable in the UI
folder = URL(fileURLWithPath: localModelPath).appendingPathComponent("openai_whisper-\(model)")
folder = URL(fileURLWithPath: localModelPath).appendingPathComponent(model)
} else {
// Download the model
folder = try await WhisperKit.download(variant: model, from: repoName, progressCallback: { progress in
Expand Down Expand Up @@ -833,7 +834,7 @@ struct ContentView: View {

func deleteModel() {
if localModels.contains(selectedModel) {
let modelFolder = URL(fileURLWithPath: localModelPath).appendingPathComponent("openai_whisper-\(selectedModel)")
let modelFolder = URL(fileURLWithPath: localModelPath).appendingPathComponent(selectedModel)

do {
try FileManager.default.removeItem(at: modelFolder)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ struct WhisperAXWatchView: View {
// Get local model folder URL from localModels
// TODO: Make this configurable in the UI
// TODO: Handle incomplete downloads
folder = URL(fileURLWithPath: localModelPath).appendingPathComponent("openai_whisper-\(model)")
folder = URL(fileURLWithPath: localModelPath).appendingPathComponent(model)
} else {
// Download the model
folder = try await WhisperKit.download(variant: model, from: repoName, progressCallback: { progress in
Expand Down
4 changes: 2 additions & 2 deletions Package.resolved
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
"kind" : "remoteSourceControl",
"location" : "https://github.com/huggingface/swift-transformers.git",
"state" : {
"revision" : "3bd02269b7797ade67c15679a575cd5c6f203ce6",
"version" : "0.1.5"
"revision" : "4f915610451d29a05948802a140880ff37494dad",
"version" : "0.1.6"
}
}
],
Expand Down
2 changes: 1 addition & 1 deletion Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ let package = Package(
),
],
dependencies: [
.package(url: "https://github.com/huggingface/swift-transformers.git", exact: "0.1.5"),
.package(url: "https://github.com/huggingface/swift-transformers.git", exact: "0.1.6"),
.package(url: "https://github.com/apple/swift-argument-parser.git", exact: "1.3.0"),
],
targets: [
Expand Down
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@ Check out the demo app on [TestFlight](https://testflight.apple.com/join/LPVOyJZ
## Table of Contents

- [Installation](#installation)
- [Swift Package Manager](#swift-package-manager)
- [Prerequisites](#prerequisites)
- [Steps](#steps)
- [Homebrew](#homebrew)
- [Getting Started](#getting-started)
- [Quick Example](#quick-example)
- [Model Selection](#model-selection)
Expand Down Expand Up @@ -91,6 +93,14 @@ WhisperKit automatically downloads the recommended model for the device if not s
let pipe = try? await WhisperKit(model: "large-v3")
```

This method also supports glob search, so you can use wildcards to select a model:

```swift
let pipe = try? await WhisperKit(model: "distil*large-v3")
```

Note that the model search must return a single model from the source repo, otherwise an error will be thrown.

For a list of available models, see our [HuggingFace repo](https://huggingface.co/argmaxinc/whisperkit-coreml).

### Generating Models
Expand Down
46 changes: 23 additions & 23 deletions Sources/WhisperKit/Core/AudioProcessor.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,18 @@ import AVFoundation
import CoreAudio
import CoreML

/// Core Audio Device
// Core Audio Device
#if os(macOS)
public typealias DeviceID = AudioDeviceID
#else
public typealias DeviceID = String
public typealias DeviceID = String
#endif

public struct AudioDevice: Identifiable, Hashable {
public let id: DeviceID
public let name: String
}

public protocol AudioProcessing {
/// Loads audio data from a specified file path.
/// - Parameter audioFilePath: The file path of the audio file.
Expand Down Expand Up @@ -53,7 +53,7 @@ public protocol AudioProcessing {

/// Starts recording audio from the specified input device, resetting the previous state
func startRecordingLive(inputDeviceID: DeviceID?, callback: (([Float]) -> Void)?) throws

/// Pause recording
func pauseRecording()

Expand Down Expand Up @@ -341,14 +341,14 @@ public class AudioProcessor: NSObject, AudioProcessing {
#endif
}
}

#if os(macOS)
public static func getAudioDevices() -> [AudioDevice] {
var devices = [AudioDevice]()

var propertySize: UInt32 = 0
var status: OSStatus = noErr

// Get the number of devices
var propertyAddress = AudioObjectPropertyAddress(
mSelector: kAudioHardwarePropertyDevices,
Expand All @@ -366,7 +366,7 @@ public class AudioProcessor: NSObject, AudioProcessing {
Logging.error("Error: Unable to get the number of audio devices.")
return devices
}

// Get the device IDs
let deviceCount = Int(propertySize) / MemoryLayout<AudioDeviceID>.size
var deviceIDs = [AudioDeviceID](repeating: 0, count: deviceCount)
Expand All @@ -382,17 +382,17 @@ public class AudioProcessor: NSObject, AudioProcessing {
Logging.error("Error: Unable to get the audio device IDs.")
return devices
}

// Get device info for each device
for deviceID in deviceIDs {
var deviceName: String = ""
var inputChannels: Int = 0
var deviceName = ""
var inputChannels = 0

// Get device name
var propertySize: UInt32 = UInt32(MemoryLayout<Unmanaged<CFString>?>.size)
var name: Unmanaged<CFString>? = nil
var propertySize = UInt32(MemoryLayout<Unmanaged<CFString>?>.size)
var name: Unmanaged<CFString>?
propertyAddress.mSelector = kAudioDevicePropertyDeviceNameCFString

status = AudioObjectGetPropertyData(
deviceID,
&propertyAddress,
Expand All @@ -404,7 +404,7 @@ public class AudioProcessor: NSObject, AudioProcessing {
if status == noErr, let deviceNameCF = name?.takeUnretainedValue() as String? {
deviceName = deviceNameCF
}

// Get input channels
propertyAddress.mSelector = kAudioDevicePropertyStreamConfiguration
propertyAddress.mScope = kAudioDevicePropertyScopeInput
Expand All @@ -420,12 +420,12 @@ public class AudioProcessor: NSObject, AudioProcessing {
}
}
}

if inputChannels > 0 {
devices.append(AudioDevice(id: deviceID, name: deviceName))
}
}

return devices
}
#endif
Expand Down Expand Up @@ -461,14 +461,14 @@ public extension AudioProcessor {
Logging.debug("Current audio size: \(self.audioSamples.count) samples, most recent buffer: \(buffer.count) samples, most recent energy: \(newEnergy)")
}
}

#if os(macOS)
func assignAudioInput(inputNode: AVAudioInputNode, inputDeviceID: AudioDeviceID) {
guard let audioUnit = inputNode.audioUnit else {
Logging.error("Failed to access the audio unit of the input node.")
return
}

var inputDeviceID = inputDeviceID

let error = AudioUnitSetProperty(
Expand All @@ -479,7 +479,7 @@ public extension AudioProcessor {
&inputDeviceID,
UInt32(MemoryLayout<AudioDeviceID>.size)
)

if error != noErr {
Logging.error("Error setting Audio Unit property: \(error)")
} else {
Expand Down Expand Up @@ -562,11 +562,11 @@ public extension AudioProcessor {
audioSamples.removeFirst(audioSamples.count - keep)
}
}

func startRecordingLive(inputDeviceID: DeviceID? = nil, callback: (([Float]) -> Void)? = nil) throws {
audioSamples = []
audioEnergy = []

try? setupAudioSessionForDevice()

audioEngine = try setupEngine(inputDeviceID: inputDeviceID)
Expand Down
2 changes: 1 addition & 1 deletion Sources/WhisperKit/Core/FeatureExtractor.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ public protocol FeatureExtracting {
}

@available(macOS 13, iOS 16, watchOS 10, visionOS 1, *)
public class FeatureExtractor: FeatureExtracting, WhisperMLModel {
open class FeatureExtractor: FeatureExtracting, WhisperMLModel {
public var model: MLModel?

public init() {}
Expand Down
45 changes: 23 additions & 22 deletions Sources/WhisperKit/Core/LogitsFilter.swift
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ public protocol LogitsFiltering {
}

@available(macOS 13, iOS 16, watchOS 10, visionOS 1, *)
public class SuppressTokensFilter: LogitsFiltering {
open class SuppressTokensFilter: LogitsFiltering {
let suppressTokens: [Int]
private let suppressTokenIndexes: [[NSNumber]]

Expand All @@ -27,7 +27,7 @@ public class SuppressTokensFilter: LogitsFiltering {
}

@available(macOS 13, iOS 16, watchOS 10, visionOS 1, *)
public class SuppressBlankFilter: LogitsFiltering {
open class SuppressBlankFilter: LogitsFiltering {
let suppressBlankTokens: [Int]
let sampleBegin: Int
private let suppressTokenIndexes: [[NSNumber]]
Expand All @@ -49,7 +49,7 @@ public class SuppressBlankFilter: LogitsFiltering {

/// Implementation based on https://github.com/openai/whisper/blob/master/whisper/decoding.py#L441
@available(macOS 13, iOS 16, watchOS 10, visionOS 1, *)
public class TimestampRulesFilter: LogitsFiltering {
open class TimestampRulesFilter: LogitsFiltering {
let transcribeToken: Int
let translateToken: Int
let noTimestampsToken: Int
Expand Down Expand Up @@ -115,15 +115,15 @@ public class TimestampRulesFilter: LogitsFiltering {
}
}

if tokens.count == sampleBegin {
// suppress generating non-timestamp tokens at the beginning
logits.fillLastDimension(indexes: 0..<timeTokenBegin, with: -FloatType.infinity)
if let maxInitialTimestampIndex {
// apply the `maxInitialTimestamp` option
let lastAllowed = timeTokenBegin + maxInitialTimestampIndex + 1
logits.fillLastDimension(indexes: lastAllowed..<logits.count, with: -FloatType.infinity)
}
}
if tokens.count == sampleBegin {
// suppress generating non-timestamp tokens at the beginning
logits.fillLastDimension(indexes: 0..<timeTokenBegin, with: -FloatType.infinity)
if let maxInitialTimestampIndex {
// apply the `maxInitialTimestamp` option
let lastAllowed = timeTokenBegin + maxInitialTimestampIndex + 1
logits.fillLastDimension(indexes: lastAllowed..<logits.count, with: -FloatType.infinity)
}
}

// if sum of probability over timestamps is above any other token, sample timestamp
if sumOfProbabilityOverTimestampsIsAboveAnyOtherToken(logits: logits, timeTokenBegin: timeTokenBegin) {
Expand Down Expand Up @@ -175,14 +175,14 @@ public class TimestampRulesFilter: LogitsFiltering {
output: logprobs,
batchSize: 1
)

let timeTokenCount = logits.count - timeTokenBeginOffset
let noTimeTokenCount = timeTokenBeginOffset
let logSumExpInputPointer = UnsafeMutableRawBufferPointer(
start: logprobs.data!.advanced(by: timeTokenBeginOffset * MemoryLayout<FloatType>.stride),
count: timeTokenCount * MemoryLayout<FloatType>.stride
)

guard let logSumExpInputDescriptor = BNNSNDArrayDescriptor(
data: logSumExpInputPointer,
scalarType: FloatType.self,
Expand All @@ -191,25 +191,25 @@ public class TimestampRulesFilter: LogitsFiltering {
Logging.error("Cannot create `logSumExpInputDescriptor`")
return false
}

let timestampLogProb = BNNSNDArrayDescriptor.allocateUninitialized(
scalarType: FloatType.self,
shape: .vector(1, stride: 1)
)
defer { timestampLogProb.deallocate() }

try BNNS.applyReduction(
.logSumExp,
input: logSumExpInputDescriptor,
output: timestampLogProb,
weights: nil
)

let maxTextTokenLogProbInputPointer = UnsafeMutableRawBufferPointer(
start: logprobs.data,
count: noTimeTokenCount * MemoryLayout<FloatType>.stride
)

guard let maxTextTokenLogProbInputDescriptor = BNNSNDArrayDescriptor(
data: maxTextTokenLogProbInputPointer,
scalarType: FloatType.self,
Expand All @@ -218,22 +218,23 @@ public class TimestampRulesFilter: LogitsFiltering {
Logging.error("Cannot create `maxTextTokenLogProbInputDescriptor`")
return false
}

let maxTextTokenLogProb = BNNSNDArrayDescriptor.allocateUninitialized(
scalarType: FloatType.self,
shape: .vector(1, stride: 1)
)
defer { maxTextTokenLogProb.deallocate() }

try BNNS.applyReduction(
.max,
input: maxTextTokenLogProbInputDescriptor,
output: maxTextTokenLogProb,
weights: nil
)

guard let timestampLogProbValue = timestampLogProb.makeArray(of: FloatType.self)?.first,
let maxTextTokenLogProbValue = maxTextTokenLogProb.makeArray(of: FloatType.self)?.first else {
let maxTextTokenLogProbValue = maxTextTokenLogProb.makeArray(of: FloatType.self)?.first
else {
Logging.error("Cannot create logProb arrays")
return false
}
Expand Down
Loading

0 comments on commit 0f19f7e

Please sign in to comment.