-
Notifications
You must be signed in to change notification settings - Fork 373
[Feature] Implemented device
argument for modules.models
#524
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
23551b0
to
e3616a0
Compare
torchrl/modules/models/utils.py
Outdated
*args: positional arguments to be passed to the module constructor. | ||
**kwargs: keyword arguments to be passed to the module constructor. | ||
""" | ||
try: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
try / except are a bit expensive
Even if we don't create modules very often in training loops, we do in the tests, and on the long run this risks to have a negative effect on our test speed.
Do you think we can find some other way of doing this?
416264d
to
ac2c3fc
Compare
Codecov Report
@@ Coverage Diff @@
## main #524 +/- ##
==========================================
- Coverage 86.72% 86.55% -0.18%
==========================================
Files 121 121
Lines 21671 21622 -49
==========================================
- Hits 18795 18715 -80
- Misses 2876 2907 +31
Flags with carried forward coverage won't be shown. Click here to find out more.
📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more |
f8d5acf
to
0ad34f1
Compare
*args: positional arguments to be passed to the module constructor. | ||
**kwargs: keyword arguments to be passed to the module constructor. | ||
""" | ||
fullargspec = inspect.getfullargspec(module_class.__init__) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
brilliant!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is reaaaally cool!
Thanks a million!
* [BugFix] Transformed ParallelEnv meta data are broken when passing to device (#531) * [Doc] Add coverage banner (#533) * add orb decov to circleci config.yml * Add codecov badge to Readme * Revert "[BugFix] Changing the dm_control import to fail if not installed (#515)" This reverts commit d194735. * codecov coverage w/o orb in circleci * Revert "Revert "[BugFix] Changing the dm_control import to fail if not installed (#515)"" This reverts commit d0dc7de. * [CI] generation of coverage reports (#534) * update test scripts to add coverage * update test scripts to add coverage Co-authored-by: Silvestre Bahi <silvestrebahi@fb.com> * [CI] Add xml coverage reports for codecov (#537) * update test scripts to add coverage * update test scripts to add coverage * generate xml file for coverage * Update run_test.sh lint end of file * Update run_test.sh lint end of file * Update run_test.sh lint end of file Co-authored-by: Silvestre Bahi <silvestrebahi@fb.com> * permissions * permissions Co-authored-by: Silvestre Bahi <silvestrebahi@fb.com> Co-authored-by: silvestrebahi <silvestre.bahi@gmail.com> * [BugFix] Fix colab link of coding_dqn.ipynb (#543) * [BugFix] Fix optional imports (#535) * [BugFix] Restore missing keys in data collector output (#521) * Ensure data collectors return all expected keys * Rerun CI * Add tests * Format code * correct unreachable test * Fix broken test * WIP: fix initialisation with policy + test * Fix initialisation with policy + test * Reset env after rollout initialisation * fix build from spec * Check policy has spec attribute before accessing * Address comments Co-authored-by: vmoens <vincentmoens@gmail.com> * [Lint] reorganize imports (#545) [Lint] reorganize imports * [BugFix] Single-cpu compatibility (#548) * [BugFix] vision install and other deps in optdeps (#552) * init * amend * amend * amend * [Feature] Implemented device argument for modules.models (#524) Co-authored-by: Yu Shiyang <yushiyang@fb.com> * [BugFix] Fix ellipsis indexing of 2d TensorDicts (#559) * [BugFix] Additive gaussian exploration spec fix (#560) * [BugFix] Disabling video step for wandb (#561) * [BugFix] Various device fix (#558) * [Feature] Allow collectors to accept regular modules as policies (#546) * [BugFix] Fix push binary nightly action (#566) Fix Push Binary Nightly Action for linux hosts Co-authored-by: Vincent Moens <vincentmoens@gmail.com> Co-authored-by: Silvestre Bahi <silvestrebahi@fb.com> Co-authored-by: silvestrebahi <silvestre.bahi@gmail.com> Co-authored-by: Bo Liu <benjaminliu.eecs@gmail.com> Co-authored-by: Tom Begley <tomcbegley@gmail.com> Co-authored-by: Alessandro Pietro Bardelli <apbard@users.noreply.github.com> Co-authored-by: Yu Shiyang <yushiyangk@users.noreply.github.com> Co-authored-by: Yu Shiyang <yushiyang@fb.com> Co-authored-by: Pavel Solikov <psolikov15@gmail.com>
Description
Implemented a
device
argument for modules intorchrl.modules.models
to allow them to be directly constructed on a particular device, rather than first constructing it then moving it with.to()
. This follows the behaviour of most modules intorch.nn
. Updatedtest_modules.py
to make use of this.However, an arbitrary
nn.Module
subclass does not necessarily support thedevice
argument, in particular for modules that do not contain parameters or buffers. The utility function added intorchrl/modules/models/utils.py
addresses this. This functionality is tested intest_mlp
intest_modules.py
, parameterised with bothnn.ReLU
(which does not supportdevice
) andnn.PReLU
(which does).Motivation and Context
Addresses #507
Types of changes
What types of changes does your code introduce? Remove all that do not apply:
Checklist
Go over all the following points, and put an
x
in all the boxes that apply.If you are unsure about any of these, don't hesitate to ask. We are here to help!