Skip to content

Commit

Permalink
Merge branch 'pa-speaker-detector-api' into 'main'
Browse files Browse the repository at this point in the history
[PASpakerDetector] Change init method

See merge request heka/medkit!238

changelog: [PASpakerDetector] Change init method
  • Loading branch information
ghisvail committed Nov 27, 2023
2 parents 9f9f50f + 5bfb163 commit daa7eaa
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 89 deletions.
4 changes: 0 additions & 4 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,6 @@ __pycache__/
*.py[cod]
*$py.class

# Large test files
tests/large_data/*
!tests/large_data/README.md

# Sphinx documentation
docs/_build/
docs/api-gen/_autosummary/
Expand Down
4 changes: 2 additions & 2 deletions medkit/audio/metrics/transcription.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,12 +196,12 @@ def _convert_speech_segs_to_words(self, segments: Sequence[Segment]) -> List[str

if not transcription_attrs:
raise ValueError(
f"Attribute with label '{self.speaker_label}' not found on"
f"Attribute with label '{self.transcription_label}' not found on"
" speech segment"
)
if len(transcription_attrs) > 1:
logger.warning(
f"Found several attributes with label '{self.speaker_label}',"
f"Found several attributes with label '{self.transcription_label}',"
" ignoring all but first"
)
transcription = transcription_attrs[0].value
Expand Down
57 changes: 25 additions & 32 deletions medkit/audio/segmentation/pa_speaker_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
__all__ = ["PASpeakerDetector"]

from pathlib import Path
from typing import Dict, Iterator, List, Optional, Union
from typing_extensions import Literal
from typing import Iterator, List, Optional, Union

# When pyannote and spacy are both installed, a conflict might occur between the
# ujson library used by pandas (a pyannote dependency) and the ujson library used
Expand All @@ -18,6 +17,7 @@
# we import pandas manually first.
# So as a workaround, we always import pandas before importing something from pyannote
import pandas # noqa: F401
from pyannote.audio import Pipeline
from pyannote.audio.pipelines import SpeakerDiarization
import torch

Expand Down Expand Up @@ -48,13 +48,11 @@ class PASpeakerDetector(SegmentationOperation):

def __init__(
self,
segmentation_model: Union[str, Path],
embedding_model: Union[str, Path],
model: Union[str, Path],
output_label: str,
pipeline_params: Optional[Dict] = None,
min_nb_speakers: Optional[int] = None,
max_nb_speakers: Optional[int] = None,
clustering: Literal["AgglomerativeClustering"] = "AgglomerativeClustering",
min_duration: float = 0.1,
device: int = -1,
segmentation_batch_size: int = 1,
embedding_batch_size: int = 1,
Expand All @@ -64,28 +62,19 @@ def __init__(
"""
Parameters
----------
segmentation_model:
Name (on the HuggingFace models hub) or path of the `PyanNet`
segmentation model. When a path, should point to the .bin file
containing the model.
embedding_model:
Name (on the HuggingFace models hub) or path to the embedding model.
When a path to a speechbrain model, should point to the directory containing
the model weights and hyperparameters.
model:
Name (on the HuggingFace models hub) or path of a pretrained
pipeline. When a path, should point to the .yaml file containing the
pipeline configuration.
output_label:
Label of generated turn segments.
pipeline_params:
Dictionary of segmentation and clustering parameters. The dictionary
can hold a "segmentation" key and a "clustering" key pointing to
sub dictionaries. Refer to the pyannote documentation for the
supported parameters segmentation and clustering parameters
(clustering parameters depend on the clustering method used).
min_nb_speakers:
Minimum number of speakers expected to be found.
max_nb_speakers:
Maximum number of speakers expected to be found.
clustering:
Clustering method to use.
min_duration:
Minimum duration of speech segments, in seconds (short segments will
be discarded).
device:
Device to use for pytorch models. Follows the Hugging Face
convention (`-1` for cpu and device number for gpu, for instance `0`
Expand All @@ -110,18 +99,20 @@ def __init__(
self.output_label = output_label
self.min_nb_speakers = min_nb_speakers
self.max_nb_speakers = max_nb_speakers
self.min_duration = min_duration

torch_device = torch.device("cpu" if device < 0 else f"cuda:{device}")
self._pipeline = SpeakerDiarization(
segmentation=str(segmentation_model),
embedding=str(embedding_model),
clustering=clustering,
embedding_exclude_overlap=True,
segmentation_batch_size=segmentation_batch_size,
embedding_batch_size=embedding_batch_size,
use_auth_token=hf_auth_token,
).to(torch_device)
self._pipeline.instantiate(pipeline_params)
self._pipeline = Pipeline.from_pretrained(model, use_auth_token=hf_auth_token)
if self._pipeline is None:
raise Exception(f"Could not instantiate pretrained pipeline with '{model}'")
if not isinstance(self._pipeline, SpeakerDiarization):
raise Exception(
f"'{model}' does not correspond to a SpeakerDiarization pipeline. Got"
f" object of type {type(self._pipeline)}"
)
self._pipeline.to(torch_device)
self._pipeline.segmentation_batch_size = segmentation_batch_size
self._pipeline.embedding_batch_size = embedding_batch_size

def run(self, segments: List[Segment]) -> List[Segment]:
"""Return all turn segments detected for all input `segments`.
Expand Down Expand Up @@ -157,6 +148,8 @@ def _detect_turns_in_segment(self, segment: Segment) -> Iterator[Segment]:
)

for turn, _, speaker in diarization.itertracks(yield_label=True):
if turn.duration < self.min_duration:
continue
# trim original audio to turn start/end points
turn_audio = audio.trim_duration(turn.start, turn.end)

Expand Down
19 changes: 19 additions & 0 deletions tests/large/diar_pipeline_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
version: 3.0.0

pipeline:
name: pyannote.audio.pipelines.SpeakerDiarization
params:
clustering: AgglomerativeClustering
embedding: speechbrain/spkrec-ecapa-voxceleb
embedding_batch_size: 1
embedding_exclude_overlap: true
segmentation: pyannote/segmentation-3.0
segmentation_batch_size: 32

params:
clustering:
method: centroid
min_cluster_size: 12
threshold: 0.7
segmentation:
min_duration_off: 0.0
25 changes: 2 additions & 23 deletions tests/large/test_pa_speaker_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,7 @@
PASpeakerDetector,
) # noqa: E402

# model weights provided by pyannote and speechbrain on huggingface hub
_TEST_DATA_DIR = Path(__file__).parent.parent / "large_data"
_SEGMENTATION_MODEL = _TEST_DATA_DIR / "pyannote" / "segmentation" / "pytorch_model.bin"
_EMBEDDING_MODEL = _TEST_DATA_DIR / "speechbrain" / "spkrec-ecapa-voxceleb"
# simple params that will work with our test file
_CLUSTERING = "AgglomerativeClustering"
_PIPELINE_PARAMS = {
"segmentation": {
"min_duration_off": 0.0,
},
"clustering": {
"method": "centroid",
"min_cluster_size": 12,
"threshold": 0.7,
},
}


_PIPELINE_MODEL = Path(__file__).parent / "diar_pipeline_config.yaml"
_AUDIO = FileAudioBuffer("tests/data/audio/dialog_long.ogg")
_SPEAKER_CHANGE_TIME = 4.0
_MARGIN = 1.0
Expand All @@ -44,14 +27,10 @@ def _get_segment():
)


@pytest.mark.xfail
def test_basic():
speaker_detector = PASpeakerDetector(
segmentation_model=_SEGMENTATION_MODEL,
embedding_model=_EMBEDDING_MODEL,
clustering=_CLUSTERING,
model=_PIPELINE_MODEL,
output_label="turn",
pipeline_params=_PIPELINE_PARAMS,
min_nb_speakers=2,
max_nb_speakers=2,
)
Expand Down
11 changes: 0 additions & 11 deletions tests/large_data/README.md

This file was deleted.

31 changes: 14 additions & 17 deletions tests/unit/audio/segmentation/test_pa_speaker_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ class _MockedPASegment(NamedTuple):
start: float
end: float

@property
def duration(self) -> float:
return self.end - self.start


class _MockedPAAnnotation:
def __init__(self, segments, labels):
Expand All @@ -42,15 +46,13 @@ def itertracks(self, yield_label):

# mock of SpeakerDiarization class used by PASpeakerDetector
class _MockedPipeline:
def __init__(self, *args, **kwargs):
pass
def __init__(self):
self.segmentation_batch_size = 1
self.embedding_batch_size = 1

def to(self, device):
return self

def instantiate(self, params):
pass

def apply(self, file, **kwargs):
# return hard coded results (always split in half)
duration = file["waveform"].shape[-1] / file["sample_rate"]
Expand All @@ -68,6 +70,10 @@ def _mocked_pipeline(module_mocker):
"medkit.audio.segmentation.pa_speaker_detector.SpeakerDiarization",
_MockedPipeline,
)
module_mocker.patch(
"medkit.audio.segmentation.pa_speaker_detector.Pipeline.from_pretrained",
lambda *args, **kwargs: _MockedPipeline(),
)


def _get_segment(duration):
Expand All @@ -85,10 +91,7 @@ def test_basic():
"""Basic behavior"""

speaker_detector = PASpeakerDetector(
segmentation_model="mock-segmentation-model",
embedding_model="mock-segmentation-model",
clustering="MockClusteringMethod",
pipeline_params={},
model="mock-pipeline",
output_label=_OUTPUT_LABEL,
min_nb_speakers=2,
max_nb_speakers=2,
Expand Down Expand Up @@ -135,10 +138,7 @@ def test_basic():
def test_multiple():
"""Several segments passed as input"""
speaker_detector = PASpeakerDetector(
segmentation_model="mock-segmentation-model",
embedding_model="mock-segmentation-model",
clustering="MockClusteringMethod",
pipeline_params={},
model="mock-pipeline",
output_label=_OUTPUT_LABEL,
min_nb_speakers=2,
max_nb_speakers=2,
Expand Down Expand Up @@ -194,10 +194,7 @@ def test_prov():
"""Generated provenance nodes"""

speaker_detector = PASpeakerDetector(
segmentation_model="mock-segmentation-model",
embedding_model="mock-segmentation-model",
clustering="MockClusteringMethod",
pipeline_params={},
model="mock-pipeline",
output_label=_OUTPUT_LABEL,
min_nb_speakers=2,
max_nb_speakers=2,
Expand Down

0 comments on commit daa7eaa

Please sign in to comment.