-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Fix A3C PyTorch implementation #2036
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Stateless functions should not be network layers.
Matches in_size and makes more sense.
Advantages and rewards both should be scalars, and therefore a list of them should be 1D.
Test PASSed. |
overall_err.backward() | ||
torch.nn.utils.clip_grad_norm( | ||
self._model.parameters(), self.config["grad_clip"]) | ||
torch.nn.utils.clip_grad_norm_(self._model.parameters(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
clip_grad_norm is deprecated in favor of the underscore version, hence the change
@@ -39,7 +39,7 @@ def __init__(self, in_channels, out_channels, kernel, stride, padding, | |||
conv = nn.Conv2d(in_channels, out_channels, kernel, stride) | |||
if initializer: | |||
initializer(conv.weight) | |||
nn.init.constant(conv.bias, bias_init) | |||
nn.init.constant_(conv.bias, bias_init) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nn.init.constant is deprecated in favor of the underscore version, hence the change
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! Overall looks good; have you tested it out?
Tested on Cartpole and it hit 200 reward pretty quick, so I think it works.
…-- Alok
On Fri, May 11, 2018 at 4:24 PM, Richard Liaw ***@***.***> wrote:
***@***.**** commented on this pull request.
Nice! Overall looks good; have you tested it out?
—
You are receiving this because you authored the thread.
Reply to this email directly, view it on GitHub
<#2036 (review)>,
or mute the thread
<https://github.com/notifications/unsubscribe-auth/AH8KTNivW6lKWak6OlBqEIVoDgessCg8ks5txh2ZgaJpZM4T7TwY>
.
|
awesome; can you make sure it runs on Pong? just as a sanity check. We should seriously add pytorch to the test suite... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm conditioned on Pong running
As it turns out, pong doesn't run. Getting shape errors with the Conv2D layers which I think are unrelated to these changes, but have just been broken a while. @richardliaw Do you mind trying out this PR on Pong and taking a look at the errors? |
Test FAILed. |
Test FAILed. |
Torch does this for us now.
* master: Create RemoteFunction class, remove FunctionProperties, simplify worker Python code. (ray-project#2052) Don't crash on duplicate actor notifications (ray-project#2043) Fixed attribute name in code example (ray-project#2054) [xray] Add Travis build for testing xray on Linux. (ray-project#2047) Added missing comma to code example (ray-project#2050) Use more CPUs for testMultipleWaitsAndGets. (ray-project#2051) use jobid_nil (ray-project#2044) Fix typo in tune. (ray-project#2046) Fix error in api.rst. (ray-project#2048) Improve shared_ptr usage (ray-project#2030)
Test PASSed. |
@@ -21,7 +21,7 @@ def _init(self, inputs, num_outputs, options): | |||
filters = options.get("conv_filters", [ | |||
[16, [8, 8], 4], | |||
[32, [4, 4], 2], | |||
[512, [10, 10], 1] | |||
[512, [10, 1], 1], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are you sure about this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No. This worked so I ran with it, but I don't know much about convnets, so this is just a SWAG. I'd appreciate if you could take a look at it.
which pytorch version are you using? |
0.4.0
|
I've been using the following script to test:
|
Test PASSed. |
Test PASSed. |
* master: Pin Pandas version for Travis to 0.22 (ray-project#2075) Fix python linting (ray-project#2076) [xray] Fix GCS table prefixes (ray-project#2065) Some tests for _submit API. (ray-project#2062) [rllib] Queue lib for python 2.7 (ray-project#2057) [autoscaler] Remove faulty assert that breaks during downscaling, pull configs from env (ray-project#2006) [DataFrame] Refactor indexers and implement setitem (ray-project#2020) [rllib]Update bc/policy.py (ray-project#2012)
Test PASSed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK I'm going to add regression tests for PyTorch because for some reason, CartPole-v0 does not work on my machine
python/ray/rllib/a3c/a3c.py
Outdated
@@ -52,7 +51,7 @@ | |||
# (Image statespace) - Converts image to (dim, dim, C) | |||
"dim": 80, | |||
# (Image statespace) - Converts image shape to (C, dim, dim) | |||
"channel_major": False | |||
"channel_major" : False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there supposed to be a space here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No.
I'm looking at the best way to overhaul how state/action spaces are handled in torch. Since torch now supports scalars, we should be able to support the same range of envs as TF. |
How much do you think should go in this PR and how much do you think should go into a subsequent one? Keep in mind #2149 is pretty big and will go in soon... |
Test FAILed. |
I think this one should fix the shapes for Pendulum, Cartpole, and Pong. Anything else is probably best handled in a followup. |
Test FAILed. |
Ok awesome - Jenkins is running PyTorch tests and Pong is passing while CartPole is not. |
Test FAILed. |
@richardliaw I think the current torch version of A3C only works in discrete action spaces since it always samples from a multinomial distribution, so we could punt supporting Pendulum and related envs for another PR while fixing CartPole and Pong in this PR. |
@richardliaw @ericl Can you check this? This should be an OK set of changes to merge before the larger overhaul. |
Pendulum doesn't work since it's an edge case (expects singleton arrays, which `.squeeze()` collapses to scalars).
It fails Pendulum but runs on CartPole and Pong. |
Yeah I just came to that decision too - that sounds good to me.
…On Tue, May 29, 2018 at 5:28 PM Alok Singh ***@***.***> wrote:
It fails Pendulum but runs on CartPole and Pong.
—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
<#2036 (comment)>, or mute
the thread
<https://github.com/notifications/unsubscribe-auth/AEUc5SxUJbn3sJBqaRHFRVp-RinlnC4tks5t3eetgaJpZM4T7TwY>
.
|
Test PASSed. |
Test FAILed. |
Test PASSed. |
Test PASSed. |
@richardliaw This passes lint and implements the decision to only support envs like CartPole and Pong (for now). |
Test failures unrelated |
thanks for contributing this! |
* master: [autoscaler] GCP node provider (ray-project#2061) [xray] Evict tasks from the lineage cache (ray-project#2152) [ASV] Add ray.init and simple Ray benchmarks (ray-project#2166) Re-encrypt key for uploading to S3 from travis to use travis-ci.com. (ray-project#2169) [rllib] Fix A3C PyTorch implementation (ray-project#2036) [JavaWorker] Do not kill local-scheduler-forked workers in RunManager.cleanup (ray-project#2151) Update Travis CI badge from travis-ci.org to travis-ci.com. (ray-project#2155) Implement Python global state API for xray. (ray-project#2125) [xray] Improve flush algorithm for the lineage cache (ray-project#2130) Fix support for actor classmethods (ray-project#2146) Add empty df test (ray-project#1879) [JavaWorker] Enable java worker support (ray-project#2094) [DataFrame] Fixing the code formatting of the tests (ray-project#2123) Update resource documentation (remove outdated limitations). (ray-project#2022) bugfix: use array redis_primary_addr out of its scope (ray-project#2139) Fix infinite retry in Push function. (ray-project#2133) [JavaWorker] Changes to the directory under src for support java worker (ray-project#2093) Integrate credis with Ray & route task table entries into credis. (ray-project#1841)
What do these changes do?
Related issue number
#2021. These changes are a subset of the ones in that PR, broken off to make
review easier.