From f10abdd923369d1f80b515f85ab9a17e63fffa51 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Garriga-Alonso?= Date: Fri, 13 Oct 2023 10:13:01 -0700 Subject: [PATCH 1/2] DeMorgan's law to make ifs clearer --- stable_baselines3/common/pytree_dataclass.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/stable_baselines3/common/pytree_dataclass.py b/stable_baselines3/common/pytree_dataclass.py index a8a0f4807..4b9b8babb 100644 --- a/stable_baselines3/common/pytree_dataclass.py +++ b/stable_baselines3/common/pytree_dataclass.py @@ -17,10 +17,9 @@ import optree as ot import torch as th from optree import CustomTreeNode, PyTree -from typing_extensions import dataclass_transform - from stable_baselines3.common.type_aliases import TensorIndex from stable_baselines3.common.utils import zip_strict +from typing_extensions import dataclass_transform __all__ = [ "FrozenPyTreeDataclass", @@ -105,10 +104,10 @@ def __new__(mcs, name, bases, namespace, slots=True, **kwargs): frozen = issubclass(cls, FrozenPyTreeDataclass) if frozen: - if not (not issubclass(cls, MutablePyTreeDataclass) and issubclass(cls, FrozenPyTreeDataclass)): + if issubclass(cls, MutablePyTreeDataclass) or not issubclass(cls, FrozenPyTreeDataclass): raise TypeError(f"Frozen dataclass {cls} should inherit from FrozenPyTreeDataclass") else: - if not (issubclass(cls, MutablePyTreeDataclass) and not issubclass(cls, FrozenPyTreeDataclass)): + if not issubclass(cls, MutablePyTreeDataclass) or issubclass(cls, FrozenPyTreeDataclass)): raise TypeError(f"Mutable dataclass {cls} should inherit from MutablePyTreeDataclass") # Calling `dataclasses.dataclass` here, with slots, is what triggers the EARLY RETURN path above. From 228076831c78d8928738b87e65c37f8a2c2cc88c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Garriga-Alonso?= Date: Fri, 13 Oct 2023 10:15:45 -0700 Subject: [PATCH 2/2] Remove conditionality from `mcs.currently_registering` --- stable_baselines3/common/pytree_dataclass.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/stable_baselines3/common/pytree_dataclass.py b/stable_baselines3/common/pytree_dataclass.py index 4b9b8babb..8f5306f25 100644 --- a/stable_baselines3/common/pytree_dataclass.py +++ b/stable_baselines3/common/pytree_dataclass.py @@ -17,9 +17,10 @@ import optree as ot import torch as th from optree import CustomTreeNode, PyTree +from typing_extensions import dataclass_transform + from stable_baselines3.common.type_aliases import TensorIndex from stable_baselines3.common.utils import zip_strict -from typing_extensions import dataclass_transform __all__ = [ "FrozenPyTreeDataclass", @@ -83,9 +84,8 @@ def __new__(mcs, name, bases, namespace, slots=True, **kwargs): # Otherwise we just mark the current class as what we're registering. if not issubclass(cls, (FrozenPyTreeDataclass, MutablePyTreeDataclass)): raise TypeError(f"Dataclass {cls} should inherit from FrozenPyTreeDataclass or MutablePyTreeDataclass") - mcs.currently_registering = cls - else: - mcs.currently_registering = cls + + mcs.currently_registering = cls if name in _RESERVED_NAMES: if not ( @@ -107,7 +107,7 @@ def __new__(mcs, name, bases, namespace, slots=True, **kwargs): if issubclass(cls, MutablePyTreeDataclass) or not issubclass(cls, FrozenPyTreeDataclass): raise TypeError(f"Frozen dataclass {cls} should inherit from FrozenPyTreeDataclass") else: - if not issubclass(cls, MutablePyTreeDataclass) or issubclass(cls, FrozenPyTreeDataclass)): + if not issubclass(cls, MutablePyTreeDataclass) or issubclass(cls, FrozenPyTreeDataclass): raise TypeError(f"Mutable dataclass {cls} should inherit from MutablePyTreeDataclass") # Calling `dataclasses.dataclass` here, with slots, is what triggers the EARLY RETURN path above.