Skip to content

Commit ba42638

Browse files
authored
Merge pull request #7 from argmaxinc/custom-vocabulary2
Bump version to 1.9.6 with `upToNextMajorVersion` and add support for custom vocabulary
2 parents b11179a + 32e6113 commit ba42638

14 files changed

+904
-165
lines changed

Playground.xcodeproj/project.pbxproj

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
1677AFC42B57618A008C61C0 /* Preview Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 1677AFAE2B57618A008C61C0 /* Preview Assets.xcassets */; };
1212
1677AFE12B57678E008C61C0 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 1677AFE02B57678E008C61C0 /* Assets.xcassets */; };
1313
1677AFE62B57704E008C61C0 /* ContentView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 1677AFE52B57704E008C61C0 /* ContentView.swift */; };
14+
5539A56B2EA719360020D5CE /* CustomVocabularySheet.swift in Sources */ = {isa = PBXBuildFile; fileRef = 5539A56A2EA719360020D5CE /* CustomVocabularySheet.swift */; };
15+
5539A56D2EA71B2A0020D5CE /* HighlightedTextView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 5539A56C2EA71B2A0020D5CE /* HighlightedTextView.swift */; };
1416
740F6DA12E2CD3ED00429FE9 /* AudioDeviceDiscoverer.swift in Sources */ = {isa = PBXBuildFile; fileRef = 740F6DA02E2CD3ED00429FE9 /* AudioDeviceDiscoverer.swift */; };
1517
740F6DA42E2DB07400429FE9 /* ArgmaxSDKCoordinator.swift in Sources */ = {isa = PBXBuildFile; fileRef = 740F6DA32E2DB07400429FE9 /* ArgmaxSDKCoordinator.swift */; };
1618
74312CDC2E1D02E3000D994A /* MacAudioDevicesView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 74312CDB2E1D02E3000D994A /* MacAudioDevicesView.swift */; };
@@ -22,6 +24,7 @@
2224
7446C9172E4D536400290EAB /* SwiftUI.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 7446C9162E4D536400290EAB /* SwiftUI.framework */; };
2325
7446C9242E4D536500290EAB /* TranscriptionLiveActivityExtension.appex in Embed Foundation Extensions */ = {isa = PBXBuildFile; fileRef = 7446C9122E4D536400290EAB /* TranscriptionLiveActivityExtension.appex */; platformFilter = ios; settings = {ATTRIBUTES = (RemoveHeadersOnCopy, ); }; };
2426
7469F1FA2E3AC3D00090AEBA /* README.md in Resources */ = {isa = PBXBuildFile; fileRef = 7469F1F92E3AC3D00090AEBA /* README.md */; };
27+
746D9D512EC7D12300CBFB4A /* VocabularyResults.swift in Sources */ = {isa = PBXBuildFile; fileRef = 746D9D502EC7D11E00CBFB4A /* VocabularyResults.swift */; };
2528
746E4C062E39874F009623D7 /* DefaultEnvInitializer.swift in Sources */ = {isa = PBXBuildFile; fileRef = 746E4C042E39874F009623D7 /* DefaultEnvInitializer.swift */; };
2629
746E4C0A2E398757009623D7 /* PlaygroundEnvInitializer.swift in Sources */ = {isa = PBXBuildFile; fileRef = 746E4C082E398757009623D7 /* PlaygroundEnvInitializer.swift */; };
2730
746F2A072E39D7030081D0D6 /* NoOpAnalyticsLogger.swift in Sources */ = {isa = PBXBuildFile; fileRef = 746F2A062E39D6FC0081D0D6 /* NoOpAnalyticsLogger.swift */; };
@@ -88,6 +91,8 @@
8891
1677AFE02B57678E008C61C0 /* Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = Assets.xcassets; sourceTree = "<group>"; };
8992
1677AFE52B57704E008C61C0 /* ContentView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ContentView.swift; sourceTree = "<group>"; };
9093
167B345E2B05431E0076F261 /* Playground.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = Playground.app; sourceTree = BUILT_PRODUCTS_DIR; };
94+
5539A56A2EA719360020D5CE /* CustomVocabularySheet.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = CustomVocabularySheet.swift; sourceTree = "<group>"; };
95+
5539A56C2EA71B2A0020D5CE /* HighlightedTextView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = HighlightedTextView.swift; sourceTree = "<group>"; };
9196
740F6DA02E2CD3ED00429FE9 /* AudioDeviceDiscoverer.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = AudioDeviceDiscoverer.swift; sourceTree = "<group>"; };
9297
740F6DA32E2DB07400429FE9 /* ArgmaxSDKCoordinator.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ArgmaxSDKCoordinator.swift; sourceTree = "<group>"; };
9398
74312CDB2E1D02E3000D994A /* MacAudioDevicesView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MacAudioDevicesView.swift; sourceTree = "<group>"; };
@@ -99,6 +104,7 @@
99104
7446C9142E4D536400290EAB /* WidgetKit.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = WidgetKit.framework; path = System/Library/Frameworks/WidgetKit.framework; sourceTree = SDKROOT; };
100105
7446C9162E4D536400290EAB /* SwiftUI.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = SwiftUI.framework; path = System/Library/Frameworks/SwiftUI.framework; sourceTree = SDKROOT; };
101106
7469F1F92E3AC3D00090AEBA /* README.md */ = {isa = PBXFileReference; lastKnownFileType = net.daringfireball.markdown; path = README.md; sourceTree = "<group>"; };
107+
746D9D502EC7D11E00CBFB4A /* VocabularyResults.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = VocabularyResults.swift; sourceTree = "<group>"; };
102108
746E4C042E39874F009623D7 /* DefaultEnvInitializer.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = DefaultEnvInitializer.swift; sourceTree = "<group>"; };
103109
746E4C082E398757009623D7 /* PlaygroundEnvInitializer.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = PlaygroundEnvInitializer.swift; sourceTree = "<group>"; };
104110
746F2A062E39D6FC0081D0D6 /* NoOpAnalyticsLogger.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = NoOpAnalyticsLogger.swift; sourceTree = "<group>"; };
@@ -193,6 +199,8 @@
193199
1677AFE42B5769E5008C61C0 /* Views */ = {
194200
isa = PBXGroup;
195201
children = (
202+
5539A56C2EA71B2A0020D5CE /* HighlightedTextView.swift */,
203+
5539A56A2EA719360020D5CE /* CustomVocabularySheet.swift */,
196204
74F897782E4F9B130045252E /* TranscriptionModeSelection.swift */,
197205
74312CDD2E1DA46C000D994A /* StreamResultView.swift */,
198206
1677AFE52B57704E008C61C0 /* ContentView.swift */,
@@ -266,6 +274,7 @@
266274
74C12F122E2EB54A00C772C3 /* Utils */ = {
267275
isa = PBXGroup;
268276
children = (
277+
746D9D502EC7D11E00CBFB4A /* VocabularyResults.swift */,
269278
74F860952E2B19060007163C /* CoreAudioUtils.swift */,
270279
);
271280
path = Utils;
@@ -394,9 +403,11 @@
394403
isa = PBXSourcesBuildPhase;
395404
buildActionMask = 2147483647;
396405
files = (
406+
5539A56D2EA71B2A0020D5CE /* HighlightedTextView.swift in Sources */,
397407
747E67082E300F780061E778 /* TranscribeResultView.swift in Sources */,
398408
74312CDC2E1D02E3000D994A /* MacAudioDevicesView.swift in Sources */,
399409
74F3B7BE2E1CF44F00C544D1 /* AudioProcessDiscoverer.swift in Sources */,
410+
746D9D512EC7D12300CBFB4A /* VocabularyResults.swift in Sources */,
400411
1677AFE62B57704E008C61C0 /* ContentView.swift in Sources */,
401412
74F3B7C12E1CF4F400C544D1 /* AudioProcess.swift in Sources */,
402413
746F2A072E39D7030081D0D6 /* NoOpAnalyticsLogger.swift in Sources */,
@@ -413,6 +424,7 @@
413424
74F860962E2B19060007163C /* CoreAudioUtils.swift in Sources */,
414425
747E67062E3008000061E778 /* TranscribeViewModel.swift in Sources */,
415426
74F860942E29A9D20007163C /* ProcessTapper.swift in Sources */,
427+
5539A56B2EA719360020D5CE /* CustomVocabularySheet.swift in Sources */,
416428
740F6DA42E2DB07400429FE9 /* ArgmaxSDKCoordinator.swift in Sources */,
417429
);
418430
runOnlyForDeploymentPostprocessing = 0;
@@ -758,7 +770,7 @@
758770
repositoryURL = "argmaxinc.argmax-sdk-swift-alpha";
759771
requirement = {
760772
kind = upToNextMajorVersion;
761-
minimumVersion = 1.7.11;
773+
minimumVersion = 1.9.6;
762774
};
763775
};
764776
/* End XCRemoteSwiftPackageReference section */

Playground/Playground.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ struct Playground: App {
9292
}
9393

9494
var body: some Scene {
95-
WindowGroup {
95+
WindowGroup("Argmax Playground") {
9696
ContentView(analyticsLogger: analyticsLogger)
9797
#if os(macOS)
9898
.environmentObject(audioProcessDiscoverer)

Playground/Services/ArgmaxSDKCoordinator.swift

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,8 @@ public final class ArgmaxSDKCoordinator: ObservableObject {
159159
public func prepare(modelName: String,
160160
repository: String? = nil,
161161
config: WhisperKitProConfig,
162-
redownload: Bool = false) async throws {
162+
redownload: Bool = false,
163+
clustererVersion: ClustererVersion) async throws {
163164
guard let apiKey = apiKey, !apiKey.isEmpty else {
164165
self.whisperKitModelState = .unloaded
165166
self.speakerKitModelState = .unloaded
@@ -197,7 +198,7 @@ public final class ArgmaxSDKCoordinator: ObservableObject {
197198
self.whisperKit = whisperKitPro
198199

199200
// --- Then prepare SpeakerKit
200-
let speakerKitPro = try await initializeSpeakerKitPro()
201+
let speakerKitPro = try await initializeSpeakerKitPro(clustererVersion: clustererVersion)
201202
self.speakerKit = speakerKitPro
202203
self.speakerKitModelState = speakerKitPro.modelState
203204

@@ -210,6 +211,20 @@ public final class ArgmaxSDKCoordinator: ObservableObject {
210211
throw error
211212
}
212213
}
214+
215+
@MainActor
216+
public func updateCustomVocabulary(words: [String]) throws {
217+
guard let whisperKit else {
218+
throw ArgmaxError.modelUnavailable("WhisperKit model is not loaded")
219+
}
220+
221+
do {
222+
try whisperKit.setCustomVocabulary(words)
223+
} catch {
224+
Logging.error("Failed to update custom vocabulary: \(error)")
225+
throw error
226+
}
227+
}
213228

214229
public func delete(modelName: String,
215230
repository: String? = nil,
@@ -226,6 +241,12 @@ public final class ArgmaxSDKCoordinator: ObservableObject {
226241
throw ArgmaxError.generic("Failed to delete model")
227242
}
228243
}
244+
245+
public func deleteCustomVocabularyModels() async throws {
246+
for model in ["canary-1b-v2", "parakeet-tdt_ctc-110m"] {
247+
try await modelStore.deleteModel(variant: model, from: "argmaxinc/ctckit-pro")
248+
}
249+
}
229250

230251
public func reset() async {
231252
modelStore.cancelDownload()
@@ -322,8 +343,8 @@ public final class ArgmaxSDKCoordinator: ObservableObject {
322343
}
323344

324345
/// Initializes and loads SpeakerKitPro
325-
private func initializeSpeakerKitPro() async throws -> SpeakerKitPro {
326-
var config = SpeakerKitProConfig(load: true)
346+
private func initializeSpeakerKitPro(clustererVersion: ClustererVersion) async throws -> SpeakerKitPro {
347+
var config = SpeakerKitProConfig(load: true, clustererVersion: clustererVersion)
327348
let connected = await ArgmaxSDK.isConnected()
328349
if !connected {
329350
config.download = false

Playground/TranscriptionLiveActivity/TranscriptionAttributes.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ struct TranscriptionAttributes: ActivityAttributes {
1212
/// Static configuration that doesn't change during the live activity session
1313
public struct ContentState: Codable, Hashable {
1414
/// Current transcription hypothesis text being processed
15-
var currentHypothesis: String
15+
var currentHypothesis: AttributedString
1616

1717
/// Duration of audio processed in seconds
1818
var audioSeconds: Double

Playground/TranscriptionLiveActivity/TranscriptionLiveActivity.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,10 @@ struct TranscriptionLiveActivity: Widget {
4646
.font(.caption)
4747
.foregroundColor(.secondary)
4848
.frame(minHeight: 32, alignment: .topLeading)
49-
} else if !context.state.currentHypothesis.isEmpty {
49+
} else if !context.state.currentHypothesis.characters.isEmpty {
5050
Text(context.state.currentHypothesis)
5151
.font(.caption)
52-
.lineLimit(nil)
52+
.lineLimit(3)
5353
.truncationMode(.head)
5454
.frame(minHeight: 32, alignment: .topLeading)
5555
.fixedSize(horizontal: false, vertical: true)
@@ -107,7 +107,7 @@ struct LockScreenLiveActivityView: View {
107107
Text("Microphone session interrupted. Restart transcription from the app.")
108108
.font(.subheadline)
109109
.foregroundColor(.secondary)
110-
} else if !context.state.currentHypothesis.isEmpty {
110+
} else if !context.state.currentHypothesis.characters.isEmpty {
111111
Text(context.state.currentHypothesis)
112112
.font(.subheadline)
113113
.lineLimit(3)
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
import Argmax
2+
3+
/// Convenient alias used for mapping `WordTiming` values to their matching custom vocabulary hits.
4+
typealias VocabularyResults = [WordTiming: [WordTiming]]

Playground/ViewModels/StreamViewModel.swift

Lines changed: 44 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,11 @@ import ActivityKit
4242
/// - **`AudioProcessDiscoverer` / `AudioDeviceDiscoverer`:** (macOS only) These are used to determine which
4343
/// audio sources are available for streaming.
4444
class StreamViewModel: ObservableObject {
45+
#if os(macOS)
46+
private let energyHistoryLimit = 512
47+
#else
48+
private let energyHistoryLimit = 256
49+
#endif
4550
// Stream Results - per-stream data for UI
4651
@Published var deviceResult: StreamResult?
4752
@Published var systemResult: StreamResult?
@@ -137,12 +142,11 @@ class StreamViewModel: ObservableObject {
137142
/// Contains all transcription data for a single stream including text results, timing information, and audio energy data
138143
struct StreamResult {
139144
var title: String = ""
140-
var confirmedText: String = ""
141-
var hypothesisText: String = ""
145+
var confirmedSegments: [TranscriptionSegment] = []
146+
var hypothesisSegments: [TranscriptionSegment] = []
147+
var customVocabularyResults: VocabularyResults = [:]
142148
var streamEndSeconds: Float?
143149
var bufferEnergy: [Float] = []
144-
var bufferSeconds: Double = 0
145-
var transcribeResult: TranscriptionResultPro? = nil
146150
var streamTimestampText: String {
147151
guard let end = streamEndSeconds else {
148152
return ""
@@ -311,20 +315,37 @@ class StreamViewModel: ObservableObject {
311315
}
312316
}
313317

318+
private func mergeVocabularyResults(
319+
existing: inout VocabularyResults,
320+
newResults: VocabularyResults
321+
) {
322+
guard !newResults.isEmpty else { return }
323+
for (key, occurrences) in newResults {
324+
if var stored = existing[key] {
325+
stored.append(contentsOf: occurrences)
326+
existing[key] = stored
327+
} else {
328+
existing[key] = occurrences
329+
}
330+
}
331+
}
332+
314333
@MainActor
315334
private func handleResult(_ result: LiveResult, for sourceId: String) {
316335
switch result {
317-
case .hypothesis(let text, _, _):
336+
case .hypothesis(_, _, let hypothesisResult):
318337
let now = Date().timeIntervalSince1970
319338
let last = lastHypothesisUpdateAtBySource[sourceId] ?? 0
320339
// Update at most 10 times per second per source
321340
guard now - last >= 0.1 else { return }
322341
lastHypothesisUpdateAtBySource[sourceId] = now
323-
let trimmed = text.trimmingCharacters(in: .whitespacesAndNewlines)
324-
guard trimmed != (isDeviceSource(sourceId) ? deviceResult?.hypothesisText : systemResult?.hypothesisText) else { return }
325342
updateStreamResult(sourceId: sourceId) { oldResult in
326343
var newResult = oldResult
327-
newResult.hypothesisText = trimmed
344+
newResult.hypothesisSegments = hypothesisResult.hypothesisSegments
345+
mergeVocabularyResults(
346+
existing: &newResult.customVocabularyResults,
347+
newResults: hypothesisResult.customVocabularyResults
348+
)
328349
return newResult
329350
}
330351

@@ -333,28 +354,31 @@ class StreamViewModel: ObservableObject {
333354
Task {
334355
await liveActivityManager.updateContentState { oldState in
335356
var state = oldState
336-
state.currentHypothesis = trimmed
357+
let highlightedHypothesis = HighlightedTextView.createHighlightedAttributedString(
358+
segments: deviceResult?.hypothesisSegments ?? [],
359+
customVocabularyResults: deviceResult?.customVocabularyResults ?? [:],
360+
font: .body,
361+
foregroundColor: .primary
362+
)
363+
state.currentHypothesis = highlightedHypothesis
337364
return state
338365
}
339366
}
340367
#endif
341368

342-
case .confirm(let text, let seconds, let transcriptionResult):
369+
case .confirm(_, let seconds, let confirmedResult):
343370
updateStreamResult(sourceId: sourceId) { oldResult in
344371
var newResult = oldResult
345-
let newText = text.trimmingCharacters(in: .whitespaces)
346-
if !newText.isEmpty {
347-
if !newResult.confirmedText.isEmpty {
348-
newResult.confirmedText += " "
349-
}
350-
newResult.confirmedText += newText
351-
}
372+
newResult.confirmedSegments += confirmedResult.segments
352373
newResult.streamEndSeconds = Float(seconds)
353-
newResult.transcribeResult = transcriptionResult
374+
mergeVocabularyResults(
375+
existing: &newResult.customVocabularyResults,
376+
newResults: confirmedResult.customVocabularyResults
377+
)
354378
return newResult
355379
}
356380
if let confirmedresultCallback = self.confirmedresultCallback {
357-
confirmedresultCallback(sourceId, transcriptionResult)
381+
confirmedresultCallback(sourceId, confirmedResult)
358382
}
359383
}
360384
}
@@ -373,18 +397,13 @@ class StreamViewModel: ObservableObject {
373397

374398
// Limit the amount of energy samples passed to the UI for performance
375399
let energies = whisperKitPro.audioProcessor.relativeEnergy
376-
#if os(iOS)
377-
let newBufferEnergy = Array(energies.suffix(256))
378-
#else
379-
let newBufferEnergy = energies
380-
#endif
400+
let newBufferEnergy = Array(energies.suffix(self.energyHistoryLimit))
381401
let sampleCount = whisperKitPro.audioProcessor.audioSamples.count
382402
let audioSeconds = Double(sampleCount) / Double(WhisperKit.sampleRate)
383403

384404
updateStreamResult(sourceId: source.id) { oldResult in
385405
var newResult = oldResult
386406
newResult.bufferEnergy = newBufferEnergy
387-
newResult.bufferSeconds = audioSeconds
388407
return newResult
389408
}
390409

0 commit comments

Comments
 (0)