Skip to content

[Feature] Implemented device argument for modules.models #524

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 12, 2022

Conversation

yushiyangk
Copy link
Contributor

@yushiyangk yushiyangk commented Oct 6, 2022

Description

Implemented a device argument for modules in torchrl.modules.models to allow them to be directly constructed on a particular device, rather than first constructing it then moving it with .to(). This follows the behaviour of most modules in torch.nn. Updated test_modules.py to make use of this.

However, an arbitrary nn.Module subclass does not necessarily support the device argument, in particular for modules that do not contain parameters or buffers. The utility function added in torchrl/modules/models/utils.py addresses this. This functionality is tested in test_mlp in test_modules.py, parameterised with both nn.ReLU (which does not support device) and nn.PReLU (which does).

Motivation and Context

Addresses #507

Types of changes

What types of changes does your code introduce? Remove all that do not apply:

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds core functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation (update in the documentation)
  • Example (update in the folder of examples)

Checklist

Go over all the following points, and put an x in all the boxes that apply.
If you are unsure about any of these, don't hesitate to ask. We are here to help!

  • I have read the CONTRIBUTION guide (required)
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 6, 2022
@yushiyangk yushiyangk force-pushed the main branch 3 times, most recently from 23551b0 to e3616a0 Compare October 6, 2022 17:21
*args: positional arguments to be passed to the module constructor.
**kwargs: keyword arguments to be passed to the module constructor.
"""
try:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

try / except are a bit expensive
Even if we don't create modules very often in training loops, we do in the tests, and on the long run this risks to have a negative effect on our test speed.
Do you think we can find some other way of doing this?

@yushiyangk yushiyangk force-pushed the main branch 3 times, most recently from 416264d to ac2c3fc Compare October 7, 2022 20:22
@codecov
Copy link

codecov bot commented Oct 7, 2022

Codecov Report

Merging #524 (e9cee79) into main (f351fd7) will decrease coverage by 0.17%.
The diff coverage is n/a.

❗ Current head e9cee79 differs from pull request most recent head 3bfc354. Consider uploading reports for the commit 3bfc354 to get more accurate results

@@            Coverage Diff             @@
##             main     #524      +/-   ##
==========================================
- Coverage   86.72%   86.55%   -0.18%     
==========================================
  Files         121      121              
  Lines       21671    21622      -49     
==========================================
- Hits        18795    18715      -80     
- Misses       2876     2907      +31     
Flag Coverage Δ
linux-cpu 85.08% <0.00%> (-0.02%) ⬇️
linux-gpu 86.39% <0.00%> (-0.16%) ⬇️
linux-outdeps-gpu 77.74% <0.00%> (-0.11%) ⬇️
linux-stable-cpu 85.06% <0.00%> (-0.02%) ⬇️
linux-stable-gpu 86.39% <0.00%> (-0.16%) ⬇️
macos-cpu 84.85% <0.00%> (-0.01%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
test/test_collector.py 90.18% <0.00%> (-6.48%) ⬇️
torchrl/trainers/loggers/mlflow.py 78.68% <0.00%> (-3.28%) ⬇️
torchrl/trainers/loggers/wandb.py 80.00% <0.00%> (-2.67%) ⬇️
torchrl/collectors/collectors.py 67.26% <0.00%> (-1.70%) ⬇️
test/test_helpers.py 88.62% <0.00%> (-0.34%) ⬇️
test/test_loggers.py 98.83% <0.00%> (-0.01%) ⬇️
test/test_cost.py 96.41% <0.00%> (ø)
test/test_tensordictmodules.py 98.42% <0.00%> (ø)
test/test_transforms.py 95.79% <0.00%> (+0.01%) ⬆️
torchrl/data/tensordict/tensordict.py 82.56% <0.00%> (+0.09%) ⬆️
... and 2 more

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more

@yushiyangk yushiyangk force-pushed the main branch 2 times, most recently from f8d5acf to 0ad34f1 Compare October 11, 2022 10:49
@yushiyangk yushiyangk marked this pull request as ready for review October 11, 2022 12:26
*args: positional arguments to be passed to the module constructor.
**kwargs: keyword arguments to be passed to the module constructor.
"""
fullargspec = inspect.getfullargspec(module_class.__init__)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

brilliant!

Copy link
Collaborator

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is reaaaally cool!
Thanks a million!

@vmoens vmoens merged commit 26e3083 into pytorch:main Oct 12, 2022
vmoens added a commit that referenced this pull request Oct 13, 2022
* [BugFix] Transformed ParallelEnv meta data are broken when passing to device (#531)

* [Doc] Add coverage banner (#533)

* add orb decov to circleci config.yml

* Add codecov badge to Readme

* Revert "[BugFix] Changing the dm_control import to fail if not installed (#515)"

This reverts commit d194735.

* codecov coverage w/o orb in circleci

* Revert "Revert "[BugFix] Changing the dm_control import to fail if not installed (#515)""

This reverts commit d0dc7de.

* [CI] generation of coverage reports (#534)

* update test scripts to add coverage

* update test scripts to add coverage

Co-authored-by: Silvestre Bahi <silvestrebahi@fb.com>

* [CI] Add xml coverage reports for codecov (#537)

* update test scripts to add coverage

* update test scripts to add coverage

* generate xml file for coverage

* Update run_test.sh

lint end of file

* Update run_test.sh

lint end of file

* Update run_test.sh

lint end of file

Co-authored-by: Silvestre Bahi <silvestrebahi@fb.com>

* permissions

* permissions

Co-authored-by: Silvestre Bahi <silvestrebahi@fb.com>
Co-authored-by: silvestrebahi <silvestre.bahi@gmail.com>

* [BugFix] Fix colab link of coding_dqn.ipynb (#543)

* [BugFix] Fix optional imports (#535)

* [BugFix] Restore missing keys in data collector output (#521)

* Ensure data collectors return all expected keys

* Rerun CI

* Add tests

* Format code

* correct unreachable test

* Fix broken test

* WIP: fix initialisation with policy + test

* Fix initialisation with policy + test

* Reset env after rollout initialisation

* fix build from spec

* Check policy has spec attribute before accessing

* Address comments

Co-authored-by: vmoens <vincentmoens@gmail.com>

* [Lint] reorganize imports (#545)

[Lint] reorganize imports

* [BugFix] Single-cpu compatibility (#548)

* [BugFix] vision install and other deps in optdeps (#552)

* init

* amend

* amend

* amend

* [Feature] Implemented device argument for modules.models (#524)

Co-authored-by: Yu Shiyang <yushiyang@fb.com>

* [BugFix] Fix ellipsis indexing of 2d TensorDicts (#559)

* [BugFix] Additive gaussian exploration spec fix (#560)

* [BugFix] Disabling video step for wandb (#561)

* [BugFix] Various device fix (#558)

* [Feature] Allow collectors to accept regular modules as policies (#546)

* [BugFix] Fix push binary nightly action (#566)

Fix Push Binary Nightly Action for linux hosts

Co-authored-by: Vincent Moens <vincentmoens@gmail.com>
Co-authored-by: Silvestre Bahi <silvestrebahi@fb.com>
Co-authored-by: silvestrebahi <silvestre.bahi@gmail.com>
Co-authored-by: Bo Liu <benjaminliu.eecs@gmail.com>
Co-authored-by: Tom Begley <tomcbegley@gmail.com>
Co-authored-by: Alessandro Pietro Bardelli <apbard@users.noreply.github.com>
Co-authored-by: Yu Shiyang <yushiyangk@users.noreply.github.com>
Co-authored-by: Yu Shiyang <yushiyang@fb.com>
Co-authored-by: Pavel Solikov <psolikov15@gmail.com>
@GavinPHR GavinPHR mentioned this pull request Oct 13, 2022
@vmoens vmoens added the enhancement New feature or request label Oct 17, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Feature Request] All models in modules.models should have a device kwarg
3 participants