Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve state performance by diffing state #389

Merged
merged 10 commits into from
Jun 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions build_defs/defaults.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,7 @@ THIRD_PARTY_PY_MYPY_PROTOBUF = [
THIRD_PARTY_PY_PANDAS = [
requirement("pandas"),
]

THIRD_PARTY_PY_DEEPDIFF = [
requirement("deepdiff"),
]
2 changes: 1 addition & 1 deletion build_defs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@ protobuf
pydantic==1.10.13 --no-binary pydantic
pygments
libcst==1.1.0

deepdiff==6.7.1
matplotlib # only used for example
pandas # only used for example
8 changes: 8 additions & 0 deletions build_defs/requirements_lock.txt
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,10 @@ cycler==0.12.1 \
--hash=sha256:85cef7cff222d8644161529808465972e51340599459b8ac3ccbac5a854e0d30 \
--hash=sha256:88bb128f02ba341da8ef447245a9e138fae777f6a23943da4540077d3601eb1c
# via matplotlib
deepdiff==6.7.1 \
--hash=sha256:58396bb7a863cbb4ed5193f548c56f18218060362311aa1dc36397b2f25108bd \
--hash=sha256:b367e6fa6caac1c9f500adc79ada1b5b1242c50d5f716a1a4362030197847d30
# via -r build_defs/requirements.txt
exceptiongroup==1.1.3 \
--hash=sha256:097acd85d473d75af5bb98e41b61ff7fe35efe6675e4f9370ec6ec5126d160e9 \
--hash=sha256:343280667a4585d195ca1cf9cef84a4e178c4b6cf2274caef9859782b567d5e3
Expand Down Expand Up @@ -583,6 +587,10 @@ numpy==1.26.3 \
# contourpy
# matplotlib
# pandas
ordered-set==4.1.0 \
--hash=sha256:046e1132c71fcf3330438a539928932caf51ddbc582496833e23de611de14562 \
--hash=sha256:694a8e44c87657c59292ede72891eb91d34131f6531463aab3009191c77364a8
# via deepdiff
packaging==23.2 \
--hash=sha256:048fb0e9405036518eaaf48a55953c750c11e1a1b68e0dd1a9d62ed0c092cfc5 \
--hash=sha256:8c491190033a9af7e1d931d0b5dacc2ef47509b34dd0de67ed209b5203fc88c7
Expand Down
10 changes: 8 additions & 2 deletions mesop/dataclass_utils/BUILD
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
load("//build_defs:defaults.bzl", "THIRD_PARTY_PY_PANDAS", "THIRD_PARTY_PY_PYTEST", "py_library", "py_test")
load("//build_defs:defaults.bzl", "THIRD_PARTY_PY_DEEPDIFF", "THIRD_PARTY_PY_PANDAS", "THIRD_PARTY_PY_PYTEST", "py_library", "py_test")

package(
default_visibility = ["//build_defs:mesop_internal"],
Expand All @@ -10,11 +10,17 @@ py_library(
["*.py"],
exclude = ["*_test.py"],
),
deps = ["//mesop/exceptions"],
deps = ["//mesop/exceptions"] + THIRD_PARTY_PY_DEEPDIFF,
)

py_test(
name = "dataclass_utils_test",
srcs = ["dataclass_utils_test.py"],
deps = [":dataclass_utils"] + THIRD_PARTY_PY_PYTEST + THIRD_PARTY_PY_PANDAS,
)

py_test(
name = "diff_state_test",
srcs = ["diff_state_test.py"],
deps = [":dataclass_utils"] + THIRD_PARTY_PY_PYTEST + THIRD_PARTY_PY_PANDAS,
)
3 changes: 3 additions & 0 deletions mesop/dataclass_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from .dataclass_utils import (
dataclass_with_defaults as dataclass_with_defaults,
)
from .dataclass_utils import (
diff_state as diff_state,
)
from .dataclass_utils import (
serialize_dataclass as serialize_dataclass,
)
Expand Down
98 changes: 95 additions & 3 deletions mesop/dataclass_utils/dataclass_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,31 @@
from io import StringIO
from typing import Any, Type, TypeVar, cast, get_origin, get_type_hints

from deepdiff import DeepDiff, Delta
from deepdiff.operator import BaseOperator
from deepdiff.path import parse_path

from mesop.exceptions import MesopException

_PANDAS_OBJECT_KEY = "__pandas.DataFrame__"
_DIFF_ACTION_DATA_FRAME_CHANGED = "data_frame_changed"

C = TypeVar("C")


def _check_has_pandas():
"""Checks if pandas exists since it is an optional dependency for Mesop."""
try:
import pandas # noqa: F401

return True
except ImportError:
return False


_has_pandas = _check_has_pandas()


def dataclass_with_defaults(cls: Type[C]) -> Type[C]:
"""
Provides defaults for every attribute in a dataclass (recursively) so
Expand Down Expand Up @@ -106,6 +124,13 @@ def default(self, obj):
return {_PANDAS_OBJECT_KEY: pd.DataFrame.to_json(obj, orient="table")}
except ImportError:
pass

if is_dataclass(obj):
return asdict(obj)

if isinstance(obj, type):
return str(obj)

return super().default(obj)


Expand All @@ -119,11 +144,78 @@ def decode_mesop_json_state_hook(dct):

One thing to note is that pandas.NA becomes numpy.nan during deserialization.
"""
try:
if _has_pandas:
import pandas as pd

if _PANDAS_OBJECT_KEY in dct:
return pd.read_json(StringIO(dct[_PANDAS_OBJECT_KEY]), orient="table")
except ImportError:
pass
return dct


class DataFrameOperator(BaseOperator):
"""Custom operator to detect changes in DataFrames.

DeepDiff does not support diffing DataFrames. See https://github.com/seperman/deepdiff/issues/394.

This operator checks if the DataFrames are equal or not. It does not do a deep diff of
the contents of the DataFrame.
"""

def match(self, level) -> bool:
try:
import pandas as pd

return isinstance(level.t1, pd.DataFrame) and isinstance(
level.t2, pd.DataFrame
)
except ImportError:
# If Pandas is not installed, don't perform this check. We should log a warning.
return False

def give_up_diffing(self, level, diff_instance) -> bool:
if not level.t1.equals(level.t2):
diff_instance.custom_report_result(
_DIFF_ACTION_DATA_FRAME_CHANGED, level, {"value": level.t2}
)
return True


def diff_state(state1: Any, state2: Any) -> str:
"""
Diffs two state objects and returns the difference using DeepDiff's Delta format as a
JSON string.

DeepDiff does not support DataFrames yet. See `DataFrameOperator`.

The `to_flat_dicts` method does not include custom report results, so we need to add
those manually for the DataFrame case.
"""
if not is_dataclass(state1) or not is_dataclass(state2):
raise MesopException("Tried to diff state which was not a dataclass")

custom_actions = []

# Only use the `DataFrameOperator` if pandas exists.
if _has_pandas:
differences = DeepDiff(
state1, state2, custom_operators=[DataFrameOperator()]
)

# Manually format dataframe diffs to flat dict format.
if _DIFF_ACTION_DATA_FRAME_CHANGED in differences:
custom_actions = [
{
"path": parse_path(path),
"action": _DIFF_ACTION_DATA_FRAME_CHANGED,
**diff,
}
for path, diff in differences[_DIFF_ACTION_DATA_FRAME_CHANGED].items()
]
else:
differences = DeepDiff(state1, state2)

return json.dumps(
Delta(differences, always_include_values=True).to_flat_dicts()
+ custom_actions,
cls=MesopJSONEncoder,
)
Loading
Loading