-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Pytree-Dataclass utilities #7
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Take my review with a grain of salt since I haven't made custom Python metaclasses myself before, so I'm not familiar with the style and bugbears.
This code does add a fair amount of non-trivial logic, and it would be nice to have some basic tests of it so that bitrot doesn't introduce breakages later (and confirm that the logic is doing what is expected now).
|
||
if name != "_PyTreeDataclassBase": | ||
if name not in ["PyTreeDataclass", "MutablePyTreeDataclass"]: | ||
frozen = issubclass(cls, PyTreeDataclass) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we assert that frozen
is either not in kwargs
, or is set to the correct value here? I'm worried that folks might set frozen
to something else, and then get surprised when this is overridden silently.
Something like:
assert "frozen" not in kwargs or frozen == kwargs.pop("frozen")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That makes sense, good point!
f98816a
to
1e7cf77
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the test suite!
mcs.currently_registering = None | ||
return cls | ||
|
||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This 'else' is unnecessary as the if
always kicks out of the function (either through an assertion failure, or a direct return). Removing the 'else' makes the nesting a little less severe, which is IMO slightly nice for reading (and makes it so I don't have to think about which if
it corresponds to).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't realize I hadn't addressed everything yet when I closed this >.<
|
||
if name != "_PyTreeDataclassBase": | ||
if name not in ["PyTreeDataclass", "MutablePyTreeDataclass"]: | ||
frozen = issubclass(cls, PyTreeDataclass) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That makes sense, good point!
Make a class into a Dataclass that is also a Pytree (via the
optree
library), by making it inherit fromPyTreeDataclass
orMutablePytreeDataclass
.Make some choice functions from
optree
default to the namespaceSB3_NAMESPACE
. This is needed so the classes defined using the above are expanded by default.Annotate
PyTree[th.Tensor]
asTensorTree
, and overload functions for that particular case. Annoyingly, Python type checkers are still not smart enough to have recursive generic references, so we specialize the type instead.As with the previous Pr, I don't expect
mypy
andpytype
to check in general, but I do expect them to typecheck correctly for the newly introduced file. That's why I modifiedconfig.yml
accordingly.