Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Pytree-Dataclass utilities #7

Merged
merged 10 commits into from
Oct 13, 2023
Merged

Add Pytree-Dataclass utilities #7

merged 10 commits into from
Oct 13, 2023

Conversation

rhaps0dy
Copy link
Collaborator

@rhaps0dy rhaps0dy commented Oct 8, 2023

  • Make a class into a Dataclass that is also a Pytree (via the optree library), by making it inherit from PyTreeDataclass or MutablePytreeDataclass.

  • Make some choice functions from optree default to the namespace SB3_NAMESPACE. This is needed so the classes defined using the above are expanded by default.

  • Annotate PyTree[th.Tensor] as TensorTree, and overload functions for that particular case. Annoyingly, Python type checkers are still not smart enough to have recursive generic references, so we specialize the type instead.

As with the previous Pr, I don't expect mypy and pytype to check in general, but I do expect them to typecheck correctly for the newly introduced file. That's why I modified config.yml accordingly.

Copy link

@dan-pandori dan-pandori left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Take my review with a grain of salt since I haven't made custom Python metaclasses myself before, so I'm not familiar with the style and bugbears.

This code does add a fair amount of non-trivial logic, and it would be nice to have some basic tests of it so that bitrot doesn't introduce breakages later (and confirm that the logic is doing what is expected now).

.circleci/config.yml Outdated Show resolved Hide resolved
stable_baselines3/common/pytree_dataclass.py Outdated Show resolved Hide resolved
stable_baselines3/common/pytree_dataclass.py Outdated Show resolved Hide resolved

if name != "_PyTreeDataclassBase":
if name not in ["PyTreeDataclass", "MutablePyTreeDataclass"]:
frozen = issubclass(cls, PyTreeDataclass)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we assert that frozen is either not in kwargs, or is set to the correct value here? I'm worried that folks might set frozen to something else, and then get surprised when this is overridden silently.

Something like:
assert "frozen" not in kwargs or frozen == kwargs.pop("frozen")

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense, good point!

stable_baselines3/common/pytree_dataclass.py Show resolved Hide resolved
@rhaps0dy rhaps0dy force-pushed the pytree-dataclass branch 2 times, most recently from f98816a to 1e7cf77 Compare October 11, 2023 18:34
Copy link

@dan-pandori dan-pandori left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the test suite!

mcs.currently_registering = None
return cls

else:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This 'else' is unnecessary as the if always kicks out of the function (either through an assertion failure, or a direct return). Removing the 'else' makes the nesting a little less severe, which is IMO slightly nice for reading (and makes it so I don't have to think about which if it corresponds to).

stable_baselines3/common/pytree_dataclass.py Show resolved Hide resolved
stable_baselines3/common/pytree_dataclass.py Show resolved Hide resolved
stable_baselines3/common/pytree_dataclass.py Show resolved Hide resolved
@rhaps0dy rhaps0dy merged commit 8f4a7d4 into main Oct 13, 2023
3 checks passed
Copy link
Collaborator Author

@rhaps0dy rhaps0dy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't realize I hadn't addressed everything yet when I closed this >.<

stable_baselines3/common/pytree_dataclass.py Outdated Show resolved Hide resolved

if name != "_PyTreeDataclassBase":
if name not in ["PyTreeDataclass", "MutablePyTreeDataclass"]:
frozen = issubclass(cls, PyTreeDataclass)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense, good point!

stable_baselines3/common/pytree_dataclass.py Show resolved Hide resolved
stable_baselines3/common/pytree_dataclass.py Show resolved Hide resolved
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants