Skip to content

Commit f52e728

Browse files
authored
Merge pull request #635 from qinzzz/issue-606
implement add_annotation_raw
2 parents dff572f + 4de2e70 commit f52e728

File tree

2 files changed

+218
-55
lines changed

2 files changed

+218
-55
lines changed

forte/data/data_store.py

+147-26
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
1313

14-
15-
from typing import List, Iterator, Tuple, Optional, Any
14+
from typing import Dict, List, Iterator, Tuple, Optional, Any
1615
import uuid
1716
from bisect import bisect_left
1817
from heapq import heappush, heappop
18+
from sortedcontainers import SortedList
1919

2020
from forte.utils import get_class
2121
from forte.data.base_store import BaseStore
@@ -29,7 +29,9 @@ class DataStore(BaseStore):
2929
# TODO: temporarily disable this for development purposes.
3030
# pylint: disable=pointless-string-statement
3131

32-
def __init__(self, onto_file_path: Optional[str] = None):
32+
def __init__(
33+
self, onto_file_path: Optional[str] = None, dynamically_add_type=True
34+
):
3335
r"""An implementation of the data store object that mainly uses
3436
primitive types. This class will be used as the internal data
3537
representation behind data pack. The usage of primitive types provides
@@ -121,15 +123,23 @@ def __init__(self, onto_file_path: Optional[str] = None):
121123
onto_file_path (str, optional): the path to the ontology file.
122124
"""
123125
super().__init__()
124-
self.onto_file_path = onto_file_path
126+
127+
if onto_file_path is None and not dynamically_add_type:
128+
raise RuntimeError(
129+
"DataStore is initialized with no existing types. Setting"
130+
"dynamically_add_type to False without providing onto_file_path"
131+
"will lead to no usable type in DataStore."
132+
)
133+
self._onto_file_path = onto_file_path
134+
self._dynamically_add_type = dynamically_add_type
125135

126136
"""
127137
The ``_type_attributes`` is a private dictionary that provides
128138
``type_name``, their parent entry, and the order of corresponding attributes.
129139
The keys are fully qualified names of every type; The value is a dictionary with
130140
two keys. Key ``attribute`` provides an inner dictionary with all valid attributes
131-
for this type and the indices of attributes among these lists. Key ``parent_entry``
132-
is a string representing the direct parent of this type.
141+
for this type and the indices of attributes among these lists. Key ``parent_class``
142+
is a string representing the ancesters of this type.
133143
134144
This structure is supposed to be built dynamically. When a user adds new entries,
135145
data_store will check unknown types and add them to ``_type_attributes``.
@@ -144,20 +154,20 @@ def __init__(self, onto_file_path: Optional[str] = None):
144154
# "attributes": {"pos": 4, "ud_xpos": 5,
145155
# "lemma": 6, "chunk": 7, "ner": 8, "sense": 9,
146156
# "is_root": 10, "ud_features": 11, "ud_misc": 12},
147-
# "parent_entry": "forte.data.ontology.top.Annotation", },
157+
# "parent_class": "forte.data.ontology.top.Annotation", },
148158
# "ft.onto.base_ontology.Document": {
149159
# "attributes": {"document_class": 4,
150160
# "sentiment": 5, "classifications": 6},
151-
# "parent_entry": "forte.data.ontology.top.Annotation", },
161+
# "parent_class": "forte.data.ontology.top.Annotation", },
152162
# "ft.onto.base_ontology.Sentence": {
153163
# "attributes": {"speaker": 4,
154164
# "part_id": 5, "sentiment": 6,
155165
# "classification": 7, "classifications": 8},
156-
# "parent_entry": "forte.data.ontology.top.Annotation", }
166+
# "parent_class": "forte.data.ontology.top.Annotation", }
157167
# }
158168
"""
159169
self._type_attributes: dict = {}
160-
if self.onto_file_path:
170+
if self._onto_file_path:
161171
self._parse_onto_file()
162172

163173
"""
@@ -190,6 +200,104 @@ def _new_tid(self) -> int:
190200
r"""This function generates a new ``tid`` for an entry."""
191201
return uuid.uuid4().int
192202

203+
def _get_type_info(self, type_name: str) -> Dict[str, Any]:
204+
"""
205+
Get the dictionary containing type information from ``self._type_attributes``.
206+
If the ``type_name`` does not currecntly exists and dynamic import is enabled,
207+
this function will add a new key-value pair into ``self._type_attributes``. The
208+
value consists of a full attribute-to-index dictionary and an empty parent set.
209+
210+
This function returns a dictionary containing an attribute dict and a set of parent
211+
entries of the given type. For example:
212+
213+
.. code-block:: python
214+
215+
"ft.onto.base_ontology.Sentence": {
216+
"attributes": {
217+
"speaker": 4,
218+
"part_id": 5,
219+
"sentiment": 6,
220+
"classification": 7,
221+
"classifications": 8,
222+
},
223+
"parent_class": set(),
224+
}
225+
226+
Args:
227+
type_name (str): The fully qualified type name of a type.
228+
Returns:
229+
attr_dict (dict): The dictionary containing an attribute dict and a set of parent
230+
entries of the given type.
231+
Raises:
232+
RuntimeError: When the type is not provided by ontology file and
233+
dynamic import is disabled.
234+
"""
235+
# check if type is in dictionary
236+
if type_name in self._type_attributes:
237+
return self._type_attributes[type_name]
238+
if not self._dynamically_add_type:
239+
raise ValueError(
240+
f"{type_name} is not an existing type in current data store."
241+
f"Dynamically add type is disabled."
242+
f"Set dynamically_add_type=True if you need to use types other than"
243+
f"types specified in the ontology file."
244+
)
245+
# get attribute dictionary
246+
attributes = self._get_entry_attributes_by_class(type_name)
247+
248+
attr_dict = {}
249+
attr_idx = constants.ENTRY_TYPE_INDEX + 1
250+
for attr_name in attributes:
251+
attr_dict[attr_name] = attr_idx
252+
attr_idx += 1
253+
254+
new_entry_info = {
255+
"attributes": attr_dict,
256+
"parent_class": set(),
257+
}
258+
self._type_attributes[type_name] = new_entry_info
259+
260+
return new_entry_info
261+
262+
def _get_type_attribute_dict(self, type_name: str) -> Dict[str, int]:
263+
"""Get the attribute dict of an entry type. The attribute dict maps
264+
attribute names to a list of consecutive integers as indicies. For example:
265+
.. code-block:: python
266+
267+
"attributes": {
268+
"speaker": 4,
269+
"part_id": 5,
270+
"sentiment": 6,
271+
"classification": 7,
272+
"classifications": 8,
273+
},
274+
275+
Args:
276+
type_name (str): The fully qualified type name of a type.
277+
Returns:
278+
attr_dict (dict): The attribute-to-index dictionary of an entry.
279+
"""
280+
return self._get_type_info(type_name)["attributes"]
281+
282+
def _get_type_parent(self, type_name: str) -> str:
283+
"""Get a set of parent names of an entry type. The set is a subset of all
284+
ancestors of the given type.
285+
Args:
286+
type_name (str): The fully qualified type name of a type.
287+
Returns:
288+
parent_class (str): The parent entry name of an entry.
289+
"""
290+
return self._get_type_info(type_name)["parent_class"]
291+
292+
def _num_attributes_for_type(self, type_name: str) -> int:
293+
"""Get the length of the attribute dict of an entry type.
294+
Args:
295+
type_name (str): The fully qualified type name of the new entry.
296+
Returns:
297+
attr_dict (dict): The attributes-to-index dict of an entry.
298+
"""
299+
return len(self._get_type_attribute_dict(type_name))
300+
193301
def _new_annotation(self, type_name: str, begin: int, end: int) -> List:
194302
r"""This function generates a new annotation with default fields.
195303
All default fields are filled with None.
@@ -207,8 +315,10 @@ def _new_annotation(self, type_name: str, begin: int, end: int) -> List:
207315

208316
tid: int = self._new_tid()
209317
entry: List[Any]
318+
210319
entry = [begin, end, tid, type_name]
211-
entry += len(self._type_attributes[type_name]) * [None]
320+
entry += self._num_attributes_for_type(type_name) * [None]
321+
212322
return entry
213323

214324
def _new_link(
@@ -230,8 +340,10 @@ def _new_link(
230340

231341
tid: int = self._new_tid()
232342
entry: List[Any]
343+
233344
entry = [parent_tid, child_tid, tid, type_name]
234-
entry += len(self._type_attributes[type_name]) * [None]
345+
entry += self._num_attributes_for_type(type_name) * [None]
346+
235347
return entry
236348

237349
def _new_group(self, type_name: str, member_type: str) -> List:
@@ -249,21 +361,22 @@ def _new_group(self, type_name: str, member_type: str) -> List:
249361
"""
250362

251363
tid: int = self._new_tid()
364+
252365
entry = [member_type, [], tid, type_name]
253-
entry += len(self._type_attributes[type_name]) * [None]
366+
entry += self._num_attributes_for_type(type_name) * [None]
367+
254368
return entry
255369

256370
def _is_annotation(self, type_name: str) -> bool:
257371
r"""This function takes a type_id and returns whether a type
258372
is an annotation type or not.
259-
260373
Args:
261374
type_name (str): The name of type in `self.__elements`.
262-
263375
Returns:
264376
A boolean value whether this type_id belongs to an annotation
265377
type or not.
266378
"""
379+
# TODO: use is_subclass() in DataStore to replace this
267380
entry_class = get_class(type_name)
268381
return issubclass(entry_class, (Annotation, AudioAnnotation))
269382

@@ -286,7 +399,15 @@ def add_annotation_raw(self, type_name: str, begin: int, end: int) -> int:
286399
# annotation type entry data with default fields.
287400
# A reference to the entry should be store in both self.__elements and
288401
# self.__entry_dict.
289-
raise NotImplementedError
402+
entry = self._new_annotation(type_name, begin, end)
403+
try:
404+
self.__elements[type_name].add(entry)
405+
except KeyError:
406+
self.__elements[type_name] = SortedList(key=lambda s: (s[0], s[1]))
407+
self.__elements[type_name].add(entry)
408+
tid = entry[constants.TID_INDEX]
409+
self.__entry_dict[tid] = entry
410+
return tid
290411

291412
def add_link_raw(
292413
self, type_name: str, parent_tid: int, child_tid: int
@@ -340,16 +461,17 @@ def set_attribute(self, tid: int, attr_name: str, attr_value: Any):
340461
KeyError: when ``tid`` or ``attr_name`` is not found.
341462
"""
342463
try:
343-
entry_type = self.__entry_dict[tid][constants.ENTRY_TYPE_INDEX]
464+
entry = self.__entry_dict[tid]
465+
entry_type = entry[constants.ENTRY_TYPE_INDEX]
344466
except KeyError as e:
345467
raise KeyError(f"Entry with tid {tid} not found.") from e
346468

347469
try:
348-
attr_id = self._type_attributes[entry_type][attr_name]
470+
attr_id = self._get_type_attribute_dict(entry_type)[attr_name]
349471
except KeyError as e:
350472
raise KeyError(f"{entry_type} has no {attr_name} attribute.") from e
351473

352-
self._set_attr(tid, attr_id, attr_value)
474+
entry[attr_id] = attr_value
353475

354476
def _set_attr(self, tid: int, attr_id: int, attr_value: Any):
355477
r"""This function locates the entry data with ``tid`` and sets its
@@ -381,16 +503,17 @@ def get_attribute(self, tid: int, attr_name: str) -> Any:
381503
KeyError: when ``tid`` or ``attr_name`` is not found.
382504
"""
383505
try:
384-
entry_type = self.__entry_dict[tid][constants.ENTRY_TYPE_INDEX]
506+
entry = self.__entry_dict[tid]
507+
entry_type = entry[constants.ENTRY_TYPE_INDEX]
385508
except KeyError as e:
386509
raise KeyError(f"Entry with tid {tid} not found.") from e
387510

388511
try:
389-
attr_id = self._type_attributes[entry_type][attr_name]
512+
attr_id = self._get_type_attribute_dict(entry_type)[attr_name]
390513
except KeyError as e:
391514
raise KeyError(f"{entry_type} has no {attr_name} attribute.") from e
392515

393-
return self._get_attr(tid, attr_id)
516+
return entry[attr_id]
394517

395518
def _get_attr(self, tid: int, attr_id: int) -> Any:
396519
r"""This function locates the entry data with ``tid`` and gets the value
@@ -439,7 +562,7 @@ def delete_entry(self, tid: int):
439562
if self._is_annotation(type_name):
440563
entry_index = bisect_left(target_list, entry_data)
441564
else: # if it's group or link, use the index in entry_list
442-
entry_index = entry_data[-1]
565+
entry_index = entry_data[constants.ENTRY_INDEX_INDEX]
443566

444567
if (
445568
entry_index >= len(target_list)
@@ -455,8 +578,6 @@ def delete_entry(self, tid: int):
455578
def _delete_entry_by_loc(self, type_name: str, index_id: int):
456579
r"""It removes an entry of `index_id` by taking both the `type_id`
457580
and `index_id`. Called by `delete_entry()`.
458-
This function will raise an IndexError if the `type_id` or `index_id`
459-
is invalid.
460581
461582
Args:
462583
type_id (int): The index of the list in ``self.__elements``.
@@ -769,7 +890,7 @@ def _parse_onto_file(self):
769890
A user can use classes both in the ontology specification file and their parent
770891
entries's paths.
771892
"""
772-
if self.onto_file_path is None:
893+
if self._onto_file_path is None:
773894
return
774895
raise NotImplementedError
775896

0 commit comments

Comments
 (0)