diff --git a/python/pyfury/_fury.py b/python/pyfury/_fury.py index 9be66e74eb..0c8f3288fd 100644 --- a/python/pyfury/_fury.py +++ b/python/pyfury/_fury.py @@ -21,7 +21,7 @@ import sys import warnings from dataclasses import dataclass -from typing import Dict, Tuple, TypeVar, Optional, Union, Iterable +from typing import Dict, Tuple, TypeVar, Union, Iterable from pyfury.lib import mmh3 @@ -75,10 +75,12 @@ except ImportError: np = None +from cloudpickle import Pickler + if sys.version_info[:2] < (3, 8): # pragma: no cover - import pickle5 as pickle # nosec # pylint: disable=import_pickle + from pickle5 import Unpickler else: - import pickle # nosec # pylint: disable=import_pickle + from pickle import Unpickler logger = logging.getLogger(__name__) @@ -599,7 +601,6 @@ class Fury: "_native_objects", ) serialization_context: "SerializationContext" - unpickler: Optional[pickle.Unpickler] def __init__( self, @@ -637,7 +638,7 @@ def __init__( RuntimeWarning, stacklevel=2, ) - self.pickler = pickle.Pickler(self.buffer) + self.pickler = Pickler(self.buffer) else: self.pickler = _PicklerStub(self.buffer) self.unpickler = None @@ -685,7 +686,7 @@ def _serialize( self._buffer_callback = buffer_callback self._unsupported_callback = unsupported_callback if buffer is not None: - self.pickler = pickle.Pickler(buffer) + self.pickler = Pickler(buffer) else: self.buffer.writer_index = 0 buffer = self.buffer @@ -832,7 +833,7 @@ def _deserialize( if self.require_class_registration: self.unpickler = _UnpicklerStub(buffer) else: - self.unpickler = pickle.Unpickler(buffer) + self.unpickler = Unpickler(buffer) if unsupported_objects is not None: self._unsupported_objects = iter(unsupported_objects) reader_index = buffer.reader_index diff --git a/python/pyfury/_serialization.pyx b/python/pyfury/_serialization.pyx index 53bfd9f306..832970dd46 100644 --- a/python/pyfury/_serialization.pyx +++ b/python/pyfury/_serialization.pyx @@ -14,7 +14,7 @@ from typing import TypeVar, Union, Iterable, get_type_hints from pyfury._util import get_bit, set_bit, clear_bit from pyfury._fury import Language, OpaqueObject -from pyfury._fury import _PicklerStub, _UnpicklerStub +from pyfury._fury import _PicklerStub, _UnpicklerStub, Pickler, Unpickler from pyfury._fury import _ENABLE_CLASS_REGISTRATION_FORCIBLY from pyfury.error import ClassNotCompatibleError from pyfury.lib import mmh3 @@ -832,7 +832,7 @@ cdef class Fury: RuntimeWarning, stacklevel=2, ) - self.pickler = pickle.Pickler(self.buffer) + self.pickler = Pickler(self.buffer) else: self.pickler = _PicklerStub(self.buffer) self.unpickler = None @@ -872,7 +872,7 @@ cdef class Fury: self._buffer_callback = buffer_callback self._unsupported_callback = unsupported_callback if buffer is not None: - self.pickler = pickle.Pickler(self.buffer) + self.pickler = Pickler(self.buffer) else: self.buffer.writer_index = 0 buffer = self.buffer @@ -1032,7 +1032,7 @@ cdef class Fury: if self.require_class_registration: self.unpickler = _UnpicklerStub(buffer) else: - self.unpickler = pickle.Unpickler(buffer) + self.unpickler = Unpickler(buffer) if unsupported_objects is not None: self._unsupported_objects = iter(unsupported_objects) cdef int32_t reader_index = buffer.reader_index diff --git a/python/pyfury/_serializer.py b/python/pyfury/_serializer.py index b25f072ad5..8b6f8d96ee 100644 --- a/python/pyfury/_serializer.py +++ b/python/pyfury/_serializer.py @@ -129,13 +129,13 @@ def get_xtype_id(self): """ return NOT_SUPPORT_CROSS_LANGUAGE - @abstractmethod def get_xtype_tag(self): """ Returns ------- a type tag used for setup type mapping between languages. """ + raise RuntimeError("Tag is only for struct.") def write(self, buffer, value): raise NotImplementedError diff --git a/python/pyfury/tests/test_serializer.py b/python/pyfury/tests/test_serializer.py index ab1164df39..6875367578 100644 --- a/python/pyfury/tests/test_serializer.py +++ b/python/pyfury/tests/test_serializer.py @@ -402,6 +402,10 @@ def test_pickle_fallback(): new_o1 = fury.deserialize(data1) assert o1 == new_o1 + df = pd.DataFrame({"a": list(range(10))}) + df2 = fury.deserialize(fury.serialize(df)) + assert df2.equals(df) + def test_unsupported_callback(): fury = Fury(language=Language.PYTHON, ref_tracking=True) @@ -529,3 +533,21 @@ def test_py_serialize_dataclass(): f1=None, f2=-2.0, f3="abc", f4=None, f5="xyz", f6=None, f7=None ) assert ser_de(fury, obj2) == obj2 + + +def test_function(): + fury = Fury( + language=Language.PYTHON, ref_tracking=True, require_class_registration=False + ) + c = fury.deserialize(fury.serialize(lambda x: x * 2)) + assert c(2) == 4 + + def func(x): + return x * 2 + + c = fury.deserialize(fury.serialize(func)) + assert c(2) == 4 + + df = pd.DataFrame({"a": list(range(10))}) + df_sum = fury.deserialize(fury.serialize(df.sum)) + assert df_sum().equals(df.sum()) diff --git a/python/setup.py b/python/setup.py index 270934f287..3e474ebce0 100644 --- a/python/setup.py +++ b/python/setup.py @@ -144,6 +144,7 @@ def parse_version(): install_requires=[ 'dataclasses; python_version<"3.7"', 'pickle5; python_version<"3.8"', + "cloudpickle", ], extras_require={ "format": [f"pyarrow == {pyarrow_version}"], @@ -155,6 +156,7 @@ def parse_version(): f"pyarrow == {pyarrow_version}", "numpy" 'dataclasses; python_version<"3.7"', 'pickle5; python_version<"3.8"', + "cloudpickle", ], distclass=BinaryDistribution, ext_modules=ext_modules,