Skip to content

Commit

Permalink
Merge pull request #4 from AlignmentResearch/contrib-recurrent
Browse files Browse the repository at this point in the history
Generic hidden state for RecurrentPPO
  • Loading branch information
rhaps0dy authored Oct 13, 2023
2 parents 2ce1723 + 188b417 commit c0ac130
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions stable_baselines3/common/pytree_dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,8 @@ def __new__(mcs, name, bases, namespace, slots=True, **kwargs):
# Otherwise we just mark the current class as what we're registering.
if not issubclass(cls, (FrozenPyTreeDataclass, MutablePyTreeDataclass)):
raise TypeError(f"Dataclass {cls} should inherit from FrozenPyTreeDataclass or MutablePyTreeDataclass")
mcs.currently_registering = cls
else:
mcs.currently_registering = cls

mcs.currently_registering = cls

if name in _RESERVED_NAMES:
if not (
Expand All @@ -105,10 +104,10 @@ def __new__(mcs, name, bases, namespace, slots=True, **kwargs):

frozen = issubclass(cls, FrozenPyTreeDataclass)
if frozen:
if not (not issubclass(cls, MutablePyTreeDataclass) and issubclass(cls, FrozenPyTreeDataclass)):
if issubclass(cls, MutablePyTreeDataclass) or not issubclass(cls, FrozenPyTreeDataclass):
raise TypeError(f"Frozen dataclass {cls} should inherit from FrozenPyTreeDataclass")
else:
if not (issubclass(cls, MutablePyTreeDataclass) and not issubclass(cls, FrozenPyTreeDataclass)):
if not issubclass(cls, MutablePyTreeDataclass) or issubclass(cls, FrozenPyTreeDataclass):
raise TypeError(f"Mutable dataclass {cls} should inherit from MutablePyTreeDataclass")

# Calling `dataclasses.dataclass` here, with slots, is what triggers the EARLY RETURN path above.
Expand Down

0 comments on commit c0ac130

Please sign in to comment.