Skip to content

Commit d48b1a9

Browse files
authored
Merge branch 'master' into payload_data_setter
2 parents 9fc8b3f + d5e69a4 commit d48b1a9

13 files changed

+1215
-703
lines changed

forte/common/constants.py

+27-15
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,21 @@
11
# DataStore constants
2-
# The index storing begin location in the internal entry data of DataStore.
3-
BEGIN_INDEX = 0
2+
# The name of the attribute storing the begin location in the internal
3+
# entry data of DataStore.
4+
BEGIN_ATTR_NAME = "begin"
45

5-
# The index storing end location in the internal entry data of DataStore.
6-
END_INDEX = 1
6+
# The name of the attribute storing the end location in the internal
7+
# entry data of DataStore.
8+
END_ATTR_NAME = "end"
79

810
# The index storing tid in the internal entry data of DataStore.
9-
TID_INDEX = 2
11+
TID_INDEX = 0
1012

1113
# The index storing entry type in the internal entry data of DataStore.
12-
ENTRY_TYPE_INDEX = 3
14+
ENTRY_TYPE_INDEX = 1
15+
16+
# The name of the attribute storing the payload index location in the
17+
# internal entry data of DataStore.
18+
PAYLOAD_ID_ATTR_NAME = "payload_idx"
1319

1420
# The index storing entry type (specific to Link and Group type). It is saved
1521
# in the `tid_idx_dict` in DataStore.
@@ -19,20 +25,26 @@
1925
# in the `tid_idx_dict` in DataStore.
2026
ENTRY_DICT_ENTRY_INDEX = 1
2127

22-
# The index storing parent entry tid in Link entries
23-
PARENT_TID_INDEX = 0
28+
# The name of the attribute storing the parent entry tid in Link entries
29+
PARENT_TID_ATTR_NAME = "parent"
30+
31+
# The name of the attribute storing the parent entry type in Link entries
32+
PARENT_TYPE_ATTR_NAME = "parent_type"
33+
34+
# The name of the attribute storing the child entry tid in Link entries
35+
CHILD_TID_ATTR_NAME = "child"
2436

25-
# The index storing child entry tid in Link entries
26-
CHILD_TID_INDEX = 1
37+
# The name of the attribute storing the child entry type in Link entries
38+
CHILD_TYPE_ATTR_NAME = "child_type"
2739

28-
# The index storing member entry type in Group entries
29-
MEMBER_TYPE_INDEX = 0
40+
# The name of the attribute storing the member entry type in Group entries
41+
MEMBER_TYPE_ATTR_NAME = "member_type"
3042

31-
# The index storing the list of member entries tid in Group entries
32-
MEMBER_TID_INDEX = 1
43+
# The name of the attribute storing the list of member entries tid in Group entries
44+
MEMBER_TID_ATTR_NAME = "members"
3345

3446
# The index where the first attribute appears in the internal entry data of DataStore.
35-
ATTR_BEGIN_INDEX = 4
47+
ATTR_BEGIN_INDEX = 2
3648

3749
# Name of the key to access the attribute dict of an entry type from
3850
# ``_type_attributes`` of ``DataStore``.

forte/data/base_pack.py

+77-4
Original file line numberDiff line numberDiff line change
@@ -436,12 +436,29 @@ def record_field(self, entry_id: int, field_name: str):
436436
self._field_records[c] = {(entry_id, field_name)}
437437

438438
def on_entry_creation(
439-
self, entry: Entry, component_name: Optional[str] = None
439+
self,
440+
entry: Entry,
441+
component_name: Optional[str] = None,
440442
):
441443
"""
442444
Call this when adding a new entry, will be called
443445
in :class:`~forte.data.ontology.core.Entry` when
444-
its `__init__` function is called.
446+
its `__init__` function is called. This method does
447+
the following 2 operations with regards to creating
448+
a new entry.
449+
450+
- All ``dataclass`` attributes of the entry to be created
451+
are stored in the class level dictionary of
452+
:class:`~forte.data.ontology.core.Entry` called
453+
``cached_attributes_data``. This is used to initialize
454+
the corresponding entry's objects data store entry
455+
- On creation of the data store entry, this methods associates
456+
``getter`` and ``setter`` properties to all `dataclass`
457+
attributes of this entry to allow direct interaction
458+
between the attributes of the entry and their copy being
459+
stored in the data store. For example, the `setter` method
460+
updates the data store value of an attribute of a given entry
461+
whenever the attribute in the entry's object is updated.
445462
446463
Args:
447464
entry: The entry to be added.
@@ -541,16 +558,51 @@ def entry_getter(cls: Entry, attr_name: str):
541558
def entry_setter(cls: Entry, value: Any, attr_name: str):
542559
"""A setter function for dataclass fields of entry object.
543560
When the value contains entry objects, we will convert them into
544-
``tid``s before storing to ``DataStore``.
561+
``tid``s before storing to ``DataStore``. Additionally, if the entry
562+
setter method is called on an attribute that does not have a pack
563+
associated with it (as is the case during intialization), the value
564+
of the atttribute is stored in the class level cache of the ``Entry``
565+
class. On the other hand, if a pack is associated with the entry,
566+
the value will directly be stored in the data store.
545567
546568
Args:
547569
cls: An ``Entry`` class object.
548570
value: The value to be assigned to the attribute.
549571
attr_name: The name of the attribute.
550572
"""
551573
attr_value: Any
574+
575+
try:
576+
pack = cls.pack
577+
except AttributeError as err:
578+
# This is the case when an object of an entry that has already been
579+
# created before (which means an setter and getter properties are
580+
# associated with its dataclass fields) is trying to be initialized.
581+
# In this case, a pack is not yet associated with this entry. Thus,
582+
# we store the initial values dataclass fields of such entries in the
583+
# _cached_attribute_data of the Entry class.
584+
585+
# pylint: disable=protected-access
586+
if cls.entry_type() not in Entry._cached_attribute_data:
587+
Entry._cached_attribute_data[cls.entry_type()] = {}
588+
589+
if (
590+
attr_name
591+
not in Entry._cached_attribute_data[cls.entry_type()]
592+
):
593+
Entry._cached_attribute_data[cls.entry_type()][
594+
attr_name
595+
] = value
596+
return
597+
else:
598+
raise KeyError(
599+
"You are trying to overwrite the value "
600+
f"of {attr_name} for a data store entry "
601+
"before it is created."
602+
) from err
603+
552604
data_store_ref = (
553-
cls.pack._data_store # pylint: disable=protected-access
605+
pack._data_store # pylint: disable=protected-access
554606
)
555607

556608
attr_type = data_store_ref.get_attr_type(
@@ -594,6 +646,27 @@ def entry_setter(cls: Entry, value: Any, attr_name: str):
594646
tid=cls.tid, attr_name=attr_name, attr_value=attr_value
595647
)
596648

649+
# If this is the first time an entry of this type is
650+
# created, its attributes do not have a getter and setter
651+
# property associated with them. We can thus assume that there
652+
# no key in the _cached_attribute_data dictionary that has yet
653+
# been created to store the dataclass fields of this entry. Thus,
654+
# we create an empty dictionary to store the dataclass fields
655+
# of this new entry and manually add all dataclass attributes
656+
# that have been initialized to the _cached_attribute_data dict.
657+
# We fetch the values of all dataclass fields by using the getattr
658+
# method.
659+
660+
# pylint: disable=protected-access
661+
if entry.entry_type() not in Entry._cached_attribute_data:
662+
Entry._cached_attribute_data[entry.entry_type()] = {}
663+
for name in entry.__dataclass_fields__:
664+
attr_val = getattr(entry, name, None)
665+
if attr_val is not None:
666+
Entry._cached_attribute_data[entry.entry_type()][
667+
name
668+
] = attr_val
669+
597670
# Save the input entry object in DataStore
598671
self._save_entry_to_data_store(entry=entry)
599672

forte/data/base_store.py

+9-11
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,8 @@
1313
# limitations under the License.
1414

1515
from abc import abstractmethod
16-
from typing import List, Iterator, Tuple, Any, Optional, Dict, Type
16+
from typing import List, Iterator, Tuple, Any, Optional, Dict
1717
import json
18-
from forte.data.ontology.core import Entry
1918

2019
__all__ = ["BaseStore"]
2120

@@ -128,10 +127,9 @@ def _deserialize(
128127
def add_entry_raw(
129128
self,
130129
type_name: str,
131-
attribute_data: List,
132-
base_class: Type[Entry],
133130
tid: Optional[int] = None,
134131
allow_duplicate: bool = True,
132+
attribute_data: Optional[List] = None,
135133
) -> int:
136134

137135
r"""
@@ -143,19 +141,19 @@ def add_entry_raw(
143141
144142
Args:
145143
type_name: The fully qualified type name of the new Entry.
146-
attribute_data: It is a list that stores attributes relevant to
147-
the entry being added. In order to keep the number of attributes
148-
same for all entries, the list is populated with trailing None's.
149-
base_class: The type of entry to add to the Data Store. This is
150-
a reference to the class of the entry that needs to be added
151-
to the DataStore. The reference can be to any of the classes
152-
supported by the function.
153144
tid: ``tid`` of the Entry that is being added.
154145
It's optional, and it will be
155146
auto-assigned if not given.
156147
allow_duplicate: Whether we allow duplicate in the DataStore. When
157148
it's set to False, the function will return the ``tid`` of
158149
existing entry if a duplicate is found. Default value is True.
150+
attribute_data: It is a `list` that stores attributes relevant to
151+
the entry being added. The attributes passed in
152+
`attributes_data` must be present in that entries
153+
`type_attributes` and must only be those entries which are
154+
relevant to the initialization of the entry. For example,
155+
begin and end position when creating an entry of type
156+
:class:`~forte.data.ontology.top.Annotation`.
159157
160158
Returns:
161159
``tid`` of the entry.

forte/data/data_pack.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -465,7 +465,7 @@ def get_payload_at(
465465
466466
"""
467467
supported_modality = [enum.name for enum in Modality]
468-
468+
payloads_length: int = 0
469469
try:
470470
# if modality.name == "text":
471471
if modality == Modality.Text:
@@ -1411,6 +1411,7 @@ def _generate_link_entry_data(
14111411
a_dict["parent"].append(
14121412
np.where(data[parent_type]["tid"] == link.parent)[0][0]
14131413
)
1414+
14141415
a_dict["child"].append(
14151416
np.where(data[child_type]["tid"] == link.child)[0][0]
14161417
)
@@ -1640,13 +1641,13 @@ def _save_entry_to_data_store(self, entry: Entry):
16401641
self._entry_converter.save_entry_object(entry=entry, pack=self)
16411642

16421643
if isinstance(entry, Payload):
1643-
if entry.modality == Modality.Text:
1644+
if Modality.Text.name == entry.modality_name:
16441645
entry.set_payload_index(len(self.text_payloads))
16451646
self.text_payloads.append(entry)
1646-
elif entry.modality == Modality.Audio:
1647+
elif Modality.Audio.name == entry.modality_name:
16471648
entry.set_payload_index(len(self.audio_payloads))
16481649
self.audio_payloads.append(entry)
1649-
elif entry.modality == Modality.Image:
1650+
elif Modality.Image.name == entry.modality_name:
16501651
entry.set_payload_index(len(self.image_payloads))
16511652
self.image_payloads.append(entry)
16521653

0 commit comments

Comments
 (0)