Skip to content

Commit

Permalink
[Python] Use cloudpickle for local function serialization (apache#914)
Browse files Browse the repository at this point in the history
* Use cloudpickle for unsupported objects serialization

* add tests

* add df.sum tests

* lint code

* install cloudpickle

* lint code
  • Loading branch information
chaokunyang authored Sep 17, 2023
1 parent f8f92d7 commit 9403d28
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 12 deletions.
15 changes: 8 additions & 7 deletions python/pyfury/_fury.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import sys
import warnings
from dataclasses import dataclass
from typing import Dict, Tuple, TypeVar, Optional, Union, Iterable
from typing import Dict, Tuple, TypeVar, Union, Iterable

from pyfury.lib import mmh3

Expand Down Expand Up @@ -75,10 +75,12 @@
except ImportError:
np = None

from cloudpickle import Pickler

if sys.version_info[:2] < (3, 8): # pragma: no cover
import pickle5 as pickle # nosec # pylint: disable=import_pickle
from pickle5 import Unpickler
else:
import pickle # nosec # pylint: disable=import_pickle
from pickle import Unpickler

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -599,7 +601,6 @@ class Fury:
"_native_objects",
)
serialization_context: "SerializationContext"
unpickler: Optional[pickle.Unpickler]

def __init__(
self,
Expand Down Expand Up @@ -637,7 +638,7 @@ def __init__(
RuntimeWarning,
stacklevel=2,
)
self.pickler = pickle.Pickler(self.buffer)
self.pickler = Pickler(self.buffer)
else:
self.pickler = _PicklerStub(self.buffer)
self.unpickler = None
Expand Down Expand Up @@ -685,7 +686,7 @@ def _serialize(
self._buffer_callback = buffer_callback
self._unsupported_callback = unsupported_callback
if buffer is not None:
self.pickler = pickle.Pickler(buffer)
self.pickler = Pickler(buffer)
else:
self.buffer.writer_index = 0
buffer = self.buffer
Expand Down Expand Up @@ -832,7 +833,7 @@ def _deserialize(
if self.require_class_registration:
self.unpickler = _UnpicklerStub(buffer)
else:
self.unpickler = pickle.Unpickler(buffer)
self.unpickler = Unpickler(buffer)
if unsupported_objects is not None:
self._unsupported_objects = iter(unsupported_objects)
reader_index = buffer.reader_index
Expand Down
8 changes: 4 additions & 4 deletions python/pyfury/_serialization.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ from typing import TypeVar, Union, Iterable, get_type_hints

from pyfury._util import get_bit, set_bit, clear_bit
from pyfury._fury import Language, OpaqueObject
from pyfury._fury import _PicklerStub, _UnpicklerStub
from pyfury._fury import _PicklerStub, _UnpicklerStub, Pickler, Unpickler
from pyfury._fury import _ENABLE_CLASS_REGISTRATION_FORCIBLY
from pyfury.error import ClassNotCompatibleError
from pyfury.lib import mmh3
Expand Down Expand Up @@ -832,7 +832,7 @@ cdef class Fury:
RuntimeWarning,
stacklevel=2,
)
self.pickler = pickle.Pickler(self.buffer)
self.pickler = Pickler(self.buffer)
else:
self.pickler = _PicklerStub(self.buffer)
self.unpickler = None
Expand Down Expand Up @@ -872,7 +872,7 @@ cdef class Fury:
self._buffer_callback = buffer_callback
self._unsupported_callback = unsupported_callback
if buffer is not None:
self.pickler = pickle.Pickler(self.buffer)
self.pickler = Pickler(self.buffer)
else:
self.buffer.writer_index = 0
buffer = self.buffer
Expand Down Expand Up @@ -1032,7 +1032,7 @@ cdef class Fury:
if self.require_class_registration:
self.unpickler = _UnpicklerStub(buffer)
else:
self.unpickler = pickle.Unpickler(buffer)
self.unpickler = Unpickler(buffer)
if unsupported_objects is not None:
self._unsupported_objects = iter(unsupported_objects)
cdef int32_t reader_index = buffer.reader_index
Expand Down
2 changes: 1 addition & 1 deletion python/pyfury/_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,13 +129,13 @@ def get_xtype_id(self):
"""
return NOT_SUPPORT_CROSS_LANGUAGE

@abstractmethod
def get_xtype_tag(self):
"""
Returns
-------
a type tag used for setup type mapping between languages.
"""
raise RuntimeError("Tag is only for struct.")

def write(self, buffer, value):
raise NotImplementedError
Expand Down
22 changes: 22 additions & 0 deletions python/pyfury/tests/test_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,10 @@ def test_pickle_fallback():
new_o1 = fury.deserialize(data1)
assert o1 == new_o1

df = pd.DataFrame({"a": list(range(10))})
df2 = fury.deserialize(fury.serialize(df))
assert df2.equals(df)


def test_unsupported_callback():
fury = Fury(language=Language.PYTHON, ref_tracking=True)
Expand Down Expand Up @@ -529,3 +533,21 @@ def test_py_serialize_dataclass():
f1=None, f2=-2.0, f3="abc", f4=None, f5="xyz", f6=None, f7=None
)
assert ser_de(fury, obj2) == obj2


def test_function():
fury = Fury(
language=Language.PYTHON, ref_tracking=True, require_class_registration=False
)
c = fury.deserialize(fury.serialize(lambda x: x * 2))
assert c(2) == 4

def func(x):
return x * 2

c = fury.deserialize(fury.serialize(func))
assert c(2) == 4

df = pd.DataFrame({"a": list(range(10))})
df_sum = fury.deserialize(fury.serialize(df.sum))
assert df_sum().equals(df.sum())
2 changes: 2 additions & 0 deletions python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def parse_version():
install_requires=[
'dataclasses; python_version<"3.7"',
'pickle5; python_version<"3.8"',
"cloudpickle",
],
extras_require={
"format": [f"pyarrow == {pyarrow_version}"],
Expand All @@ -155,6 +156,7 @@ def parse_version():
f"pyarrow == {pyarrow_version}",
"numpy" 'dataclasses; python_version<"3.7"',
'pickle5; python_version<"3.8"',
"cloudpickle",
],
distclass=BinaryDistribution,
ext_modules=ext_modules,
Expand Down

0 comments on commit 9403d28

Please sign in to comment.