Skip to content

Commit d7cbcfa

Browse files
mylibrarSuqi SunZhanyuan Zhanghunterhector
authored
Add pipeline states to ir (#499)
* Add pipeline states to ir * Add doc for ir * Restructure IR * Update IR parse * Update selector serialization * Add informative error message * init implement of Selector serialization/deserialization * passed given unit tests * Selector inherits Configurable and assures backward compatability * Fix pylint err * resolved PR comments on Selector 1.0 * Fix lint issue * pipeline calls selectors' method * removed is_initialized property from Selector * implemented Selector initialize method * Fix black issue * Ignore too-many-public-methods pylint error * Fix mypy error * Fix black issue Co-authored-by: Suqi Sun <suqi.sun@petuum.com> Co-authored-by: Zhanyuan Zhang <zhanyuan.zhang@petuum.com> Co-authored-by: Hector <hunterhector@gmail.com>
1 parent 8ee2c6c commit d7cbcfa

File tree

7 files changed

+474
-49
lines changed

7 files changed

+474
-49
lines changed

forte/data/ontology/code_generation_objects.py

+48
Original file line numberDiff line numberDiff line change
@@ -785,6 +785,54 @@ def collect_parents(self, node_dict: Dict[str, Set[str]]):
785785
] = found_node.parent.attributes
786786
found_node = found_node.parent
787787

788+
def todict(self) -> Dict[str, Any]:
789+
r"""Dump the EntryTree structure to a dictionary.
790+
791+
Returns:
792+
dict: A dictionary storing the EntryTree.
793+
"""
794+
795+
def node_to_dict(node: EntryTreeNode):
796+
return (
797+
None
798+
if not node
799+
else {
800+
"name": node.name,
801+
"attributes": list(node.attributes),
802+
"children": [
803+
node_to_dict(child) for child in node.children
804+
],
805+
}
806+
)
807+
808+
return node_to_dict(self.root)
809+
810+
def fromdict(
811+
self, tree_dict: Dict[str, Any], parent_entry_name: Optional[str] = None
812+
) -> Optional["EntryTree"]:
813+
r"""Load the EntryTree structure from a dictionary.
814+
815+
Args:
816+
tree_dict: A dictionary storing the EntryTree.
817+
parent_entry_name: The type name of the parent of the node to be
818+
built. Default value is None.
819+
"""
820+
if not tree_dict:
821+
return None
822+
823+
if parent_entry_name is None:
824+
self.root = EntryTreeNode(name=tree_dict["name"])
825+
self.root.attributes = set(tree_dict["attributes"])
826+
else:
827+
self.add_node(
828+
curr_entry_name=tree_dict["name"],
829+
parent_entry_name=parent_entry_name,
830+
curr_entry_attr=set(tree_dict["attributes"]),
831+
)
832+
for child in tree_dict["children"]:
833+
self.fromdict(child, tree_dict["name"])
834+
return self
835+
788836

789837
def search(node: EntryTreeNode, search_node_name: str):
790838
if node.name == search_node_name:

forte/data/selector.py

+74-12
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,12 @@
1515
This defines some selector interface used as glue to combine
1616
DataPack/multiPack processors and Pipeline.
1717
"""
18-
from typing import Generic, Iterator, TypeVar
18+
from typing import Generic, Iterator, TypeVar, Optional, Union, Dict, Any
1919

2020
import re
2121

22+
from forte.common.configuration import Config
23+
from forte.common.configurable import Configurable
2224
from forte.data.base_pack import BasePack
2325
from forte.data.data_pack import DataPack
2426
from forte.data.multi_pack import MultiPack
@@ -37,13 +39,18 @@
3739
]
3840

3941

40-
class Selector(Generic[InputPackType, OutputPackType]):
41-
def __init__(self, **kwargs):
42-
pass
42+
class Selector(Generic[InputPackType, OutputPackType], Configurable):
43+
def __init__(self):
44+
self.configs: Config = Config({}, {})
4345

4446
def select(self, pack: InputPackType) -> Iterator[OutputPackType]:
4547
raise NotImplementedError
4648

49+
def initialize(
50+
self, configs: Optional[Union[Config, Dict[str, Any]]] = None
51+
):
52+
self.configs = self.make_configs(configs)
53+
4754

4855
class DummySelector(Selector[InputPackType, InputPackType]):
4956
r"""Do nothing, return the data pack itself, which can be either
@@ -66,12 +73,23 @@ def select(self, pack: MultiPack) -> Iterator[DataPack]:
6673
class NameMatchSelector(SinglePackSelector):
6774
r"""Select a :class:`DataPack` from a :class:`MultiPack` with specified
6875
name.
76+
77+
This implementation takes special care for backward compatability:
78+
Deprecated:
79+
selector = NameMatchSelector(select_name="foo")
80+
selector = NameMatchSelector("foo")
81+
Now:
82+
selector = NameMatchSelector()
83+
selector.initialize(
84+
configs={
85+
"select_name": "foo"
86+
}
87+
)
6988
"""
7089

71-
def __init__(self, select_name: str):
90+
def __init__(self, select_name: Optional[str] = None):
7291
super().__init__()
73-
assert select_name is not None
74-
self.select_name: str = select_name
92+
self.select_name = select_name
7593

7694
def select(self, m_pack: MultiPack) -> Iterator[DataPack]:
7795
matches = 0
@@ -85,23 +103,67 @@ def select(self, m_pack: MultiPack) -> Iterator[DataPack]:
85103
f"Pack name {self.select_name}" f" not in the MultiPack"
86104
)
87105

106+
def initialize(
107+
self, configs: Optional[Union[Config, Dict[str, Any]]] = None
108+
):
109+
if self.select_name is not None:
110+
super().initialize({"select_name": self.select_name})
111+
else:
112+
super().initialize(configs)
113+
114+
if self.configs["select_name"] is None:
115+
raise ValueError("select_name shouldn't be None.")
116+
self.select_name = self.configs["select_name"]
117+
118+
@classmethod
119+
def default_configs(cls):
120+
return {"select_name": None}
121+
88122

89123
class RegexNameMatchSelector(SinglePackSelector):
90-
r"""Select a :class:`DataPack` from a :class:`MultiPack` using a regex."""
124+
r"""Select a :class:`DataPack` from a :class:`MultiPack` using a regex.
125+
126+
This implementation takes special care for backward compatability:
127+
Deprecated:
128+
selector = RegexNameMatchSelector(select_name="^.*\\d$")
129+
selector = RegexNameMatchSelector("^.*\\d$")
130+
Now:
131+
selector = RegexNameMatchSelector()
132+
selector.initialize(
133+
configs={
134+
"select_name": "^.*\\d$"
135+
}
136+
)
137+
"""
91138

92-
def __init__(self, select_name: str):
139+
def __init__(self, select_name: Optional[str] = None):
93140
super().__init__()
94-
assert select_name is not None
95-
self.select_name: str = select_name
141+
self.select_name = select_name
96142

97143
def select(self, m_pack: MultiPack) -> Iterator[DataPack]:
98144
if len(m_pack.packs) == 0:
99145
raise ValueError("Multi-pack is empty")
100146
else:
101147
for name, pack in m_pack.iter_packs():
102-
if re.match(self.select_name, name):
148+
if re.match(self.select_name, name): # type: ignore
103149
yield pack
104150

151+
def initialize(
152+
self, configs: Optional[Union[Config, Dict[str, Any]]] = None
153+
):
154+
if self.select_name is not None:
155+
super().initialize({"select_name": self.select_name})
156+
else:
157+
super().initialize(configs)
158+
159+
if self.configs["select_name"] is None:
160+
raise ValueError("select_name shouldn't be None.")
161+
self.select_name = self.configs["select_name"]
162+
163+
@classmethod
164+
def default_configs(cls):
165+
return {"select_name": None}
166+
105167

106168
class FirstPackSelector(SinglePackSelector):
107169
r"""Select the first entry from :class:`MultiPack` and yield it."""

0 commit comments

Comments
 (0)