Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,10 @@ import FluidAudio
// Initialize and process audio
Task {
let diarizer = DiarizerManager()
try await diarizer.initialize()
diarizer.initialize(models: try await .downloadIfNeeded())

let audioSamples: [Float] = // your 16kHz audio data
let result = try await diarizer.performCompleteDiarization(audioSamples, sampleRate: 16000)
let result = try diarizer.performCompleteDiarization(audioSamples, sampleRate: 16000)

for segment in result.segments {
print("\(segment.speakerId): \(segment.startTimeSeconds)s - \(segment.endTimeSeconds)s")
Expand Down
257 changes: 32 additions & 225 deletions Sources/FluidAudio/DiarizerManager.swift
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ public struct DiarizerConfig: Sendable {
public var numClusters: Int = -1 // Number of speakers to detect (-1 = auto-detect)
public var minActivityThreshold: Float = 10.0 // Minimum activity threshold (frames) for speaker to be considered active
public var debugMode: Bool = false
public var modelCacheDirectory: URL?

public static let `default` = DiarizerConfig()

Expand All @@ -21,8 +20,7 @@ public struct DiarizerConfig: Sendable {
minDurationOff: 0.5,
numClusters: -1,
minActivityThreshold: 10.0,
debugMode: false,
modelCacheDirectory: nil
debugMode: false
)
#endif

Expand All @@ -32,16 +30,14 @@ public struct DiarizerConfig: Sendable {
minDurationOff: Float = 0.5,
numClusters: Int = -1,
minActivityThreshold: Float = 10.0,
debugMode: Bool = false,
modelCacheDirectory: URL? = nil
debugMode: Bool = false
) {
self.clusteringThreshold = clusteringThreshold
self.minDurationOn = minDurationOn
self.minDurationOff = minDurationOff
self.numClusters = numClusters
self.minActivityThreshold = minActivityThreshold
self.debugMode = debugMode
self.modelCacheDirectory = modelCacheDirectory
}
}

Expand Down Expand Up @@ -114,6 +110,15 @@ public struct PipelineTimings: Sendable, Codable {
}
}

extension Duration {

/// Converts this duration to a Foundation TimeInterval - i.e. a `Double` number of seconds.
///
internal var timeInterval: TimeInterval {
self / .seconds(1)
}
}

/// Complete diarization result with consistent speaker IDs and embeddings
public struct DiarizationResult: Sendable {
public let segments: [TimedSpeakerSegment]
Expand Down Expand Up @@ -168,10 +173,10 @@ public struct SpeakerEmbedding: Sendable {
}

public struct ModelPaths: Sendable {
public let segmentationPath: String
public let embeddingPath: String
public let segmentationPath: URL
public let embeddingPath: URL

public init(segmentationPath: String, embeddingPath: String) {
public init(segmentationPath: URL, embeddingPath: URL) {
self.segmentationPath = segmentationPath
self.embeddingPath = embeddingPath
}
Expand Down Expand Up @@ -252,124 +257,40 @@ public final class DiarizerManager {
private let logger = Logger(subsystem: "com.fluidinfluence.diarizer", category: "Diarizer")
private let config: DiarizerConfig

// ML models
private var segmentationModel: MLModel?
private var embeddingModel: MLModel?

// Timing tracking
private var modelDownloadTime: TimeInterval = 0
private var modelCompilationTime: TimeInterval = 0
private var models: DiarizerModels?

public init(config: DiarizerConfig = .default) {
self.config = config
}

public var isAvailable: Bool {
return segmentationModel != nil && embeddingModel != nil
models != nil
}

/// Get the initialization timing data
public var initializationTimings: (downloadTime: TimeInterval, compilationTime: TimeInterval) {
return (modelDownloadTime, modelCompilationTime)
models.map { ($0.downloadTime.timeInterval, $0.compilationTime.timeInterval) } ?? (0, 0)
}

public func initialize() async throws {
let initStartTime = Date()
public func initialize(models: consuming DiarizerModels) {
logger.info("Initializing diarization system")

try await cleanupBrokenModels()

let downloadStartTime = Date()
let modelPaths = try await downloadModels()
self.modelDownloadTime = Date().timeIntervalSince(downloadStartTime)

let segmentationURL = URL(fileURLWithPath: modelPaths.segmentationPath)
let embeddingURL = URL(fileURLWithPath: modelPaths.embeddingPath)

let compilationStartTime = Date()
try await loadModelsWithAutoRecovery(
segmentationURL: segmentationURL, embeddingURL: embeddingURL)
self.modelCompilationTime = Date().timeIntervalSince(compilationStartTime)

let totalInitTime = Date().timeIntervalSince(initStartTime)
logger.info(
"Diarization system initialized successfully in \(String(format: "%.2f", totalInitTime))s (download: \(String(format: "%.2f", self.modelDownloadTime))s, compilation: \(String(format: "%.2f", self.modelCompilationTime))s)"
)
self.models = consume models
}

/// Load models with automatic recovery on compilation failures
private func loadModelsWithAutoRecovery(
segmentationURL: URL, embeddingURL: URL, maxRetries: Int = 2
) async throws {
let config: MLModelConfiguration = MLModelConfiguration()
config.computeUnits = .cpuAndNeuralEngine

let modelPaths = [
(url: segmentationURL, name: "segmentation"),
(url: embeddingURL, name: "embedding")
]

let models = try await DownloadUtils.loadModelsWithAutoRecovery(
modelPaths: modelPaths,
config: config,
maxRetries: maxRetries,
recoveryAction: {
try await self.performModelRecovery(
segmentationURL: segmentationURL,
embeddingURL: embeddingURL
)
}
)

self.segmentationModel = models[0]
self.embeddingModel = models[1]
}

/// Perform model recovery by deleting and re-downloading corrupted models
private func performModelRecovery(segmentationURL: URL, embeddingURL: URL) async throws {
try await DownloadUtils.performModelRecovery(
modelPaths: [segmentationURL, embeddingURL],
downloadAction: {
// Re-download segmentation model
try await DownloadUtils.downloadMLModelBundle(
repoPath: "FluidInference/speaker-diarization-coreml",
modelName: "pyannote_segmentation.mlmodelc",
outputPath: segmentationURL
)

// Re-download embedding model
try await DownloadUtils.downloadMLModelBundle(
repoPath: "FluidInference/speaker-diarization-coreml",
modelName: "wespeaker.mlmodelc",
outputPath: embeddingURL
)
}
)
@available(*, deprecated, message: "Use initialize(models:) instead")
public func initialize() async throws {
self.initialize(models: try await .downloadIfNeeded())
}

private func cleanupBrokenModels() async throws {
let modelsDirectory = getModelsDirectory()
let segmentationModelPath = modelsDirectory.appendingPathComponent(
"pyannote_segmentation.mlmodelc")
let embeddingModelPath = modelsDirectory.appendingPathComponent("wespeaker.mlmodelc")

if FileManager.default.fileExists(atPath: segmentationModelPath.path)
&& !DownloadUtils.isModelCompiled(at: segmentationModelPath)
{
logger.info("Removing broken segmentation model")
try FileManager.default.removeItem(at: segmentationModelPath)
}

if FileManager.default.fileExists(atPath: embeddingModelPath.path)
&& !DownloadUtils.isModelCompiled(at: embeddingModelPath)
{
logger.info("Removing broken embedding model")
try FileManager.default.removeItem(at: embeddingModelPath)
}
/// Clean up resources
public func cleanup() {
models = nil
logger.info("Diarization resources cleaned up")
}

private func getSegments(audioChunk: ArraySlice<Float>, chunkSize: Int = 160_000) throws -> [[[Float]]] {
guard let segmentationModel = self.segmentationModel else {

guard let segmentationModel = models?.segmentationModel else {
throw DiarizerError.notInitialized
}

Expand Down Expand Up @@ -613,111 +534,6 @@ public final class DiarizerManager {
annotation[finalSegment] = currentSpeaker // Use raw speaker index
}

// MARK: - Model Management

/// Download required models for diarization
public func downloadModels() async throws -> ModelPaths {
logger.info("Checking for existing diarization models")

let modelsDirectory = getModelsDirectory()

let segmentationModelPath = modelsDirectory.appendingPathComponent(
"pyannote_segmentation.mlmodelc"
).path
let embeddingModelPath = modelsDirectory.appendingPathComponent("wespeaker.mlmodelc").path

let segmentationURL = URL(fileURLWithPath: segmentationModelPath)
let embeddingURL = URL(fileURLWithPath: embeddingModelPath)

// Check if models already exist and are valid
let segmentationExists =
FileManager.default.fileExists(atPath: segmentationModelPath)
&& DownloadUtils.isModelCompiled(at: segmentationURL)
let embeddingExists =
FileManager.default.fileExists(atPath: embeddingModelPath)
&& DownloadUtils.isModelCompiled(at: embeddingURL)

if segmentationExists && embeddingExists {
logger.info("Valid models already exist, skipping download")
return ModelPaths(
segmentationPath: segmentationModelPath, embeddingPath: embeddingModelPath)
}

logger.info("Downloading missing or invalid diarization models from Hugging Face")

// Download segmentation model if needed
if !segmentationExists {
logger.info("Downloading segmentation model bundle from Hugging Face")
try await DownloadUtils.downloadMLModelBundle(
repoPath: "FluidInference/speaker-diarization-coreml",
modelName: "pyannote_segmentation.mlmodelc",
outputPath: segmentationURL
)
logger.info("Downloaded segmentation model bundle from Hugging Face")
}

// Download embedding model if needed
if !embeddingExists {
logger.info("Downloading embedding model bundle from Hugging Face")
try await DownloadUtils.downloadMLModelBundle(
repoPath: "FluidInference/speaker-diarization-coreml",
modelName: "wespeaker.mlmodelc",
outputPath: embeddingURL
)
logger.info("Downloaded embedding model bundle from Hugging Face")
}

logger.info("Successfully ensured diarization models are available")
return ModelPaths(
segmentationPath: segmentationModelPath, embeddingPath: embeddingModelPath)
}


/// Compile a model
private func compileModel(at sourceURL: URL, outputPath: URL) async throws -> URL {
logger.info("Compiling model from \(sourceURL.lastPathComponent)")

// Remove existing compiled model if it exists
if FileManager.default.fileExists(atPath: outputPath.path) {
try FileManager.default.removeItem(at: outputPath)
}

// Compile the model
let compiledModelURL = try await MLModel.compileModel(at: sourceURL)

// Move to the desired location
try FileManager.default.moveItem(at: compiledModelURL, to: outputPath)

// Clean up the source file
try? FileManager.default.removeItem(at: sourceURL)

logger.info("Successfully compiled model to \(outputPath.lastPathComponent)")
return outputPath
}

private func getModelsDirectory() -> URL {
let directory: URL

if let customDirectory = config.modelCacheDirectory {
directory = customDirectory.appendingPathComponent("coreml", isDirectory: true)
} else {
#if os(iOS)
// Use Documents directory on iOS for better compatibility with sandboxing
let documents = FileManager.default.urls(for: .documentDirectory, in: .userDomainMask).first!
directory = documents.appendingPathComponent("FluidAudio/models/diarization", isDirectory: true)
#else
// Use Application Support on macOS
let appSupport = FileManager.default.urls(
for: .applicationSupportDirectory, in: .userDomainMask
).first!
directory = appSupport.appendingPathComponent(
"SpeakerKitModels/coreml", isDirectory: true)
#endif
}

try? FileManager.default.createDirectory(at: directory, withIntermediateDirectories: true)
return directory.standardizedFileURL
}

// MARK: - Audio Analysis

Expand Down Expand Up @@ -871,16 +687,14 @@ public final class DiarizerManager {
return (embeddings[maxActivityIndex], normalizedActivity)
}

// MARK: - Cleanup

// MARK: - Combined Efficient Diarization

/// Perform complete diarization with consistent speaker IDs across chunks
/// This is more efficient than calling performSegmentation + extractEmbedding separately
public func performCompleteDiarization(_ samples: [Float], sampleRate: Int = 16000) throws
-> DiarizationResult
{
guard segmentationModel != nil, embeddingModel != nil else {
guard let models else {
throw DiarizerError.notInitialized
}

Expand Down Expand Up @@ -922,8 +736,8 @@ public final class DiarizerManager {
let totalProcessingTime = Date().timeIntervalSince(processingStartTime)

let timings = PipelineTimings(
modelDownloadSeconds: self.modelDownloadTime,
modelCompilationSeconds: self.modelCompilationTime,
modelDownloadSeconds: models.downloadTime.timeInterval,
modelCompilationSeconds: models.compilationTime.timeInterval,
audioLoadingSeconds: 0, // Will be set by CLI
segmentationSeconds: segmentationTime,
embeddingExtractionSeconds: embeddingTime,
Expand Down Expand Up @@ -972,7 +786,7 @@ public final class DiarizerManager {
let embeddingStartTime = Date()

// Step 2: Get embeddings using same segmentation results
guard let embeddingModel = self.embeddingModel else {
guard let embeddingModel = models?.embeddingModel else {
throw DiarizerError.notInitialized
}

Expand Down Expand Up @@ -1210,11 +1024,4 @@ public final class DiarizerManager {
qualityScore: quality
)
}

/// Clean up resources
public func cleanup() {
segmentationModel = nil
embeddingModel = nil
logger.info("Diarization resources cleaned up")
}
}
Loading
Loading