Skip to content

Commit 230cf3b

Browse files
use overload bool-literal typing
1 parent f81a75a commit 230cf3b

File tree

1 file changed

+21
-12
lines changed

1 file changed

+21
-12
lines changed

pyannote/core/annotation.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@
111111
from collections import defaultdict
112112
from typing import (
113113
Hashable,
114+
Literal,
114115
Optional,
115116
Dict,
116117
Union,
@@ -123,6 +124,7 @@
123124
Text,
124125
TYPE_CHECKING,
125126
NamedTuple,
127+
overload,
126128
)
127129

128130
import numpy as np
@@ -224,7 +226,7 @@ def _updateLabels(self):
224226

225227
# accumulate segments for updated labels
226228
_segments = {label: [] for label in update}
227-
for segment, track, label in self.itertracks_with_labels():
229+
for segment, track, label in self.itertracks(yield_label=True):
228230
if label in update:
229231
_segments[label].append(segment)
230232

@@ -270,6 +272,13 @@ def itersegments(self):
270272
"""
271273
return iter(self._tracks)
272274

275+
@overload
276+
def itertracks(self, yield_label: Literal[False] = ...) -> Iterator[SegmentTrack]: ...
277+
@overload
278+
def itertracks(self, yield_label: Literal[True]) -> Iterator[SegmentTrackLabel]: ...
279+
@overload
280+
def itertracks(self, yield_label: bool) -> Iterator[Union[SegmentTrack, SegmentTrackLabel]]: ...
281+
273282
def itertracks(
274283
self, yield_label: bool = False
275284
) -> Iterator[Union[SegmentTrack, SegmentTrackLabel]]:
@@ -292,7 +301,7 @@ def itertracks(
292301
>>> for segment, track in annotation.itertracks():
293302
... # do something with the track
294303
295-
>>> for segment, track, label in annotation.itertracks_with_labels():
304+
>>> for segment, track, label in annotation.itertracks(yield_label=True):
296305
... # do something with the track and its label
297306
"""
298307

@@ -307,11 +316,11 @@ def itertracks(
307316

308317
def itertracks_with_labels(self) -> Iterator[SegmentTrackLabel]:
309318
"""Typed version of :func:`itertracks`(yield_label=True)"""
310-
return self.itertracks(yield_label=True) # type: ignore
319+
return self.itertracks(yield_label=True)
311320

312321
def itertracks_without_labels(self) -> Iterator[SegmentTrack]:
313322
"""Typed version of :func:`itertracks`(yield_label=False)"""
314-
return self.itertracks(yield_label=False) # type: ignore
323+
return self.itertracks(yield_label=False)
315324

316325
def _updateTimeline(self):
317326
self._timeline = Timeline(segments=self._tracks, uri=self.uri)
@@ -358,14 +367,14 @@ def __eq__(self, other: "Annotation"):
358367
labels are equal.
359368
"""
360369
pairOfTracks = itertools.zip_longest(
361-
self.itertracks_with_labels(), other.itertracks_with_labels()
370+
self.itertracks(yield_label=True), other.itertracks(yield_label=True)
362371
)
363372
return all(t1 == t2 for t1, t2 in pairOfTracks)
364373

365374
def __ne__(self, other: "Annotation"):
366375
"""Inequality"""
367376
pairOfTracks = itertools.zip_longest(
368-
self.itertracks_with_labels(), other.itertracks_with_labels()
377+
self.itertracks(yield_label=True), other.itertracks(yield_label=True)
369378
)
370379

371380
return any(t1 != t2 for t1, t2 in pairOfTracks)
@@ -404,7 +413,7 @@ def _iter_rttm(self) -> Iterator[Text]:
404413
f'containing spaces (got: "{uri}").'
405414
)
406415
raise ValueError(msg)
407-
for segment, _, label in self.itertracks_with_labels():
416+
for segment, _, label in self.itertracks(yield_label=True):
408417
if isinstance(label, Text) and " " in label:
409418
msg = (
410419
f"Space-separated RTTM file format does not allow labels "
@@ -449,7 +458,7 @@ def _iter_lab(self) -> Iterator[Text]:
449458
iterator: Iterator[str]
450459
An iterator over LAB text lines
451460
"""
452-
for segment, _, label in self.itertracks_with_labels():
461+
for segment, _, label in self.itertracks(yield_label=True):
453462
if isinstance(label, Text) and " " in label:
454463
msg = (
455464
f"Space-separated LAB file format does not allow labels "
@@ -806,7 +815,7 @@ def __str__(self):
806815
"""Human-friendly representation"""
807816
# TODO: use pandas.DataFrame
808817
return "\n".join(
809-
["%s %s %s" % (s, t, l) for s, t, l in self.itertracks_with_labels()]
818+
["%s %s %s" % (s, t, l) for s, t, l in self.itertracks(yield_label=True)]
810819
)
811820

812821
def __delitem__(self, key: Key):
@@ -1051,7 +1060,7 @@ def update(self, annotation: "Annotation", copy: bool = False) -> "Annotation":
10511060
result = self.copy() if copy else self
10521061

10531062
# TODO speed things up by working directly with annotation internals
1054-
for segment, track, label in annotation.itertracks_with_labels():
1063+
for segment, track, label in annotation.itertracks(yield_label=True):
10551064
result[segment, track] = label
10561065

10571066
return result
@@ -1255,7 +1264,7 @@ def rename_tracks(self, generator: Union[LabelGenerator, Iterable[str], Iterable
12551264
raise ValueError("generator must be 'string', 'int', or iterable")
12561265

12571266
# TODO speed things up by working directly with annotation internals
1258-
for s, _, label in self.itertracks_with_labels():
1267+
for s, _, label in self.itertracks(yield_label=True):
12591268
renamed[s, next(generator_)] = label
12601269
return renamed
12611270

@@ -1338,7 +1347,7 @@ def relabel_tracks(self, generator: LabelGenerator = "string") -> "Annotation":
13381347
generator = int_generator()
13391348

13401349
relabeled = self.empty()
1341-
for s, t, _ in self.itertracks_with_labels():
1350+
for s, t, _ in self.itertracks(yield_label=True):
13421351
relabeled[s, t] = next(generator)
13431352

13441353
return relabeled

0 commit comments

Comments
 (0)