111111from collections import defaultdict
112112from typing import (
113113 Hashable ,
114+ Literal ,
114115 Optional ,
115116 Dict ,
116117 Union ,
123124 Text ,
124125 TYPE_CHECKING ,
125126 NamedTuple ,
127+ overload ,
126128)
127129
128130import numpy as np
@@ -224,7 +226,7 @@ def _updateLabels(self):
224226
225227 # accumulate segments for updated labels
226228 _segments = {label : [] for label in update }
227- for segment , track , label in self .itertracks_with_labels ( ):
229+ for segment , track , label in self .itertracks ( yield_label = True ):
228230 if label in update :
229231 _segments [label ].append (segment )
230232
@@ -270,6 +272,13 @@ def itersegments(self):
270272 """
271273 return iter (self ._tracks )
272274
275+ @overload
276+ def itertracks (self , yield_label : Literal [False ] = ...) -> Iterator [SegmentTrack ]: ...
277+ @overload
278+ def itertracks (self , yield_label : Literal [True ]) -> Iterator [SegmentTrackLabel ]: ...
279+ @overload
280+ def itertracks (self , yield_label : bool ) -> Iterator [Union [SegmentTrack , SegmentTrackLabel ]]: ...
281+
273282 def itertracks (
274283 self , yield_label : bool = False
275284 ) -> Iterator [Union [SegmentTrack , SegmentTrackLabel ]]:
@@ -292,7 +301,7 @@ def itertracks(
292301 >>> for segment, track in annotation.itertracks():
293302 ... # do something with the track
294303
295- >>> for segment, track, label in annotation.itertracks_with_labels( ):
304+ >>> for segment, track, label in annotation.itertracks(yield_label=True ):
296305 ... # do something with the track and its label
297306 """
298307
@@ -307,11 +316,11 @@ def itertracks(
307316
308317 def itertracks_with_labels (self ) -> Iterator [SegmentTrackLabel ]:
309318 """Typed version of :func:`itertracks`(yield_label=True)"""
310- return self .itertracks (yield_label = True ) # type: ignore
319+ return self .itertracks (yield_label = True )
311320
312321 def itertracks_without_labels (self ) -> Iterator [SegmentTrack ]:
313322 """Typed version of :func:`itertracks`(yield_label=False)"""
314- return self .itertracks (yield_label = False ) # type: ignore
323+ return self .itertracks (yield_label = False )
315324
316325 def _updateTimeline (self ):
317326 self ._timeline = Timeline (segments = self ._tracks , uri = self .uri )
@@ -358,14 +367,14 @@ def __eq__(self, other: "Annotation"):
358367 labels are equal.
359368 """
360369 pairOfTracks = itertools .zip_longest (
361- self .itertracks_with_labels ( ), other .itertracks_with_labels ( )
370+ self .itertracks ( yield_label = True ), other .itertracks ( yield_label = True )
362371 )
363372 return all (t1 == t2 for t1 , t2 in pairOfTracks )
364373
365374 def __ne__ (self , other : "Annotation" ):
366375 """Inequality"""
367376 pairOfTracks = itertools .zip_longest (
368- self .itertracks_with_labels ( ), other .itertracks_with_labels ( )
377+ self .itertracks ( yield_label = True ), other .itertracks ( yield_label = True )
369378 )
370379
371380 return any (t1 != t2 for t1 , t2 in pairOfTracks )
@@ -404,7 +413,7 @@ def _iter_rttm(self) -> Iterator[Text]:
404413 f'containing spaces (got: "{ uri } ").'
405414 )
406415 raise ValueError (msg )
407- for segment , _ , label in self .itertracks_with_labels ( ):
416+ for segment , _ , label in self .itertracks ( yield_label = True ):
408417 if isinstance (label , Text ) and " " in label :
409418 msg = (
410419 f"Space-separated RTTM file format does not allow labels "
@@ -449,7 +458,7 @@ def _iter_lab(self) -> Iterator[Text]:
449458 iterator: Iterator[str]
450459 An iterator over LAB text lines
451460 """
452- for segment , _ , label in self .itertracks_with_labels ( ):
461+ for segment , _ , label in self .itertracks ( yield_label = True ):
453462 if isinstance (label , Text ) and " " in label :
454463 msg = (
455464 f"Space-separated LAB file format does not allow labels "
@@ -806,7 +815,7 @@ def __str__(self):
806815 """Human-friendly representation"""
807816 # TODO: use pandas.DataFrame
808817 return "\n " .join (
809- ["%s %s %s" % (s , t , l ) for s , t , l in self .itertracks_with_labels ( )]
818+ ["%s %s %s" % (s , t , l ) for s , t , l in self .itertracks ( yield_label = True )]
810819 )
811820
812821 def __delitem__ (self , key : Key ):
@@ -1051,7 +1060,7 @@ def update(self, annotation: "Annotation", copy: bool = False) -> "Annotation":
10511060 result = self .copy () if copy else self
10521061
10531062 # TODO speed things up by working directly with annotation internals
1054- for segment , track , label in annotation .itertracks_with_labels ( ):
1063+ for segment , track , label in annotation .itertracks ( yield_label = True ):
10551064 result [segment , track ] = label
10561065
10571066 return result
@@ -1255,7 +1264,7 @@ def rename_tracks(self, generator: Union[LabelGenerator, Iterable[str], Iterable
12551264 raise ValueError ("generator must be 'string', 'int', or iterable" )
12561265
12571266 # TODO speed things up by working directly with annotation internals
1258- for s , _ , label in self .itertracks_with_labels ( ):
1267+ for s , _ , label in self .itertracks ( yield_label = True ):
12591268 renamed [s , next (generator_ )] = label
12601269 return renamed
12611270
@@ -1338,7 +1347,7 @@ def relabel_tracks(self, generator: LabelGenerator = "string") -> "Annotation":
13381347 generator = int_generator ()
13391348
13401349 relabeled = self .empty ()
1341- for s , t , _ in self .itertracks_with_labels ( ):
1350+ for s , t , _ in self .itertracks ( yield_label = True ):
13421351 relabeled [s , t ] = next (generator )
13431352
13441353 return relabeled
0 commit comments