Skip to content

Commit 7299fd6

Browse files
authored
Merge branch 'master' into suqi-localtest
2 parents b3de999 + 8ee2c6c commit 7299fd6

24 files changed

+660
-266
lines changed

docs/spelling_wordlist.txt

+1
Original file line numberDiff line numberDiff line change
@@ -136,3 +136,4 @@ embedding
136136
embeddings
137137
docstrings
138138
numpy
139+
jsonpickle

examples/content_rewriter/pipeline.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@ def do_process(input_pack_str: str):
2929
# Let's assume there is a JSON string for us to use.
3030
datapack: DataPack = pipeline.process([input_pack_str])
3131
# You can get the JSON form like this.
32-
data_json = datapack.serialize()
32+
data_json = datapack.to_string()
3333
# Let's write it out.
34-
with open("generation.txt", "w") as fo:
34+
with open("generation.txt", "w", encoding="utf-8") as fo:
3535
fo.write(data_json)
3636

3737

@@ -42,6 +42,6 @@ def do_process(input_pack_str: str):
4242
# You should initialize the model here, so we only do it once.
4343
pipeline.initialize()
4444

45-
with open("rewriting_input.json") as fi:
45+
with open("rewriting_input.json", encoding="utf-8") as fi:
4646
test_str = fi.read()
4747
do_process(test_str)

forte/data/base_pack.py

+101-26
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,11 @@
1313
# limitations under the License.
1414

1515
import copy
16+
import gzip
17+
import pickle
1618
import uuid
1719
from abc import abstractmethod
20+
from pathlib import Path
1821
from typing import (
1922
List,
2023
Optional,
@@ -162,19 +165,44 @@ def pack_name(self, pack_name: str):
162165
self._meta.pack_name = pack_name
163166

164167
@classmethod
165-
def _deserialize(cls, string: str) -> "PackType":
168+
def _deserialize(
169+
cls,
170+
data_source: Union[Path, str],
171+
serialize_method: str = "jsonpickle",
172+
zip_pack: bool = False,
173+
) -> "PackType":
166174
"""
167175
This function should deserialize a Pack from a string. The
168-
implementation should decide the specific pack type.
176+
implementation should decide the specific pack type.
169177
170178
Args:
171-
string: The serialized string to be deserialized.
179+
data_source: The data path containing pack data. The content
180+
of the data could be string or bytes depending on the method of
181+
serialization.
182+
serialize_method: The method used to serialize the data, this
183+
should be the same as how serialization is done. The current
184+
options are "jsonpickle" and "pickle". The default method
185+
is "jsonpickle".
186+
zip_pack: Boolean value indicating whether the input source is
187+
zipped.
172188
173189
Returns:
174-
An pack object deserialized from the string.
190+
An pack object deserialized from the data.
175191
"""
176-
pack = jsonpickle.decode(string)
177-
return pack
192+
_open = gzip.open if zip_pack else open
193+
194+
if serialize_method == "jsonpickle":
195+
with _open(data_source, mode="rt") as f: # type: ignore
196+
pack = cls.from_string(f.read())
197+
else:
198+
with _open(data_source, mode="rb") as f: # type: ignore
199+
pack = pickle.load(f)
200+
201+
return pack # type: ignore
202+
203+
@classmethod
204+
def from_string(cls, data_content: str) -> "BasePack":
205+
return jsonpickle.decode(data_content)
178206

179207
@abstractmethod
180208
def delete_entry(self, entry: EntryType):
@@ -238,13 +266,76 @@ def add_all_remaining_entries(self, component: Optional[str] = None):
238266
self.add_entry(entry, c_)
239267
self._pending_entries.clear()
240268

241-
def serialize(self, drop_record: Optional[bool] = False) -> str:
242-
r"""Serializes a pack to a string."""
269+
def to_string(
270+
self,
271+
drop_record: Optional[bool] = False,
272+
json_method: str = "jsonpickle",
273+
indent: Optional[int] = None,
274+
) -> str:
275+
"""
276+
Return the string representation (json encoded) of this method.
277+
278+
Args:
279+
drop_record: Whether to drop the creation records, default is False.
280+
json_method: What method is used to convert data pack to json.
281+
Only supports `json_pickle` for now. Default value is
282+
`json_pickle`.
283+
indent: The indent used for json string.
284+
285+
Returns: String representation of the data pack.
286+
"""
287+
if drop_record:
288+
self._creation_records.clear()
289+
self._field_records.clear()
290+
if json_method == "jsonpickle":
291+
return jsonpickle.encode(self, unpicklable=True, indent=indent)
292+
else:
293+
raise ValueError(f"Unsupported JSON method {json_method}.")
294+
295+
def serialize(
296+
self,
297+
output_path: Union[str, Path],
298+
zip_pack: bool = False,
299+
drop_record: bool = False,
300+
serialize_method: str = "jsonpickle",
301+
indent: Optional[int] = None,
302+
):
303+
r"""
304+
Serializes the data pack to the provided path. The output of this
305+
function depends on the serialization method chosen.
306+
307+
Args:
308+
output_path: The path to write data to.
309+
zip_pack: Whether to compress the result with `gzip`.
310+
drop_record: Whether to drop the creation records, default is False.
311+
serialize_method: The method used to serialize the data. Currently
312+
supports "jsonpickle" (outputs str) and Python's built-in
313+
"pickle" (outputs bytes).
314+
indent: Whether to indent the file if written as JSON.
315+
316+
Returns: Results of serialization.
317+
"""
318+
if zip_pack:
319+
_open = gzip.open
320+
else:
321+
_open = open # type:ignore
322+
243323
if drop_record:
244324
self._creation_records.clear()
245325
self._field_records.clear()
246326

247-
return jsonpickle.encode(self, unpicklable=True)
327+
if serialize_method == "pickle":
328+
with _open(output_path, mode="wb") as pickle_out:
329+
pickle.dump(self, pickle_out) # type:ignore
330+
elif serialize_method == "jsonpickle":
331+
with _open(output_path, mode="wt", encoding="utf-8") as json_out:
332+
json_out.write(
333+
self.to_string(drop_record, "jsonpickle", indent=indent)
334+
)
335+
else:
336+
raise NotImplementedError(
337+
f"Unsupported serialization method {serialize_method}"
338+
)
248339

249340
def view(self):
250341
return copy.deepcopy(self)
@@ -457,22 +548,6 @@ def get_ids_from(self, components: List[str]) -> Set[int]:
457548
valid_component_id |= self.get_ids_by_creator(component)
458549
return valid_component_id
459550

460-
def get_ids_by_type_subtype(self, entry_type: Type[EntryType]) -> Set[int]:
461-
r"""Look up the type_index with key ``entry_type``.
462-
463-
Args:
464-
entry_type: The type of the entry you are looking for.
465-
466-
Returns:
467-
A set of entry ids. The entries are instances of `entry_type` (
468-
and also includes instances of the subclasses of `entry_type`).
469-
"""
470-
subclass_index: Set[int] = set()
471-
for index_key, index_val in self._index.iter_type_index():
472-
if issubclass(index_key, entry_type):
473-
subclass_index.update(index_val)
474-
return subclass_index
475-
476551
def _expand_to_sub_types(self, entry_type: Type[EntryType]) -> Set[Type]:
477552
"""
478553
Return all the types and the sub types that inherit from the provided
@@ -511,7 +586,7 @@ def get_entries_of(
511586
for tid in self._index.query_by_type(entry_type):
512587
yield self.get_entry(tid)
513588
else:
514-
for tid in self.get_ids_by_type_subtype(entry_type):
589+
for tid in self._index.query_by_type_subtype(entry_type):
515590
yield self.get_entry(tid)
516591

517592
@classmethod

forte/data/base_reader.py

+21-10
Original file line numberDiff line numberDiff line change
@@ -95,13 +95,16 @@ def default_configs(cls):
9595
values. Used to replace the missing values of input `configs`
9696
during pipeline construction.
9797
98-
.. code-block:: python
98+
Here:
99+
- zip_pack (bool): whether to zip the results. The default value is
100+
False.
101+
102+
- serialize_method: The method used to serialize the data. Current
103+
available options are "jsonpickle" and "pickle". Default is
104+
"jsonpickle".
99105
100-
{
101-
"name": "reader"
102-
}
103106
"""
104-
return {"name": "reader"}
107+
return {"zip_pack": False, "serialize_method": "jsonpickle"}
105108

106109
@property
107110
def pack_type(self):
@@ -323,11 +326,19 @@ def cache_data(self, collection: Any, pack: PackType, append: bool):
323326

324327
logger.info("Caching pack to %s", cache_filename)
325328
if append:
326-
with open(cache_filename, "a", encoding="utf-8") as cache:
327-
cache.write(pack.serialize() + "\n")
329+
with open(
330+
cache_filename,
331+
"a",
332+
encoding="utf-8",
333+
) as cache:
334+
cache.write(pack.to_string() + "\n")
328335
else:
329-
with open(cache_filename, "w", encoding="utf-8") as cache:
330-
cache.write(pack.serialize() + "\n")
336+
with open(
337+
cache_filename,
338+
"w",
339+
encoding="utf-8",
340+
) as cache:
341+
cache.write(pack.to_string() + "\n")
331342

332343
def read_from_cache(
333344
self, cache_filename: Union[Path, str]
@@ -343,7 +354,7 @@ def read_from_cache(
343354
logger.info("reading from cache file %s", cache_filename)
344355
with open(cache_filename, "r", encoding="utf-8") as cache_file:
345356
for line in cache_file:
346-
pack = DataPack.deserialize(line.strip())
357+
pack = DataPack.from_string(line.strip())
347358
if not isinstance(pack, self.pack_type):
348359
raise TypeError(
349360
f"Pack deserialized from {cache_filename} "

forte/data/data_pack.py

+32-24
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import logging
16+
from pathlib import Path
1617
from typing import (
1718
Dict,
1819
Iterable,
@@ -504,20 +505,30 @@ def get_original_index(
504505
return Span(orig_begin, orig_end)
505506

506507
@classmethod
507-
def deserialize(cls, data_pack_string: str) -> "DataPack":
508+
def deserialize(
509+
cls,
510+
data_source: Union[Path, str],
511+
serialize_method: str = "jsonpickle",
512+
zip_pack: bool = False,
513+
) -> "DataPack":
508514
"""
509515
Deserialize a Data Pack from a string. This internally calls the
510516
internal :meth:`~forte.data.base_pack.BasePack._deserialize` function
511517
from :class:`~forte.data.base_pack.BasePack`.
512518
513519
Args:
514-
data_pack_string: The serialized string of a data pack to be
515-
deserialized.
520+
data_source: The path storing data source.
521+
serialize_method: The method used to serialize the data, this
522+
should be the same as how serialization is done. The current
523+
options are "jsonpickle" and "pickle". The default method
524+
is "jsonpickle".
525+
zip_pack: Boolean value indicating whether the input source is
526+
zipped.
516527
517528
Returns:
518529
An data pack object deserialized from the string.
519530
"""
520-
return cls._deserialize(data_pack_string)
531+
return cls._deserialize(data_source, serialize_method, zip_pack)
521532

522533
def _add_entry(self, entry: EntryType) -> EntryType:
523534
r"""Force add an :class:`~forte.data.ontology.core.Entry` object to the
@@ -586,25 +597,22 @@ def __add_entry_with_check(
586597
f"should be an instance of Annotation, Link, Group of Generics."
587598
)
588599

589-
# TODO: duplicate is ill-defined.
590-
add_new = allow_duplicate or (entry not in target)
591-
592-
if add_new:
593-
target.add(entry)
594-
595-
# update the data pack index if needed
596-
self._index.update_basic_index([entry])
597-
if self._index.link_index_on and isinstance(entry, Link):
598-
self._index.update_link_index([entry])
599-
if self._index.group_index_on and isinstance(entry, Group):
600-
self._index.update_group_index([entry])
601-
self._index.deactivate_coverage_index()
602-
603-
self._pending_entries.pop(entry.tid)
604-
605-
return entry
606-
else:
607-
return target[target.index(entry)]
600+
if not allow_duplicate:
601+
index = target.index(entry)
602+
if index < 0:
603+
# Return the existing entry if duplicate is not allowed.
604+
return target[index]
605+
606+
target.add(entry)
607+
# update the data pack index if needed
608+
self._index.update_basic_index([entry])
609+
if self._index.link_index_on and isinstance(entry, Link):
610+
self._index.update_link_index([entry])
611+
if self._index.group_index_on and isinstance(entry, Group):
612+
self._index.update_group_index([entry])
613+
self._index.deactivate_coverage_index()
614+
self._pending_entries.pop(entry.tid)
615+
return entry
608616

609617
def delete_entry(self, entry: EntryType):
610618
r"""Delete an :class:`~forte.data.ontology.core.Entry` object from the
@@ -762,7 +770,7 @@ def get_data(
762770
context_type_, context_args
763771
)
764772

765-
valid_context_ids: Set[int] = self.get_ids_by_type_subtype(
773+
valid_context_ids: Set[int] = self._index.query_by_type_subtype(
766774
context_type_
767775
)
768776
if context_components:

0 commit comments

Comments
 (0)