Skip to content

[Feature] Allow collectors to accept regular modules as policies #546

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Oct 12, 2022

Conversation

tcbegley
Copy link
Contributor

Description

This PR makes changes to the collector classes in order that users can pass simple nn.Modules as policies.

It checks that the signature of the forward method matches the next_ entries in the environment's observation_spec. The first returned value is assumed to be the action while remaining return values are simply enumerated and labelled output1, output2, etc.

Motivation and Context

Closes #544

Types of changes

  • New feature (non-breaking change which adds core functionality)

Checklist

  • I have read the CONTRIBUTION guide (required)
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 10, 2022
@vmoens vmoens changed the title Allow collectors to accept regular modules as policies [Feature] Allow collectors to accept regular modules as policies Oct 10, 2022
elif (
observation_spec is not None
and isinstance(policy, nn.Module)
and not isinstance(policy, TensorDictModule)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we could write a helper function to determine if a policy is not reading tensordicts
Here are the possible cases:

  • the policy is a TensorDictModule
  • the policy is a nn.Module that has in_keys, out_keys and reads tensordicts but it is not a TensorDictModule
  • the policy is a nn.Module that reads tensors (and maybe also tensordicts)
    In the last case we should wrap it into a TensorDictModule

So what we could do is:

  • if it's a TDModule we're good
  • If it's not but it has in_keys, out_keys and the signature reads only one value we're good
  • Otherwise we should wrap that thing in a TensorDictModule

Another option would be to look at the type-hinting in the forward of the module but I'm not sure we really need that

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think this makes sense.

When you say helper function, do you mean pulling the elif condition out into a call to e.g. _module_is_tensordict_compatible(module) or similar? Or did you mean the actual body of the elif?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Had a go at this, let me know if it's what you had in mind.

sig = inspect.signature(policy.forward)
next_observation = {
key[5:]: value
for key, value in observation_spec.rand().items()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need for rand, we should be good with observation_spec.items()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason for putting rand in there was because a couple of lines down I want to do output = policy(**next_observation) to count output arguments and if I attempt to do this just on the spec items I get errors like

TypeError: linear: argument 'input' (position 1) must be Tensor, not NdBoundedTensorSpec

@tcbegley tcbegley force-pushed the auto-tensordictmodule branch from 2d47fda to 9920785 Compare October 11, 2022 13:26
@vmoens vmoens added the enhancement New feature or request label Oct 11, 2022
@@ -81,6 +83,15 @@ def recursive_map_to_cpu(dictionary: OrderedDict) -> OrderedDict:
)


def _module_is_tensordict_compatible(module: nn.Module):
sig = inspect.signature(module.forward)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess I was more thinking of this

a function that returns True in these cases
1.

class MyModule(nn.Module):
    in_keys = ["stuff"]
    out_keys = ["other_stuff"]
    def forward(self, tensordict):
        pass
  1. TensorDictModule

returns False in those cases
1.

class MyModule(nn.Module):
    in_keys = ["stuff"]
    out_keys = ["other_stuff"]
    def forward(self, tensordict, tensor):  # tensor is unexpected here
        pass
class MyModule(nn.Module):
    def forward(self, tensor):
        pass

We could also be super restrictive
Return True iif one of:

  • TensorDictmodule
  • nn.Module + in_keys + out_keys + one input only (assumed to be of type tensordict)

Return False iif

  • no in_keys / out_keys attribute and whatever input

if there are in_keys and/or out_keys but there are multiple inputs or similar, it's likely that the use is trying to do smth that will have an undetermined behaviour.

Other cases: throw an error

(in_keys / out_keys is an indirect indicator of "smth that works with TensorDict" which is always worse than a direct indicator but I guess that's the best we can do right now)

@codecov
Copy link

codecov bot commented Oct 11, 2022

Codecov Report

Merging #546 (9920785) into main (59b1f2b) will increase coverage by 0.07%.
The diff coverage is 94.62%.

❗ Current head 9920785 differs from pull request most recent head bba04d4. Consider uploading reports for the commit bba04d4 to get more accurate results

@@            Coverage Diff             @@
##             main     #546      +/-   ##
==========================================
+ Coverage   86.78%   86.85%   +0.07%     
==========================================
  Files         121      121              
  Lines       21667    21756      +89     
==========================================
+ Hits        18804    18897      +93     
+ Misses       2863     2859       -4     
Flag Coverage Δ
linux-cpu 85.19% <94.62%> (+0.07%) ⬆️
linux-gpu 86.63% <94.62%> (+0.07%) ⬆️
linux-outdeps-gpu 75.11% <35.48%> (-0.14%) ⬇️
linux-stable-cpu 85.17% <94.62%> (+0.07%) ⬆️
linux-stable-gpu 86.63% <94.62%> (+0.07%) ⬆️
macos-cpu 84.95% <94.62%> (+0.08%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
torchrl/modules/models/models.py 95.07% <0.00%> (ø)
torchrl/collectors/collectors.py 69.58% <88.88%> (+0.62%) ⬆️
test/test_collector.py 98.55% <98.46%> (+1.89%) ⬆️

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more

Comment on lines 86 to 113
def _policy_is_tensordict_compatible(policy: nn.Module):
sig = inspect.signature(policy.forward)

if isinstance(policy, TensorDictModule) or (
len(sig.parameters) == 1
and hasattr(policy, "in_keys")
and hasattr(policy, "out_keys")
):
# if the policy is a TensorDictModule or takes a single argument and defines
# in_keys and out_keys then we assume it can already deal with TensorDict input
# to forward and we return True
return True
elif not hasattr(policy, "in_keys") and not hasattr(policy, "out_keys"):
# if it's not a TensorDictModule, and in_keys and out_keys are not defined then
# we assume no TensorDict compatibility and will try to wrap it.
return False

# if in_keys or out_keys were defined but policy is not a TensorDictModule or
# accepts multiple arguments then it's likely the user is trying to do something
# that will have undetermined behaviour, we raise an error
raise TypeError(
"Received a policy that defines in_keys or out_keys and also expects multiple "
"arguments to self.forward. If the policy is compatible with TensorDict, it "
"should take a single argument of type TensorDict to self.forward and define "
"both in_keys and out_keys. Alternatively, self.forward can accept arbitrarily "
"many tensor inputs and leave in_keys and out_keys undefined and TorchRL will "
"attempt to automatically wrap the policy with a TensorDictModule."
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vmoens is this what you had in mind in your previous comment?

Comment on lines +162 to +163
# TODO: revisit these checks when we have determined whether arbitrary
# callables should be supported as policies.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More thought is going into whether arbitrary callables should be supported as policies. Once there is a clear answer on this question these checks should perhaps be revisited.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes I think that this is cool
if it's not a nn.Module we assume that it works out of the box

Copy link
Collaborator

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM thanks for this!

# Conflicts:
#	torchrl/collectors/collectors.py
@vmoens vmoens merged commit 4862459 into pytorch:main Oct 12, 2022
vmoens added a commit that referenced this pull request Oct 13, 2022
* [BugFix] Transformed ParallelEnv meta data are broken when passing to device (#531)

* [Doc] Add coverage banner (#533)

* add orb decov to circleci config.yml

* Add codecov badge to Readme

* Revert "[BugFix] Changing the dm_control import to fail if not installed (#515)"

This reverts commit d194735.

* codecov coverage w/o orb in circleci

* Revert "Revert "[BugFix] Changing the dm_control import to fail if not installed (#515)""

This reverts commit d0dc7de.

* [CI] generation of coverage reports (#534)

* update test scripts to add coverage

* update test scripts to add coverage

Co-authored-by: Silvestre Bahi <silvestrebahi@fb.com>

* [CI] Add xml coverage reports for codecov (#537)

* update test scripts to add coverage

* update test scripts to add coverage

* generate xml file for coverage

* Update run_test.sh

lint end of file

* Update run_test.sh

lint end of file

* Update run_test.sh

lint end of file

Co-authored-by: Silvestre Bahi <silvestrebahi@fb.com>

* permissions

* permissions

Co-authored-by: Silvestre Bahi <silvestrebahi@fb.com>
Co-authored-by: silvestrebahi <silvestre.bahi@gmail.com>

* [BugFix] Fix colab link of coding_dqn.ipynb (#543)

* [BugFix] Fix optional imports (#535)

* [BugFix] Restore missing keys in data collector output (#521)

* Ensure data collectors return all expected keys

* Rerun CI

* Add tests

* Format code

* correct unreachable test

* Fix broken test

* WIP: fix initialisation with policy + test

* Fix initialisation with policy + test

* Reset env after rollout initialisation

* fix build from spec

* Check policy has spec attribute before accessing

* Address comments

Co-authored-by: vmoens <vincentmoens@gmail.com>

* [Lint] reorganize imports (#545)

[Lint] reorganize imports

* [BugFix] Single-cpu compatibility (#548)

* [BugFix] vision install and other deps in optdeps (#552)

* init

* amend

* amend

* amend

* [Feature] Implemented device argument for modules.models (#524)

Co-authored-by: Yu Shiyang <yushiyang@fb.com>

* [BugFix] Fix ellipsis indexing of 2d TensorDicts (#559)

* [BugFix] Additive gaussian exploration spec fix (#560)

* [BugFix] Disabling video step for wandb (#561)

* [BugFix] Various device fix (#558)

* [Feature] Allow collectors to accept regular modules as policies (#546)

* [BugFix] Fix push binary nightly action (#566)

Fix Push Binary Nightly Action for linux hosts

Co-authored-by: Vincent Moens <vincentmoens@gmail.com>
Co-authored-by: Silvestre Bahi <silvestrebahi@fb.com>
Co-authored-by: silvestrebahi <silvestre.bahi@gmail.com>
Co-authored-by: Bo Liu <benjaminliu.eecs@gmail.com>
Co-authored-by: Tom Begley <tomcbegley@gmail.com>
Co-authored-by: Alessandro Pietro Bardelli <apbard@users.noreply.github.com>
Co-authored-by: Yu Shiyang <yushiyangk@users.noreply.github.com>
Co-authored-by: Yu Shiyang <yushiyang@fb.com>
Co-authored-by: Pavel Solikov <psolikov15@gmail.com>
@tcbegley tcbegley deleted the auto-tensordictmodule branch November 10, 2022 16:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Feature Request] Allow data collectors to receive policies that read and write regular tensors
3 participants