Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MNT Refactor with a SaveState and avoid duplicate numpy arrays #173

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 32 additions & 25 deletions skops/io/_general.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,31 @@
from __future__ import annotations

import json
from functools import partial
from types import FunctionType
from typing import Any

import numpy as np

from ._utils import _import_obj, get_instance, get_module, get_state, gettype
from ._utils import SaveState, _import_obj, get_instance, get_module, get_state, gettype
from .exceptions import UnsupportedTypeException


def dict_get_state(obj, dst):
def dict_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
res = {
"__class__": obj.__class__.__name__,
"__module__": get_module(type(obj)),
}

key_types = get_state([type(key) for key in obj.keys()], dst)
key_types = get_state([type(key) for key in obj.keys()], save_state)
content = {}
for key, value in obj.items():
if isinstance(value, property):
continue
if np.isscalar(key) and hasattr(key, "item"):
# convert numpy value to python object
key = key.item()
content[key] = get_state(value, dst)
key = key.item() # type: ignore
content[key] = get_state(value, save_state)
res["content"] = content
res["key_types"] = key_types
return res
Expand All @@ -36,14 +39,14 @@ def dict_get_instance(state, src):
return content


def list_get_state(obj, dst):
def list_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
res = {
"__class__": obj.__class__.__name__,
"__module__": get_module(type(obj)),
}
content = []
for value in obj:
content.append(get_state(value, dst))
content.append(get_state(value, save_state))
res["content"] = content
return res

Expand All @@ -55,12 +58,12 @@ def list_get_instance(state, src):
return content


def tuple_get_state(obj, dst):
def tuple_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
res = {
"__class__": obj.__class__.__name__,
"__module__": get_module(type(obj)),
}
content = tuple(get_state(value, dst) for value in obj)
content = tuple(get_state(value, save_state) for value in obj)
res["content"] = content
return res

Expand All @@ -86,7 +89,7 @@ def isnamedtuple(t):
return content


def function_get_state(obj, dst):
def function_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
res = {
"__class__": obj.__class__.__name__,
"__module__": get_module(obj),
Expand All @@ -103,16 +106,16 @@ def function_get_instance(state, src):
return loaded


def partial_get_state(obj, dst):
def partial_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
_, _, (func, args, kwds, namespace) = obj.__reduce__()
res = {
"__class__": "partial", # don't allow any subclass
"__module__": get_module(type(obj)),
"content": {
"func": get_state(func, dst),
"args": get_state(args, dst),
"kwds": get_state(kwds, dst),
"namespace": get_state(namespace, dst),
"func": get_state(func, save_state),
"args": get_state(args, save_state),
"kwds": get_state(kwds, save_state),
"namespace": get_state(namespace, save_state),
},
}
return res
Expand All @@ -129,7 +132,7 @@ def partial_get_instance(state, src):
return instance


def type_get_state(obj, dst):
def type_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
# To serialize a type, we first need to set the metadata to tell that it's
# a type, then store the type's info itself in the content field.
res = {
Expand All @@ -148,7 +151,7 @@ def type_get_instance(state, src):
return loaded


def slice_get_state(obj, dst):
def slice_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
res = {
"__class__": obj.__class__.__name__,
"__module__": get_module(type(obj)),
Expand All @@ -168,13 +171,19 @@ def slice_get_instance(state, src):
return slice(start, stop, step)


def object_get_state(obj, dst):
def object_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
# This method is for objects which can either be persisted with json, or
# the ones for which we can get/set attributes through
# __getstate__/__setstate__ or reading/writing to __dict__.
try:
# if we can simply use json, then we're done.
return json.dumps(obj)
obj_str = json.dumps(obj)
return {
"__class__": "str",
"__module__": "builtins",
"content": obj_str,
"is_json": True,
}
except Exception:
pass

Expand All @@ -192,18 +201,16 @@ def object_get_state(obj, dst):
else:
return res

content = get_state(attrs, dst)
content = get_state(attrs, save_state)
# it's sufficient to store the "content" because we know that this dict can
# only have str type keys
res["content"] = content
return res


def object_get_instance(state, src):
try:
return json.loads(state)
except Exception:
pass
if state.get("is_json", False):
return json.loads(state["content"])

cls = gettype(state)

Expand All @@ -225,7 +232,7 @@ def object_get_instance(state, src):
return instance


def unsupported_get_state(obj, dst):
def unsupported_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
raise UnsupportedTypeException(obj)


Expand Down
75 changes: 41 additions & 34 deletions skops/io/_numpy.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,49 @@
from __future__ import annotations

import io
from pathlib import Path
from uuid import uuid4
from typing import Any

import numpy as np

from ._general import function_get_instance
from ._utils import _import_obj, get_instance, get_module, get_state
from ._utils import SaveState, _import_obj, get_instance, get_module, get_state
from .exceptions import UnsupportedTypeException


def ndarray_get_state(obj, dst):
def ndarray_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
res = {
"__class__": obj.__class__.__name__,
"__module__": get_module(type(obj)),
}

# First, try to save object with np.save and allow_pickle=False, which
# should generally work as long as the dtype is not object.
try:
f_name = f"{uuid4()}.npy"
with open(Path(dst) / f_name, "wb") as f:
np.save(f, obj, allow_pickle=False)
res.update(type="numpy", file=f_name)
# If the dtype is object, np.save should not work with
# allow_pickle=False, therefore we convert them to a list and
# recursively call get_state on it.
if obj.dtype == object:
obj_serialized = get_state(obj.tolist(), save_state)
res["content"] = obj_serialized["content"]
res["type"] = "json"
res["shape"] = get_state(obj.shape, save_state)
else:
# Memoize the object and then check if it's file name (containing
# the object id) already exists. If it does, there is no need to
# save the object again. Memoizitation is necessary since for
# ephemeral objects, the same id might otherwise be reused.
obj_id = save_state.memoize(obj)
f_name = f"{obj_id}.npy"
path = save_state.path / f_name
if not path.exists():
with open(path, "wb") as f:
np.save(f, obj, allow_pickle=False)
res.update(type="numpy", file=f_name)
except ValueError:
# Object arrays cannot be saved with allow_pickle=False, therefore we
# convert them to a list and recursively call get_state on it. For this,
# we expect the dtype to be object.
if obj.dtype != object:
raise UnsupportedTypeException(
f"numpy arrays of dtype {obj.dtype} are not supported yet, please "
"open an issue at https://github.com/skops-dev/skops/issues and "
"report your error"
)

obj_serialized = get_state(obj.tolist(), dst)
res["content"] = obj_serialized["content"]
res["type"] = "json"
res["shape"] = get_state(obj.shape, dst)
# Couldn't save the numpy array with either method
raise UnsupportedTypeException(
f"numpy arrays of dtype {obj.dtype} are not supported yet, please "
"open an issue at https://github.com/skops-dev/skops/issues and "
"report your error"
)

return res

Expand Down Expand Up @@ -67,13 +74,13 @@ def ndarray_get_instance(state, src):
return val


def maskedarray_get_state(obj, dst):
def maskedarray_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
res = {
"__class__": obj.__class__.__name__,
"__module__": get_module(type(obj)),
"content": {
"data": get_state(obj.data, dst),
"mask": get_state(obj.mask, dst),
"data": get_state(obj.data, save_state),
"mask": get_state(obj.mask, save_state),
},
}
return res
Expand All @@ -85,8 +92,8 @@ def maskedarray_get_instance(state, src):
return np.ma.MaskedArray(data, mask)


def random_state_get_state(obj, dst):
content = get_state(obj.get_state(legacy=False), dst)
def random_state_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
content = get_state(obj.get_state(legacy=False), save_state)
res = {
"__class__": obj.__class__.__name__,
"__module__": get_module(type(obj)),
Expand All @@ -103,7 +110,7 @@ def random_state_get_instance(state, src):
return random_state


def random_generator_get_state(obj, dst):
def random_generator_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
bit_generator_state = obj.bit_generator.state
res = {
"__class__": obj.__class__.__name__,
Expand All @@ -128,7 +135,7 @@ def random_generator_get_instance(state, src):
# For numpy.ufunc we need to get the type from the type's module, but for other
# functions we get it from objet's module directly. Therefore sett a especial
# get_state method for them here. The load is the same as other functions.
def ufunc_get_state(obj, dst):
def ufunc_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
res = {
"__class__": obj.__class__.__name__, # ufunc
"__module__": get_module(type(obj)), # numpy
Expand All @@ -140,14 +147,14 @@ def ufunc_get_state(obj, dst):
return res


def dtype_get_state(obj, dst):
def dtype_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
# we use numpy's internal save mechanism to store the dtype by
# saving/loading an empty array with that dtype.
tmp = np.ndarray(0, dtype=obj)
tmp: np.typing.NDArray = np.ndarray(0, dtype=obj)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not a fan of typing local variables. if mypy is failing on such a line, we should just disable those checks if possible.

res = {
"__class__": "dtype",
"__module__": "numpy",
"content": ndarray_get_state(tmp, dst),
"content": ndarray_get_state(tmp, save_state),
}
return res

Expand Down
16 changes: 8 additions & 8 deletions skops/io/_persist.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,7 @@

import skops

from ._utils import _get_instance, _get_state, get_instance, get_state

# For now, there is just one protocol version
PROTOCOL = 0

from ._utils import SaveState, _get_instance, _get_state, get_instance, get_state

# We load the dispatch functions from the corresponding modules and register
# them.
Expand Down Expand Up @@ -53,9 +49,13 @@ def save(obj, file):

"""
with tempfile.TemporaryDirectory() as dst:
with open(Path(dst) / "schema.json", "w") as f:
state = get_state(obj, dst)
state["protocol"] = PROTOCOL
path = Path(dst)
with open(path / "schema.json", "w") as f:
save_state = SaveState(path=path)
state = get_state(obj, save_state)
save_state.clear_memo()

state["protocol"] = save_state.protocol
state["_skops_version"] = skops.__version__
json.dump(state, f, indent=2)

Expand Down
22 changes: 15 additions & 7 deletions skops/io/_scipy.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,31 @@
from __future__ import annotations

import io
from pathlib import Path
from uuid import uuid4
from typing import Any

from scipy.sparse import load_npz, save_npz, spmatrix

from ._utils import get_module
from ._utils import SaveState, get_module


def sparse_matrix_get_state(obj, dst):
def sparse_matrix_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
res = {
"__class__": obj.__class__.__name__,
"__module__": get_module(type(obj)),
}

f_name = f"{uuid4()}.npz"
save_npz(Path(dst) / f_name, obj)
# Memoize the object and then check if it's file name (containing the object
# id) already exists. If it does, there is no need to save the object again.
# Memoizitation is necessary since for ephemeral objects, the same id might
# otherwise be reused.
obj_id = save_state.memoize(obj)
f_name = f"{obj_id}.npz"
path = save_state.path / f_name
if not path.exists():
save_npz(path, obj)

res["type"] = "scipy"
res["file"] = f_name

return res


Expand Down
Loading