Skip to content

Commit

Permalink
Make nodes and operations pickable.
Browse files Browse the repository at this point in the history
  • Loading branch information
marcenacp committed Sep 3, 2024
1 parent d1193d4 commit 589c070
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,33 +23,6 @@ def __init__(self):
super().__init__(self)
self.last_operations: dict[Node, Operation] = {}

def add_node(self, operation: Operation) -> None:
"""Overloads nx.add_node to keep track of the last operations."""
if not self.has_node(operation):
super().add_node(operation)

def add_edge(self, operation1: Operation, operation2: Operation) -> None:
"""Overloads nx.add_node to keep track of the last operations."""
if not self.has_edge(operation1, operation2):
super().add_edge(operation1, operation2)

@property
def nodes(self) -> Iterable[Operation]:
"""Overloads nx.nodes to return an interator of operations."""
return super().nodes()

def is_leaf(self, operation: Operation | None) -> bool:
"""Tests whether an operation is a leaf in the graph."""
return self.out_degree(operation) == 0

def entry_operations(self) -> list[Operation]:
"""Lists all operations without a parent in the graph of operations."""
return [
operation
for operation, indegree in self.in_degree(self.nodes)
if indegree == 0 and isinstance(operation, Operation)
]


@dataclasses.dataclass(frozen=True, repr=False)
class Operation(abc.ABC):
Expand Down Expand Up @@ -116,3 +89,31 @@ def __rrshift__(
for left_operation in left_operations:
self.operations.add_edge(left_operation, right_operation)
return right_operation

def __reduce__(self):
"""Allows pickling the operation.
`self.operations` is stored separately in the state to break the cyclic
reference between Operation and Operations. We could refactor the codebase to
not have this dependency, but it would be a bigger change and it's convenient
to be able to reference all operations from a single operation.
"""
state = self.__getstate__()
args = tuple(state.values())
return (
self.__class__,
args,
{"operations": self.operations},
)

def __getstate__(self):
state = {}
for field in dataclasses.fields(self):
if field.name == "operations":
state[field.name] = Operations()
else:
state[field.name] = getattr(self, field.name)
return state

def __setstate__(self, state):
object.__setattr__(self, "operations", state["operations"])
6 changes: 5 additions & 1 deletion python/mlcroissant/mlcroissant/_src/operation_graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,11 @@ def from_nodes(cls, ctx: Context, metadata: Node) -> "OperationGraph":
_add_operations_for_field(operations, node)

# Attach all entry nodes to a single `start` node
entry_operations = operations.entry_operations()
entry_operations = [
operation
for operation, indegree in operations.in_degree(operations.nodes)
if indegree == 0 and isinstance(operation, Operation)
]
init_operation = InitOperation(operations=operations, node=metadata)
for entry_operation in entry_operations:
operations.add_edge(init_operation, entry_operation)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""read_test module."""

import pathlib
import pickle
import tempfile
from unittest import mock

Expand Down Expand Up @@ -84,3 +85,14 @@ def test_reading_method():
assert _reading_method(empty_file_object, (filename,)) == ReadingMethod.NONE
with pytest.raises(ValueError):
_reading_method(empty_file_object, (content_field, lines_field))


def test_pickable():
operation = Read(
operations=operations(),
node=empty_file_object,
folder=epath.Path("/foo/bar"),
fields=(),
)
operation = pickle.loads(pickle.dumps(operation))
assert operation.folder == epath.Path("/foo/bar")
21 changes: 21 additions & 0 deletions python/mlcroissant/mlcroissant/_src/structure_graph/base_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,27 @@ def __deepcopy__(self, memo):
memo[id(self)] = copy
return copy

def __reduce__(self):
"""Allows pickling the node.
`self.ctx` is stored separately in the state because it's not pickable directly.
"""
state = self.__getstate__()
args = tuple(state.values())
return (self.__class__, args, {"ctx": self.ctx})

def __getstate__(self):
state = {}
for field in dataclasses.fields(self):
if field.name == "ctx":
state[field.name] = Context()
else:
state[field.name] = getattr(self, field.name)
return state

def __setstate__(self, state):
self.ctx = state["ctx"]

def to_json(self) -> Json:
"""Converts the Python class to JSON."""
cls = self.__class__
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Tests for FileObjects."""

import pickle
from unittest import mock

from etils import epath
Expand Down Expand Up @@ -91,3 +92,15 @@ def test_from_jsonld(encoding):
== "48a7c257f3c90b2a3e529ddd2cca8f4f1bd8e49ed244ef53927649504ac55354"
)
assert not ctx.issues.errors


@pytest.mark.parametrize(
["conforms_to"],
[[CroissantVersion.V_0_8], [CroissantVersion.V_1_0]],
)
def test_pickable(conforms_to):
ctx = Context(conforms_to=conforms_to)
file_object = create_test_node(FileObject, ctx=ctx)
file_object = pickle.loads(pickle.dumps(file_object))
# Test that the context was successfully restored:
assert file_object.ctx.conforms_to == conforms_to

0 comments on commit 589c070

Please sign in to comment.