Skip to content

Commit 395612b

Browse files
author
Suqi Sun
committed
Add AudioAnnotation entry to ontology system
1 parent ab5248e commit 395612b

File tree

5 files changed

+390
-4
lines changed

5 files changed

+390
-4
lines changed

forte/data/data_pack.py

+79-4
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
Group,
4747
SinglePackEntries,
4848
Generics,
49+
AudioAnnotation,
4950
)
5051
from forte.data.span import Span
5152
from forte.data.types import ReplaceOperationsType, DataRequest
@@ -166,6 +167,7 @@ def __init__(self, pack_name: Optional[str] = None):
166167
self.links: SortedList[Link] = SortedList()
167168
self.groups: SortedList[Group] = SortedList()
168169
self.generics: SortedList[Generics] = SortedList()
170+
self.audio_annotations: SortedList[AudioAnnotation] = SortedList()
169171

170172
self.__replace_back_operations: ReplaceOperationsType = []
171173
self.__processed_original_spans: List[Tuple[Span, Span]] = []
@@ -185,6 +187,7 @@ def __getstate__(self):
185187
state["links"] = list(state["links"])
186188
state["groups"] = list(state["groups"])
187189
state["generics"] = list(state["generics"])
190+
state["audio_annotations"] = list(state["audio_annotations"])
188191
return state
189192

190193
def __setstate__(self, state):
@@ -212,12 +215,14 @@ def __setstate__(self, state):
212215
self.links = as_sorted_error_check(self.links)
213216
self.groups = as_sorted_error_check(self.groups)
214217
self.generics = as_sorted_error_check(self.generics)
218+
self.audio_annotations = as_sorted_error_check(self.audio_annotations)
215219

216220
self._index = DataIndex()
217221
self._index.update_basic_index(list(self.annotations))
218222
self._index.update_basic_index(list(self.links))
219223
self._index.update_basic_index(list(self.groups))
220224
self._index.update_basic_index(list(self.generics))
225+
self._index.update_basic_index(list(self.audio_annotations))
221226

222227
for a in self.annotations:
223228
a.set_pack(self)
@@ -231,11 +236,15 @@ def __setstate__(self, state):
231236
for a in self.generics:
232237
a.set_pack(self)
233238

239+
for a in self.audio_annotations:
240+
a.set_pack(self)
241+
234242
def __iter__(self):
235243
yield from self.annotations
236244
yield from self.links
237245
yield from self.groups
238246
yield from self.generics
247+
yield from self.audio_annotations
239248

240249
def _init_meta(self, pack_name: Optional[str] = None) -> Meta:
241250
return Meta(pack_name)
@@ -341,6 +350,27 @@ def num_generics_entries(self):
341350
"""
342351
return len(self.generics)
343352

353+
@property
354+
def all_audio_annotations(self) -> Iterator[AudioAnnotation]:
355+
"""
356+
An iterator of all audio annotations in this data pack.
357+
358+
Returns: Iterator of all audio annotations, of
359+
type :class:`~forte.data.ontology.top.AudioAnnotation`.
360+
361+
"""
362+
yield from self.audio_annotations
363+
364+
@property
365+
def num_audio_annotations(self):
366+
"""
367+
Number of audio annotations in this data pack.
368+
369+
Returns: Number of audio annotations.
370+
371+
"""
372+
return len(self.audio_annotations)
373+
344374
def get_span_text(self, begin: int, end: int) -> str:
345375
r"""Get the text in the data pack contained in the span.
346376
@@ -353,6 +383,23 @@ def get_span_text(self, begin: int, end: int) -> str:
353383
"""
354384
return self._text[begin:end]
355385

386+
def get_span_audio(self, begin: int, end: int) -> str:
387+
r"""Get the audio in the data pack contained in the span.
388+
389+
Args:
390+
begin (int): begin index to query.
391+
end (int): end index to query.
392+
393+
Returns:
394+
The audio within this span.
395+
"""
396+
if self._audio is None:
397+
raise ProcessExecutionException(
398+
"The audio payload of this DataPack is not set. Please call"
399+
" method `set_audio` before running `get_span_audio`."
400+
)
401+
return self._audio[begin:end]
402+
356403
def set_text(
357404
self,
358405
text: str,
@@ -619,10 +666,13 @@ def __add_entry_with_check(
619666
target = self.groups
620667
elif isinstance(entry, Generics):
621668
target = self.generics
669+
elif isinstance(entry, AudioAnnotation):
670+
target = self.audio_annotations
622671
else:
623672
raise ValueError(
624673
f"Invalid entry type {type(entry)}. A valid entry "
625-
f"should be an instance of Annotation, Link, Group of Generics."
674+
f"should be an instance of Annotation, Link, Group, Generics "
675+
"or AudioAnnotation."
626676
)
627677

628678
if not allow_duplicate:
@@ -664,6 +714,8 @@ def delete_entry(self, entry: EntryType):
664714
target = self.groups
665715
elif isinstance(entry, Generics):
666716
target = self.generics
717+
elif isinstance(entry, AudioAnnotation):
718+
target = self.audio_annotations
667719
else:
668720
raise ValueError(
669721
f"Invalid entry type {type(entry)}. A valid entry "
@@ -779,6 +831,9 @@ def get_data(
779831
link_types: Dict[Type[Link], Union[Dict, List]] = {}
780832
group_types: Dict[Type[Group], Union[Dict, List]] = {}
781833
generics_types: Dict[Type[Generics], Union[Dict, List]] = {}
834+
audio_annotation_types: Dict[
835+
Type[AudioAnnotation], Union[Dict, List]
836+
] = {}
782837

783838
if request is not None:
784839
for key_, value in request.items():
@@ -791,6 +846,8 @@ def get_data(
791846
group_types[key] = value
792847
elif issubclass(key, Generics):
793848
generics_types[key] = value
849+
elif issubclass(key, AudioAnnotation):
850+
audio_annotation_types[key] = value
794851

795852
context_args = annotation_types.get(context_type_)
796853

@@ -866,6 +923,12 @@ def get_data(
866923
"currently not supported."
867924
)
868925

926+
if audio_annotation_types:
927+
raise NotImplementedError(
928+
"Querying audio annotation types based on ranges is "
929+
"currently not supported."
930+
)
931+
869932
yield data
870933

871934
def _parse_request_args(self, a_type, a_args):
@@ -1127,6 +1190,12 @@ def iter_in_range(
11271190
for group in self.groups:
11281191
if self._index.in_span(group, range_annotation.span):
11291192
yield group
1193+
elif issubclass(entry_type, AudioAnnotation):
1194+
for audio_annotation in self.audio_annotations:
1195+
if self._index.in_span(
1196+
audio_annotation, range_annotation.span
1197+
):
1198+
yield audio_annotation
11301199

11311200
def get( # type: ignore
11321201
self,
@@ -1224,6 +1293,7 @@ def require_annotations() -> bool:
12241293
issubclass(entry_type_, Annotation)
12251294
or issubclass(entry_type_, Link)
12261295
or issubclass(entry_type_, Group)
1296+
or issubclass(entry_type_, AudioAnnotation)
12271297
):
12281298
entry_iter = self.iter_in_range(entry_type_, range_annotation)
12291299
elif issubclass(entry_type_, Annotation):
@@ -1232,6 +1302,8 @@ def require_annotations() -> bool:
12321302
entry_iter = self.links
12331303
elif issubclass(entry_type_, Group):
12341304
entry_iter = self.groups
1305+
elif issubclass(entry_type_, AudioAnnotation):
1306+
entry_iter = self.audio_annotations
12351307
else:
12361308
raise ValueError(
12371309
f"The requested type {str(entry_type_)} is not supported."
@@ -1426,8 +1498,9 @@ def in_span(self, inner_entry: Union[int, Entry], span: Span) -> bool:
14261498
r"""Check whether the ``inner entry`` is within the given ``span``. The
14271499
criterion are as followed:
14281500
1429-
Annotation entries: they are considered in a span if the begin is not
1430-
smaller than `span.begin` and the end is not larger than `span.end`.
1501+
Annotation/AudioAnnotation entries: they are considered in a span if the
1502+
begin is not smaller than `span.begin` and the end is not larger than
1503+
`span.end`.
14311504
14321505
Link entries: if the parent and child of the links are both
14331506
`Annotation` type, this link will be considered in span if both parent
@@ -1463,7 +1536,9 @@ def in_span(self, inner_entry: Union[int, Entry], span: Span) -> bool:
14631536
inner_begin = -1
14641537
inner_end = -1
14651538

1466-
if isinstance(inner_entry, Annotation):
1539+
if isinstance(inner_entry, Annotation) or isinstance(
1540+
inner_entry, AudioAnnotation
1541+
):
14671542
inner_begin = inner_entry.begin
14681543
inner_end = inner_entry.end
14691544
elif isinstance(inner_entry, Link):

forte/data/ontology/top.py

+129
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
"Query",
3939
"SinglePackEntries",
4040
"MultiPackEntries",
41+
"AudioAnnotation",
4142
]
4243

4344
QueryType = Union[Dict[str, Any], np.ndarray]
@@ -545,5 +546,133 @@ def update_results(self, pid_to_score: Dict[str, float]):
545546
self.results.update(pid_to_score)
546547

547548

549+
@total_ordering
550+
class AudioAnnotation(Entry):
551+
r"""AudioAnnotation type entries, such as "recording" and "audio utterance".
552+
Each audio annotation has a :class:`Span` corresponding to its offset
553+
in the audio. Most methods in this class are the same as the ones in
554+
:class:`Annotation`, except that it replaces property `text` with `audio`.
555+
556+
Args:
557+
pack (PackType): The container that this audio annotation
558+
will be added to.
559+
begin (int): The offset of the first sample in the audio annotation.
560+
end (int): The offset of the last sample in the audio annotation + 1.
561+
"""
562+
563+
def __init__(self, pack: PackType, begin: int, end: int):
564+
self._span: Optional[Span] = None
565+
self._begin: int = begin
566+
self._end: int = end
567+
super().__init__(pack)
568+
569+
@property
570+
def audio(self):
571+
if self.pack is None:
572+
raise ValueError(
573+
"Cannot get audio because annotation is not "
574+
"attached to any data pack."
575+
)
576+
return self.pack.get_span_audio(self.begin, self.end)
577+
578+
def __getstate__(self):
579+
r"""For serializing AudioAnnotation, we should create Span annotations
580+
for compatibility purposes.
581+
"""
582+
self._span = Span(self._begin, self._end)
583+
state = super().__getstate__()
584+
state.pop("_begin")
585+
state.pop("_end")
586+
return state
587+
588+
def __setstate__(self, state):
589+
"""
590+
For de-serializing AudioAnnotation, we load the begin, end from Span,
591+
for compatibility purposes.
592+
"""
593+
super().__setstate__(state)
594+
self._begin = self._span.begin
595+
self._end = self._span.end
596+
597+
@property
598+
def span(self) -> Span:
599+
# Delay span creation at usage.
600+
if self._span is None:
601+
self._span = Span(self._begin, self._end)
602+
return self._span
603+
604+
@property
605+
def begin(self):
606+
return self._begin
607+
608+
@property
609+
def end(self):
610+
return self._end
611+
612+
def __eq__(self, other):
613+
r"""The eq function of :class:`AudioAnnotation`.
614+
By default, :class:`AudioAnnotation` objects are regarded as the same if
615+
they have the same type, span, and are generated by the same component.
616+
617+
Users can define their own eq function by themselves but this must
618+
be consistent to :meth:`hash`.
619+
"""
620+
if other is None:
621+
return False
622+
return (type(self), self.begin, self.end) == (
623+
type(other),
624+
other.begin,
625+
other.end,
626+
)
627+
628+
def __lt__(self, other):
629+
r"""To support total_ordering, `AudioAnnotation` must implement
630+
`__lt__`. The ordering is defined in the following way:
631+
632+
1. If the begin of the audio annotations are different, the one with
633+
larger begin will be larger.
634+
2. In the case where the begins are the same, the one with larger
635+
end will be larger.
636+
3. In the case where both offsets are the same, we break the tie using
637+
the normal sorting of the class name.
638+
"""
639+
if self.begin == other.begin:
640+
if self.end == other.end:
641+
return str(type(self)) < str(type(other))
642+
return self.end < other.end
643+
else:
644+
return self.begin < other.begin
645+
646+
@property
647+
def index_key(self) -> int:
648+
return self.tid
649+
650+
def get(
651+
self,
652+
entry_type: Union[str, Type[EntryType]],
653+
components: Optional[Union[str, Iterable[str]]] = None,
654+
include_sub_type=True,
655+
) -> Iterable[EntryType]:
656+
"""
657+
This function wraps the :meth:`~forte.data.DataPack.get()` method to find
658+
entries "covered" by this audio annotation. See that method for more
659+
information. For usage details, refer to
660+
:meth:`forte.data.ontology.top.Annotation.get()`.
661+
662+
Args:
663+
entry_type (type): The type of entries requested.
664+
components (str or list, optional): The component (creator)
665+
generating the entries requested. If `None`, will return valid
666+
entries generated by any component.
667+
include_sub_type (bool): whether to consider the sub types of
668+
the provided entry type. Default `True`.
669+
670+
Yields:
671+
Each `Entry` found using this method.
672+
673+
"""
674+
yield from self.pack.get(entry_type, self, components, include_sub_type)
675+
676+
548677
SinglePackEntries = (Link, Group, Annotation, Generics)
549678
MultiPackEntries = (MultiPackLink, MultiPackGroup, MultiPackGeneric)

forte/ontology_specs/base_ontology.json

+24
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,30 @@
415415
"item_type": "ft.onto.base_ontology.Phrase"
416416
}
417417
]
418+
},
419+
{
420+
"entry_name": "ft.onto.base_ontology.Recording",
421+
"parent_entry": "forte.data.ontology.top.AudioAnnotation",
422+
"description": "A span based annotation `Recording`, normally used to represent a recording.",
423+
"attributes": [
424+
{
425+
"name": "recording_class",
426+
"type": "List",
427+
"item_type": "str",
428+
"description": "A list of class names that the recording belongs to."
429+
}
430+
]
431+
},
432+
{
433+
"entry_name": "ft.onto.base_ontology.AudioUtterance",
434+
"parent_entry": "forte.data.ontology.top.AudioAnnotation",
435+
"description": "A span based annotation `AudioUtterance`, normally used to represent an utterance in dialogue.",
436+
"attributes": [
437+
{
438+
"name": "speaker",
439+
"type": "str"
440+
}
441+
]
418442
}
419443
]
420444
}

0 commit comments

Comments
 (0)