Skip to content

Commit

Permalink
Test all the pytree and dataclass functions/metaclasses
Browse files Browse the repository at this point in the history
  • Loading branch information
rhaps0dy committed Oct 11, 2023
1 parent a9370d8 commit 57dfc30
Show file tree
Hide file tree
Showing 2 changed files with 229 additions and 27 deletions.
55 changes: 37 additions & 18 deletions stable_baselines3/common/pytree_dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@

SB3_NAMESPACE = "stable-baselines3"

_RESERVED_NAMES = ["_PyTreeDataclassBase", "FrozenPyTreeDataclass", "MutablePyTreeDataclass"]


# We need to inherit from `type(CustomTreeNode)` to prevent conflicts due to different-inheritance in metaclasses.
# - For some reason just inheriting from `typing._ProtocolMeta` does not get rid of that error.
Expand Down Expand Up @@ -69,30 +71,45 @@ def __new__(mcs, name, bases, namespace, slots=True, **kwargs):
# We've already created and annotated a class without __slots__, now we create the one with __slots__
# that will actually get returned after from the __new__ method.
assert mcs.currently_registering.__module__ == cls.__module__
assert mcs.currently_registering.__name__ == cls.__name__
mcs.currently_registering = None
return cls

else:
if not (
cls.__name__ in ["FrozenPyTreeDataclass", "MutablePyTreeDataclass"]
or issubclass(cls, (FrozenPyTreeDataclass, MutablePyTreeDataclass))
):
assert name not in _RESERVED_NAMES, (
f"Class with name {name}: classes {_RESERVED_NAMES} don't inherit from a dataclass, so they should "
"not be in this branch."
)

# Otherwise we just mark the current class as what we're registering.
if not issubclass(cls, (FrozenPyTreeDataclass, MutablePyTreeDataclass)):
raise TypeError(f"Dataclass {cls} should inherit from FrozenPyTreeDataclass or MutablePyTreeDataclass")
mcs.currently_registering = cls
else:
mcs.currently_registering = cls

if name != "_PyTreeDataclassBase":
if name not in ["FrozenPyTreeDataclass", "MutablePyTreeDataclass"]:
frozen = issubclass(cls, FrozenPyTreeDataclass)
if frozen:
if not (not issubclass(cls, MutablePyTreeDataclass) and issubclass(cls, FrozenPyTreeDataclass)):
raise TypeError(f"Frozen dataclass {cls} should inherit from FrozenPyTreeDataclass")
else:
if not (issubclass(cls, MutablePyTreeDataclass) and not issubclass(cls, FrozenPyTreeDataclass)):
raise TypeError(f"Mutable dataclass {cls} should inherit from MutablePyTreeDataclass")
if name in _RESERVED_NAMES:
if not (
namespace["__module__"] == "stable_baselines3.common.pytree_dataclass" and namespace["__qualname__"] == name
):
raise TypeError(f"You cannot have another class named {name} with metaclass=_PyTreeDataclassMeta")

if name == "_PyTreeDataclassBase":
return cls
frozen = kwargs.pop("frozen")
else:
if "frozen" in kwargs:
raise TypeError(
"You should not specify frozen= for descendants of FrozenPyTreeDataclass or MutablePyTreeDataclass"
)

frozen = issubclass(cls, FrozenPyTreeDataclass)
if frozen:
if not (not issubclass(cls, MutablePyTreeDataclass) and issubclass(cls, FrozenPyTreeDataclass)):
raise TypeError(f"Frozen dataclass {cls} should inherit from FrozenPyTreeDataclass")
else:
frozen = kwargs.pop("frozen")
if not (issubclass(cls, MutablePyTreeDataclass) and not issubclass(cls, FrozenPyTreeDataclass)):
raise TypeError(f"Mutable dataclass {cls} should inherit from MutablePyTreeDataclass")

# Calling `dataclasses.dataclass` here, with slots, is what triggers the EARLY RETURN path above.
cls = dataclasses.dataclass(frozen=frozen, slots=slots, **kwargs)(cls)
Expand Down Expand Up @@ -173,7 +190,7 @@ class MutablePyTreeDataclass(_PyTreeDataclassBase[T], Generic[T], frozen=False):
@overload
def tree_flatten(
tree: TensorTree,
is_leaf: Callable[[TensorTree], bool] | None,
is_leaf: Callable[[TensorTree], bool] | None = None,
*,
none_is_leaf: bool = False,
namespace: str = SB3_NAMESPACE,
Expand All @@ -184,7 +201,7 @@ def tree_flatten(
@overload
def tree_flatten(
tree: PyTree[T],
is_leaf: Callable[[T], bool] | None,
is_leaf: Callable[[T], bool] | None = None,
*,
none_is_leaf: bool = False,
namespace: str = SB3_NAMESPACE,
Expand Down Expand Up @@ -232,14 +249,16 @@ def tree_map(func, tree, *rests, is_leaf=None, none_is_leaf=False, namespace=SB3
return ot.tree_map(func, tree, *rests, is_leaf=is_leaf, none_is_leaf=none_is_leaf, namespace=namespace)


def tree_empty(tree: ot.PyTree, namespace: str = SB3_NAMESPACE) -> bool:
def tree_empty(
tree: ot.PyTree, *, is_leaf: Callable[[T], bool] | None = None, none_is_leaf: bool = False, namespace: str = SB3_NAMESPACE
) -> bool:
"""Is the tree `tree` empty, i.e. without leaves?
:param tree: the tree to check
:param namespace: when expanding nodes, use this namespace
:return: True iff the tree is empty
"""
flattened_state, _ = ot.tree_flatten(tree, namespace=namespace)
flattened_state, _ = ot.tree_flatten(tree, is_leaf=is_leaf, none_is_leaf=none_is_leaf, namespace=namespace)
return not bool(flattened_state)


Expand Down
201 changes: 192 additions & 9 deletions tests/test_pytree_dataclass.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
from dataclasses import FrozenInstanceError
from typing import Optional

import pytest

from stable_baselines3.common.pytree_dataclass import (
FrozenPyTreeDataclass,
MutablePyTreeDataclass,
tree_map,
)
import stable_baselines3.common.pytree_dataclass as ptd


@pytest.mark.parametrize("ParentPyTreeClass", (ptd.FrozenPyTreeDataclass, ptd.MutablePyTreeDataclass))
def test_dataclass_mapped_have_slots(ParentPyTreeClass: type) -> None:
"""
If after running `tree_map` the class still has __slots__ and they're the same, then the correct class (the one with
__slots__) is what has been registered as a Pytree custom node.
"""

@pytest.mark.parametrize("parent_class", (FrozenPyTreeDataclass, MutablePyTreeDataclass))
def test_slots(parent_class):
class D(parent_class):
class D(ParentPyTreeClass):
a: int
b: str

Expand All @@ -18,7 +22,186 @@ class D(parent_class):
assert D.__slots__ == ("a", "b")
assert d.__slots__ == ("a", "b")

d2 = tree_map(lambda x: x * 2, d)
d2 = ptd.tree_map(lambda x: x * 2, d)

assert d2.a == 8 and d2.b == "bb"

assert isinstance(d2, D)
assert d2.__slots__ == d.__slots__


@pytest.mark.parametrize("ParentPyTreeClass", (ptd.FrozenPyTreeDataclass, ptd.MutablePyTreeDataclass))
def test_dataclass_frozen_explicit(ParentPyTreeClass: type) -> None:
class D(ParentPyTreeClass):
a: int

with pytest.raises(TypeError, match="You should not specify frozen= for descendants"):

class D(ParentPyTreeClass, frozen=True): # type: ignore
a: int


@pytest.mark.parametrize("frozen", (True, False))
def test_dataclass_must_be_descendant(frozen: bool) -> None:
"""classes with metaclass _PyTreeDataclassMeta must be descendants of FrozenPyTreeDataclass or MutablePyTreeDataclass"""

# First with arbitrary name
with pytest.raises(TypeError):

class D(ptd._PyTreeDataclassBase, frozen=frozen): # type: ignore
pass

with pytest.raises(TypeError):

class D(metaclass=ptd._PyTreeDataclassMeta, frozen=frozen): # type: ignore
pass

with pytest.raises(TypeError, match="[^ ]* dataclass .* should inherit"):

class D(ptd._PyTreeDataclassBase): # type: ignore
pass

with pytest.raises(TypeError, match="[^ ]* dataclass .* should inherit"):

class D(metaclass=ptd._PyTreeDataclassMeta): # type: ignore
pass

# Then try to copy each of the reserved names:
## _PyTreeDataclassBase
with pytest.raises(TypeError):

class _PyTreeDataclassBase(ptd._PyTreeDataclassBase, frozen=frozen): # type: ignore
pass

with pytest.raises(TypeError):

class _PyTreeDataclassBase(metaclass=ptd._PyTreeDataclassMeta, frozen=frozen): # type: ignore
pass

with pytest.raises(TypeError, match="You cannot have another class named"):

class _PyTreeDataclassBase(ptd._PyTreeDataclassBase): # type: ignore
pass

with pytest.raises(TypeError, match="You cannot have another class named"):

class _PyTreeDataclassBase(metaclass=ptd._PyTreeDataclassMeta): # type: ignore
pass

## FrozenPyTreeDataclass
with pytest.raises(TypeError):

class FrozenPyTreeDataclass(ptd._PyTreeDataclassBase, frozen=frozen): # type: ignore
pass

with pytest.raises(TypeError):

class FrozenPyTreeDataclass(metaclass=ptd._PyTreeDataclassMeta, frozen=frozen): # type: ignore
pass

with pytest.raises(TypeError, match="You cannot have another class named"):

class FrozenPyTreeDataclass(ptd._PyTreeDataclassBase): # type: ignore
pass

with pytest.raises(TypeError, match="You cannot have another class named"):

class FrozenPyTreeDataclass(metaclass=ptd._PyTreeDataclassMeta): # type: ignore
pass

## MutablePyTreeDataclass
with pytest.raises(TypeError):

class MutablePyTreeDataclass(ptd._PyTreeDataclassBase, frozen=frozen): # type: ignore
pass

with pytest.raises(TypeError):

class MutablePyTreeDataclass(metaclass=ptd._PyTreeDataclassMeta, frozen=frozen): # type: ignore
pass

with pytest.raises(TypeError, match="You cannot have another class named"):

class MutablePyTreeDataclass(ptd._PyTreeDataclassBase): # type: ignore
pass

with pytest.raises(TypeError, match="You cannot have another class named"):

class MutablePyTreeDataclass(metaclass=ptd._PyTreeDataclassMeta): # type: ignore
pass


def test_dataclass_frozen_or_not() -> None:
class MutA(ptd.MutablePyTreeDataclass):
a: int

class FrozenA(ptd.FrozenPyTreeDataclass):
a: int

inst1 = MutA(2)
inst2 = FrozenA(2)

inst1.a = 2
with pytest.raises(FrozenInstanceError):
inst2.a = 3 # type: ignore[misc]


@pytest.mark.parametrize("ParentPyTreeClass", (ptd.FrozenPyTreeDataclass, ptd.MutablePyTreeDataclass))
def test_dataclass_inheriting_dataclass(ParentPyTreeClass: type) -> None:
class A(ParentPyTreeClass):
a: int

inst = A(3)
assert inst.a == 3

class B(A):
b: int

inst = B(2, 4)
assert inst.a == 2
assert inst.b == 4


def test_tree_flatten() -> None:
class A(ptd.FrozenPyTreeDataclass):
a: Optional[int]

flat, _ = ptd.tree_flatten((A(3), A(None), {"a": A(4)}))
assert flat == [3, 4]


def test_tree_map() -> None:
class A(ptd.FrozenPyTreeDataclass):
a: Optional[int]

assert ptd.tree_map(lambda x: x * 2, ([2, 3], 4, A(5), None, {"a": 6})) == ([4, 6], 8, A(10), None, {"a": 12}) # type: ignore


def test_tree_empty() -> None:
assert ptd.tree_empty(()) # type: ignore
assert ptd.tree_empty([]) # type: ignore
assert ptd.tree_empty({}) # type: ignore
assert not ptd.tree_empty({"a": 2}) # type: ignore
assert not ptd.tree_empty([2]) # type: ignore

class A(ptd.FrozenPyTreeDataclass):
a: Optional[int]

assert ptd.tree_empty([A(None)]) # type: ignore
assert not ptd.tree_empty([A(None)], none_is_leaf=True) # type: ignore
assert not ptd.tree_empty([A(2)]) # type: ignore


def test_tree_index() -> None:
l1 = ["a", "b", "c"]
l2 = ["hi", "bye"]
idx = 1

e1 = l1[idx]
e2 = l2[idx]

class A(ptd.FrozenPyTreeDataclass):
a: str

out_tree = ptd.tree_index([A(l1), A(l2), l1, (l2, {"a": l1})], idx, is_leaf=lambda x: x is l1 or x is l2) # type: ignore
assert out_tree == [A(e1), A(e2), e1, (e2, {"a": e1})]

0 comments on commit 57dfc30

Please sign in to comment.