-
Notifications
You must be signed in to change notification settings - Fork 373
[Feature] Allow collectors to accept regular modules as policies #546
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
torchrl/collectors/collectors.py
Outdated
elif ( | ||
observation_spec is not None | ||
and isinstance(policy, nn.Module) | ||
and not isinstance(policy, TensorDictModule) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess we could write a helper function to determine if a policy is not reading tensordicts
Here are the possible cases:
- the policy is a TensorDictModule
- the policy is a nn.Module that has in_keys, out_keys and reads tensordicts but it is not a TensorDictModule
- the policy is a nn.Module that reads tensors (and maybe also tensordicts)
In the last case we should wrap it into a TensorDictModule
So what we could do is:
- if it's a TDModule we're good
- If it's not but it has in_keys, out_keys and the signature reads only one value we're good
- Otherwise we should wrap that thing in a TensorDictModule
Another option would be to look at the type-hinting in the forward of the module but I'm not sure we really need that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I think this makes sense.
When you say helper function, do you mean pulling the elif
condition out into a call to e.g. _module_is_tensordict_compatible(module)
or similar? Or did you mean the actual body of the elif
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Had a go at this, let me know if it's what you had in mind.
torchrl/collectors/collectors.py
Outdated
sig = inspect.signature(policy.forward) | ||
next_observation = { | ||
key[5:]: value | ||
for key, value in observation_spec.rand().items() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no need for rand, we should be good with observation_spec.items()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason for putting rand
in there was because a couple of lines down I want to do output = policy(**next_observation)
to count output arguments and if I attempt to do this just on the spec items I get errors like
TypeError: linear: argument 'input' (position 1) must be Tensor, not NdBoundedTensorSpec
2d47fda
to
9920785
Compare
torchrl/collectors/collectors.py
Outdated
@@ -81,6 +83,15 @@ def recursive_map_to_cpu(dictionary: OrderedDict) -> OrderedDict: | |||
) | |||
|
|||
|
|||
def _module_is_tensordict_compatible(module: nn.Module): | |||
sig = inspect.signature(module.forward) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess I was more thinking of this
a function that returns True in these cases
1.
class MyModule(nn.Module):
in_keys = ["stuff"]
out_keys = ["other_stuff"]
def forward(self, tensordict):
pass
- TensorDictModule
returns False in those cases
1.
class MyModule(nn.Module):
in_keys = ["stuff"]
out_keys = ["other_stuff"]
def forward(self, tensordict, tensor): # tensor is unexpected here
pass
class MyModule(nn.Module):
def forward(self, tensor):
pass
We could also be super restrictive
Return True iif one of:
- TensorDictmodule
- nn.Module + in_keys + out_keys + one input only (assumed to be of type tensordict)
Return False iif
- no in_keys / out_keys attribute and whatever input
if there are in_keys and/or out_keys but there are multiple inputs or similar, it's likely that the use is trying to do smth that will have an undetermined behaviour.
Other cases: throw an error
(in_keys / out_keys is an indirect indicator of "smth that works with TensorDict" which is always worse than a direct indicator but I guess that's the best we can do right now)
Codecov Report
@@ Coverage Diff @@
## main #546 +/- ##
==========================================
+ Coverage 86.78% 86.85% +0.07%
==========================================
Files 121 121
Lines 21667 21756 +89
==========================================
+ Hits 18804 18897 +93
+ Misses 2863 2859 -4
Flags with carried forward coverage won't be shown. Click here to find out more.
📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more |
torchrl/collectors/collectors.py
Outdated
def _policy_is_tensordict_compatible(policy: nn.Module): | ||
sig = inspect.signature(policy.forward) | ||
|
||
if isinstance(policy, TensorDictModule) or ( | ||
len(sig.parameters) == 1 | ||
and hasattr(policy, "in_keys") | ||
and hasattr(policy, "out_keys") | ||
): | ||
# if the policy is a TensorDictModule or takes a single argument and defines | ||
# in_keys and out_keys then we assume it can already deal with TensorDict input | ||
# to forward and we return True | ||
return True | ||
elif not hasattr(policy, "in_keys") and not hasattr(policy, "out_keys"): | ||
# if it's not a TensorDictModule, and in_keys and out_keys are not defined then | ||
# we assume no TensorDict compatibility and will try to wrap it. | ||
return False | ||
|
||
# if in_keys or out_keys were defined but policy is not a TensorDictModule or | ||
# accepts multiple arguments then it's likely the user is trying to do something | ||
# that will have undetermined behaviour, we raise an error | ||
raise TypeError( | ||
"Received a policy that defines in_keys or out_keys and also expects multiple " | ||
"arguments to self.forward. If the policy is compatible with TensorDict, it " | ||
"should take a single argument of type TensorDict to self.forward and define " | ||
"both in_keys and out_keys. Alternatively, self.forward can accept arbitrarily " | ||
"many tensor inputs and leave in_keys and out_keys undefined and TorchRL will " | ||
"attempt to automatically wrap the policy with a TensorDictModule." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@vmoens is this what you had in mind in your previous comment?
# TODO: revisit these checks when we have determined whether arbitrary | ||
# callables should be supported as policies. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
More thought is going into whether arbitrary callables should be supported as policies. Once there is a clear answer on this question these checks should perhaps be revisited.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes I think that this is cool
if it's not a nn.Module we assume that it works out of the box
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM thanks for this!
# Conflicts: # torchrl/collectors/collectors.py
* [BugFix] Transformed ParallelEnv meta data are broken when passing to device (#531) * [Doc] Add coverage banner (#533) * add orb decov to circleci config.yml * Add codecov badge to Readme * Revert "[BugFix] Changing the dm_control import to fail if not installed (#515)" This reverts commit d194735. * codecov coverage w/o orb in circleci * Revert "Revert "[BugFix] Changing the dm_control import to fail if not installed (#515)"" This reverts commit d0dc7de. * [CI] generation of coverage reports (#534) * update test scripts to add coverage * update test scripts to add coverage Co-authored-by: Silvestre Bahi <silvestrebahi@fb.com> * [CI] Add xml coverage reports for codecov (#537) * update test scripts to add coverage * update test scripts to add coverage * generate xml file for coverage * Update run_test.sh lint end of file * Update run_test.sh lint end of file * Update run_test.sh lint end of file Co-authored-by: Silvestre Bahi <silvestrebahi@fb.com> * permissions * permissions Co-authored-by: Silvestre Bahi <silvestrebahi@fb.com> Co-authored-by: silvestrebahi <silvestre.bahi@gmail.com> * [BugFix] Fix colab link of coding_dqn.ipynb (#543) * [BugFix] Fix optional imports (#535) * [BugFix] Restore missing keys in data collector output (#521) * Ensure data collectors return all expected keys * Rerun CI * Add tests * Format code * correct unreachable test * Fix broken test * WIP: fix initialisation with policy + test * Fix initialisation with policy + test * Reset env after rollout initialisation * fix build from spec * Check policy has spec attribute before accessing * Address comments Co-authored-by: vmoens <vincentmoens@gmail.com> * [Lint] reorganize imports (#545) [Lint] reorganize imports * [BugFix] Single-cpu compatibility (#548) * [BugFix] vision install and other deps in optdeps (#552) * init * amend * amend * amend * [Feature] Implemented device argument for modules.models (#524) Co-authored-by: Yu Shiyang <yushiyang@fb.com> * [BugFix] Fix ellipsis indexing of 2d TensorDicts (#559) * [BugFix] Additive gaussian exploration spec fix (#560) * [BugFix] Disabling video step for wandb (#561) * [BugFix] Various device fix (#558) * [Feature] Allow collectors to accept regular modules as policies (#546) * [BugFix] Fix push binary nightly action (#566) Fix Push Binary Nightly Action for linux hosts Co-authored-by: Vincent Moens <vincentmoens@gmail.com> Co-authored-by: Silvestre Bahi <silvestrebahi@fb.com> Co-authored-by: silvestrebahi <silvestre.bahi@gmail.com> Co-authored-by: Bo Liu <benjaminliu.eecs@gmail.com> Co-authored-by: Tom Begley <tomcbegley@gmail.com> Co-authored-by: Alessandro Pietro Bardelli <apbard@users.noreply.github.com> Co-authored-by: Yu Shiyang <yushiyangk@users.noreply.github.com> Co-authored-by: Yu Shiyang <yushiyang@fb.com> Co-authored-by: Pavel Solikov <psolikov15@gmail.com>
Description
This PR makes changes to the collector classes in order that users can pass simple
nn.Module
s as policies.It checks that the signature of the
forward
method matches thenext_
entries in the environment'sobservation_spec
. The first returned value is assumed to be the action while remaining return values are simply enumerated and labelledoutput1
,output2
, etc.Motivation and Context
Closes #544
Types of changes
Checklist