Skip to content

Commit

Permalink
Remove conditionality from mcs.currently_registering
Browse files Browse the repository at this point in the history
  • Loading branch information
rhaps0dy committed Oct 13, 2023
1 parent f10abdd commit 2280768
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions stable_baselines3/common/pytree_dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@
import optree as ot
import torch as th
from optree import CustomTreeNode, PyTree
from typing_extensions import dataclass_transform

from stable_baselines3.common.type_aliases import TensorIndex
from stable_baselines3.common.utils import zip_strict
from typing_extensions import dataclass_transform

__all__ = [
"FrozenPyTreeDataclass",
Expand Down Expand Up @@ -83,9 +84,8 @@ def __new__(mcs, name, bases, namespace, slots=True, **kwargs):
# Otherwise we just mark the current class as what we're registering.
if not issubclass(cls, (FrozenPyTreeDataclass, MutablePyTreeDataclass)):
raise TypeError(f"Dataclass {cls} should inherit from FrozenPyTreeDataclass or MutablePyTreeDataclass")
mcs.currently_registering = cls
else:
mcs.currently_registering = cls

mcs.currently_registering = cls

if name in _RESERVED_NAMES:
if not (
Expand All @@ -107,7 +107,7 @@ def __new__(mcs, name, bases, namespace, slots=True, **kwargs):
if issubclass(cls, MutablePyTreeDataclass) or not issubclass(cls, FrozenPyTreeDataclass):
raise TypeError(f"Frozen dataclass {cls} should inherit from FrozenPyTreeDataclass")
else:
if not issubclass(cls, MutablePyTreeDataclass) or issubclass(cls, FrozenPyTreeDataclass)):
if not issubclass(cls, MutablePyTreeDataclass) or issubclass(cls, FrozenPyTreeDataclass):
raise TypeError(f"Mutable dataclass {cls} should inherit from MutablePyTreeDataclass")

# Calling `dataclasses.dataclass` here, with slots, is what triggers the EARLY RETURN path above.
Expand Down

0 comments on commit 2280768

Please sign in to comment.