diff --git a/stable_baselines3/common/pytree_dataclass.py b/stable_baselines3/common/pytree_dataclass.py index a8a0f4807..8f5306f25 100644 --- a/stable_baselines3/common/pytree_dataclass.py +++ b/stable_baselines3/common/pytree_dataclass.py @@ -84,9 +84,8 @@ def __new__(mcs, name, bases, namespace, slots=True, **kwargs): # Otherwise we just mark the current class as what we're registering. if not issubclass(cls, (FrozenPyTreeDataclass, MutablePyTreeDataclass)): raise TypeError(f"Dataclass {cls} should inherit from FrozenPyTreeDataclass or MutablePyTreeDataclass") - mcs.currently_registering = cls - else: - mcs.currently_registering = cls + + mcs.currently_registering = cls if name in _RESERVED_NAMES: if not ( @@ -105,10 +104,10 @@ def __new__(mcs, name, bases, namespace, slots=True, **kwargs): frozen = issubclass(cls, FrozenPyTreeDataclass) if frozen: - if not (not issubclass(cls, MutablePyTreeDataclass) and issubclass(cls, FrozenPyTreeDataclass)): + if issubclass(cls, MutablePyTreeDataclass) or not issubclass(cls, FrozenPyTreeDataclass): raise TypeError(f"Frozen dataclass {cls} should inherit from FrozenPyTreeDataclass") else: - if not (issubclass(cls, MutablePyTreeDataclass) and not issubclass(cls, FrozenPyTreeDataclass)): + if not issubclass(cls, MutablePyTreeDataclass) or issubclass(cls, FrozenPyTreeDataclass): raise TypeError(f"Mutable dataclass {cls} should inherit from MutablePyTreeDataclass") # Calling `dataclasses.dataclass` here, with slots, is what triggers the EARLY RETURN path above.