Skip to content

Commit

Permalink
fix: move NeMo imports inside the function
Browse files Browse the repository at this point in the history
  • Loading branch information
elanmart committed Dec 20, 2022
1 parent e18d7ba commit f7db007
Showing 1 changed file with 23 additions and 15 deletions.
38 changes: 23 additions & 15 deletions cbp_translate/components/speakers.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
""" Speaker diarization - detecting and annotating unique speakers. """

import json
import os
from dataclasses import dataclass
from pathlib import Path

import wget
from nemo.collections.asr.models import ClusteringDiarizer
from omegaconf import OmegaConf

from cbp_translate.modal_ import ROOT, gpu_image, hf_secret, stub, volume
from cbp_translate.modal_ import SHARED, gpu_image, nemo_secret, hf_secret, stub, volume


@dataclass
Expand All @@ -19,6 +16,7 @@ class SpeakerSegment:


def combine_segments(speakers: list[SpeakerSegment]) -> list[SpeakerSegment]:
"""Combine consecutive segments where speaker ID stays the same."""

ret = []
s = speakers[0]
Expand All @@ -38,14 +36,17 @@ def combine_segments(speakers: list[SpeakerSegment]) -> list[SpeakerSegment]:


def parse_nemo_output(path: str):
dia = Path(path).read_text()
lines = dia.splitlines()
lines = [l.strip() for l in lines if len(l.strip()) > 1]
"""Parse the output of the Nemo diarization model."""

results = Path(path).read_text()
lines = results.splitlines()
lines = [line.strip() for line in lines if len(line.strip()) > 1]

ret = []
for line in lines:
_, _, _, t0, dur, _, _, ID, *_ = line.split()
t0, dur = float(t0), float(dur)
seg = SpeakerSegment(ID.capitalize(), start=t0, end=t0 + dur)
_, _, _, t0, duration, _, _, ID, *_ = line.split()
t0, duration = float(t0), float(duration)
seg = SpeakerSegment(ID.capitalize(), start=t0, end=t0 + duration)
ret.append(seg)

return ret
Expand All @@ -54,11 +55,17 @@ def parse_nemo_output(path: str):
@stub.function(
image=gpu_image,
gpu=True,
shared_volumes={str(ROOT): volume},
secret=hf_secret,
shared_volumes={str(SHARED): volume},
secret=nemo_secret,
timeout=30 * 60,
)
def extract_speakers(path: str, combine: bool = True) -> list[SpeakerSegment]:
"""Extract speaker IDs from an audio file."""

# Local imports are required by Modal
import wget
from nemo.collections.asr.models import ClusteringDiarizer
from omegaconf import OmegaConf

meta = {
"audio_filepath": path,
Expand Down Expand Up @@ -128,17 +135,18 @@ def extract_speakers(path: str, combine: bool = True) -> list[SpeakerSegment]:
@stub.function(
image=gpu_image,
gpu=True,
shared_volumes={str(ROOT): volume},
shared_volumes={str(SHARED): volume},
secret=hf_secret,
timeout=30 * 60,
)
def extract_speakers_pyannote(path_audio: str) -> list[SpeakerSegment]:
"""Legacy implementation using pyannote.audio"""

# Local imports are required for Modal
from pyannote.audio import Pipeline

# Note that we're downloading the model to a shared volume
(cache_dir := (ROOT / ".hf")).mkdir(exist_ok=True)
(cache_dir := (SHARED / ".hf")).mkdir(exist_ok=True)
auth_token = os.environ["HUGGINGFACE_TOKEN"]

# Run the pipeline
Expand Down

0 comments on commit f7db007

Please sign in to comment.