-
Notifications
You must be signed in to change notification settings - Fork 373
[BugFix] Restore missing keys in data collector output #521
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've submitted as a draft because I'm expecting CI to fail on test/test_collectors:test_update_weights
.
I've been having a bit of trouble pinning down the issue because it originates in a sub-process, but it seems to be related to
TypeError: no implementation found for 'torch.nn.linear' on types that implement __torch_function__: [<class 'torchrl.data.tensordict.tensordict.TensorDict'>]
The test defines policy = torch.nn.Linear(3, 4).cuda(1)
,and it seems the problem arises during env.rollout(3, policy)
used to determine the keys for _tensordict_out
because torch doesn't know how to handle
tensordict = policy(tensordict)
Does the policy need to be wrapped first or something before doing env.rollout
?
Weirdly test I was expecting to fail is passing after all... Could be a quirk of my dev environment? Happy to look into that more if you like, but since the tests are passing in CI perhaps this can be reviewed as is. |
Is it the |
Yes, that's the one I saw running the tests myself but hasn't failed in CircleCI checks. |
torchrl/collectors/collectors.py
Outdated
batch_size=[*self.env.batch_size, self.frames_per_batch], | ||
device=self.passing_device, | ||
|
||
# TODO: perhaps check type of policy and raise TypeError if something |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can remove this comment
Long term plan is to support any kind of policy so i'm happy with the state of the checks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM thanks for this contribuiton!
Codecov Report
@@ Coverage Diff @@
## main #521 +/- ##
==========================================
+ Coverage 86.57% 86.73% +0.16%
==========================================
Files 121 121
Lines 21632 21680 +48
==========================================
+ Hits 18727 18804 +77
+ Misses 2905 2876 -29
Flags with carried forward coverage won't be shown. Click here to find out more.
📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more |
* [BugFix] Transformed ParallelEnv meta data are broken when passing to device (#531) * [Doc] Add coverage banner (#533) * add orb decov to circleci config.yml * Add codecov badge to Readme * Revert "[BugFix] Changing the dm_control import to fail if not installed (#515)" This reverts commit d194735. * codecov coverage w/o orb in circleci * Revert "Revert "[BugFix] Changing the dm_control import to fail if not installed (#515)"" This reverts commit d0dc7de. * [CI] generation of coverage reports (#534) * update test scripts to add coverage * update test scripts to add coverage Co-authored-by: Silvestre Bahi <silvestrebahi@fb.com> * [CI] Add xml coverage reports for codecov (#537) * update test scripts to add coverage * update test scripts to add coverage * generate xml file for coverage * Update run_test.sh lint end of file * Update run_test.sh lint end of file * Update run_test.sh lint end of file Co-authored-by: Silvestre Bahi <silvestrebahi@fb.com> * permissions * permissions Co-authored-by: Silvestre Bahi <silvestrebahi@fb.com> Co-authored-by: silvestrebahi <silvestre.bahi@gmail.com> * [BugFix] Fix colab link of coding_dqn.ipynb (#543) * [BugFix] Fix optional imports (#535) * [BugFix] Restore missing keys in data collector output (#521) * Ensure data collectors return all expected keys * Rerun CI * Add tests * Format code * correct unreachable test * Fix broken test * WIP: fix initialisation with policy + test * Fix initialisation with policy + test * Reset env after rollout initialisation * fix build from spec * Check policy has spec attribute before accessing * Address comments Co-authored-by: vmoens <vincentmoens@gmail.com> * [Lint] reorganize imports (#545) [Lint] reorganize imports * [BugFix] Single-cpu compatibility (#548) * [BugFix] vision install and other deps in optdeps (#552) * init * amend * amend * amend * [Feature] Implemented device argument for modules.models (#524) Co-authored-by: Yu Shiyang <yushiyang@fb.com> * [BugFix] Fix ellipsis indexing of 2d TensorDicts (#559) * [BugFix] Additive gaussian exploration spec fix (#560) * [BugFix] Disabling video step for wandb (#561) * [BugFix] Various device fix (#558) * [Feature] Allow collectors to accept regular modules as policies (#546) * [BugFix] Fix push binary nightly action (#566) Fix Push Binary Nightly Action for linux hosts Co-authored-by: Vincent Moens <vincentmoens@gmail.com> Co-authored-by: Silvestre Bahi <silvestrebahi@fb.com> Co-authored-by: silvestrebahi <silvestre.bahi@gmail.com> Co-authored-by: Bo Liu <benjaminliu.eecs@gmail.com> Co-authored-by: Tom Begley <tomcbegley@gmail.com> Co-authored-by: Alessandro Pietro Bardelli <apbard@users.noreply.github.com> Co-authored-by: Yu Shiyang <yushiyangk@users.noreply.github.com> Co-authored-by: Yu Shiyang <yushiyang@fb.com> Co-authored-by: Pavel Solikov <psolikov15@gmail.com>
Description
This PR fixes a bug with data collectors that saw some keys being omitted from the output when e.g. random actions were applied during initialisation.
Data collectors will now attempt to determine keys from the policy spec if it is supplied and complete, otherwise they will perform a short rollout with the policy to determine which keys will be returned.
Motivation and Context
Closes #505
Types of changes
What types of changes does your code introduce? Remove all that do not apply:
Checklist
I still need to add a test, currently seeing an error on one particular test I don't understand.