Skip to content

Fix A3C PyTorch implementation #2036

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 29 commits into from
May 30, 2018
Merged

Conversation

alok
Copy link
Contributor

@alok alok commented May 11, 2018

What do these changes do?

  1. Fixes up old broken torch code.
  2. Ensures that data is the proper shape.
  3. Renames some variables and removes unused ones.
  4. Makes code more idiomatic.

Related issue number

#2021. These changes are a subset of the ones in that PR, broken off to make
review easier.

alok added 5 commits May 11, 2018 10:53
Stateless functions should not be network layers.
Matches in_size and makes more sense.
Advantages and rewards both should be scalars, and therefore a list of them
should be 1D.
@AmplabJenkins
Copy link

Test PASSed.
Refer to this link for build results (access rights to CI server needed):
https://amplab.cs.berkeley.edu/jenkins//job/Ray-PRB/5335/
Test PASSed.

overall_err.backward()
torch.nn.utils.clip_grad_norm(
self._model.parameters(), self.config["grad_clip"])
torch.nn.utils.clip_grad_norm_(self._model.parameters(),
Copy link
Contributor Author

@alok alok May 11, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clip_grad_norm is deprecated in favor of the underscore version, hence the change

@@ -39,7 +39,7 @@ def __init__(self, in_channels, out_channels, kernel, stride, padding,
conv = nn.Conv2d(in_channels, out_channels, kernel, stride)
if initializer:
initializer(conv.weight)
nn.init.constant(conv.bias, bias_init)
nn.init.constant_(conv.bias, bias_init)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nn.init.constant is deprecated in favor of the underscore version, hence the change

Copy link
Contributor

@richardliaw richardliaw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Overall looks good; have you tested it out?

@alok
Copy link
Contributor Author

alok commented May 11, 2018 via email

@richardliaw
Copy link
Contributor

richardliaw commented May 12, 2018

awesome; can you make sure it runs on Pong? just as a sanity check.

We should seriously add pytorch to the test suite...

Copy link
Contributor

@richardliaw richardliaw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm conditioned on Pong running

@alok
Copy link
Contributor Author

alok commented May 12, 2018

As it turns out, pong doesn't run. Getting shape errors with the Conv2D layers which I think are unrelated to these changes, but have just been broken a while. @richardliaw Do you mind trying out this PR on Pong and taking a look at the errors?

@AmplabJenkins
Copy link

Test FAILed.
Refer to this link for build results (access rights to CI server needed):
https://amplab.cs.berkeley.edu/jenkins//job/Ray-PRB/5346/
Test FAILed.

@AmplabJenkins
Copy link

Test FAILed.
Refer to this link for build results (access rights to CI server needed):
https://amplab.cs.berkeley.edu/jenkins//job/Ray-PRB/5381/
Test FAILed.

alok added 4 commits May 14, 2018 23:06
Torch does this for us now.
* master:
  Create RemoteFunction class, remove FunctionProperties, simplify worker Python code. (ray-project#2052)
  Don't crash on duplicate actor notifications (ray-project#2043)
  Fixed attribute name in code example (ray-project#2054)
  [xray] Add Travis build for testing xray on Linux. (ray-project#2047)
  Added missing comma to code example (ray-project#2050)
  Use more CPUs for testMultipleWaitsAndGets. (ray-project#2051)
  use jobid_nil (ray-project#2044)
  Fix typo in tune. (ray-project#2046)
  Fix error in api.rst. (ray-project#2048)
  Improve shared_ptr usage (ray-project#2030)
@AmplabJenkins
Copy link

Test PASSed.
Refer to this link for build results (access rights to CI server needed):
https://amplab.cs.berkeley.edu/jenkins//job/Ray-PRB/5390/
Test PASSed.

@@ -21,7 +21,7 @@ def _init(self, inputs, num_outputs, options):
filters = options.get("conv_filters", [
[16, [8, 8], 4],
[32, [4, 4], 2],
[512, [10, 10], 1]
[512, [10, 1], 1],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are you sure about this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. This worked so I ran with it, but I don't know much about convnets, so this is just a SWAG. I'd appreciate if you could take a look at it.

@richardliaw
Copy link
Contributor

which pytorch version are you using?

@alok
Copy link
Contributor Author

alok commented May 15, 2018 via email

@richardliaw
Copy link
Contributor

I've been using the following script to test:

import ray
import torch
ray.init()
from ray.rllib.a3c import A3CAgent
from ray.rllib.a3c import DEFAULT_CONFIG
DEFAULT_CONFIG
config = DEFAULT_CONFIG.copy()
config["use_pytorch"] = True
config["model"]["channel_major"] = True
config["num_workers"] = 1
config["optimizer"]["grads_per_step"] = 10
import ipdb; ipdb.set_trace()
agent = A3CAgent(config=config, env="Pong-v0")
evaluator = agent.local_evaluator
agent.train()
policy = evaluator.policy
state = (evaluator.sampler.env.reset())
ob = torch.from_numpy(state).float().unsqueeze(0)

@AmplabJenkins
Copy link

Test PASSed.
Refer to this link for build results (access rights to CI server needed):
https://amplab.cs.berkeley.edu/jenkins//job/Ray-PRB/5417/
Test PASSed.

@AmplabJenkins
Copy link

Test PASSed.
Refer to this link for build results (access rights to CI server needed):
https://amplab.cs.berkeley.edu/jenkins//job/Ray-PRB/5418/
Test PASSed.

* master:
  Pin Pandas version for Travis to 0.22 (ray-project#2075)
  Fix python linting (ray-project#2076)
  [xray] Fix GCS table prefixes (ray-project#2065)
  Some tests for _submit API. (ray-project#2062)
  [rllib] Queue lib for python 2.7 (ray-project#2057)
  [autoscaler] Remove faulty assert that breaks during downscaling, pull configs from env (ray-project#2006)
  [DataFrame] Refactor indexers and implement setitem (ray-project#2020)
  [rllib]Update bc/policy.py (ray-project#2012)
@AmplabJenkins
Copy link

Test PASSed.
Refer to this link for build results (access rights to CI server needed):
https://amplab.cs.berkeley.edu/jenkins//job/Ray-PRB/5683/
Test PASSed.

Copy link
Contributor

@richardliaw richardliaw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK I'm going to add regression tests for PyTorch because for some reason, CartPole-v0 does not work on my machine

@@ -52,7 +51,7 @@
# (Image statespace) - Converts image to (dim, dim, C)
"dim": 80,
# (Image statespace) - Converts image shape to (C, dim, dim)
"channel_major": False
"channel_major" : False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there supposed to be a space here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No.

@alok
Copy link
Contributor Author

alok commented May 29, 2018

I'm looking at the best way to overhaul how state/action spaces are handled in torch. Since torch now supports scalars, we should be able to support the same range of envs as TF.

@richardliaw
Copy link
Contributor

How much do you think should go in this PR and how much do you think should go into a subsequent one? Keep in mind #2149 is pretty big and will go in soon...

@AmplabJenkins
Copy link

Test FAILed.
Refer to this link for build results (access rights to CI server needed):
https://amplab.cs.berkeley.edu/jenkins//job/Ray-PRB/5692/
Test FAILed.

@alok
Copy link
Contributor Author

alok commented May 29, 2018

I think this one should fix the shapes for Pendulum, Cartpole, and Pong. Anything else is probably best handled in a followup.

@AmplabJenkins
Copy link

Test FAILed.
Refer to this link for build results (access rights to CI server needed):
https://amplab.cs.berkeley.edu/jenkins//job/Ray-PRB/5694/
Test FAILed.

@richardliaw
Copy link
Contributor

Ok awesome - Jenkins is running PyTorch tests and Pong is passing while CartPole is not.

@AmplabJenkins
Copy link

Test FAILed.
Refer to this link for build results (access rights to CI server needed):
https://amplab.cs.berkeley.edu/jenkins//job/Ray-PRB/5695/
Test FAILed.

@alok
Copy link
Contributor Author

alok commented May 29, 2018

@richardliaw I think the current torch version of A3C only works in discrete action spaces since it always samples from a multinomial distribution, so we could punt supporting Pendulum and related envs for another PR while fixing CartPole and Pong in this PR.

@alok
Copy link
Contributor Author

alok commented May 30, 2018

@richardliaw @ericl Can you check this? This should be an OK set of changes to merge before the larger overhaul.

Pendulum doesn't work since it's an edge case (expects singleton arrays, which
`.squeeze()` collapses to scalars).
@alok
Copy link
Contributor Author

alok commented May 30, 2018

It fails Pendulum but runs on CartPole and Pong.

@richardliaw
Copy link
Contributor

richardliaw commented May 30, 2018 via email

@AmplabJenkins
Copy link

Test PASSed.
Refer to this link for build results (access rights to CI server needed):
https://amplab.cs.berkeley.edu/jenkins//job/Ray-PRB/5697/
Test PASSed.

@AmplabJenkins
Copy link

Test FAILed.
Refer to this link for build results (access rights to CI server needed):
https://amplab.cs.berkeley.edu/jenkins//job/Ray-PRB/5700/
Test FAILed.

@AmplabJenkins
Copy link

Test PASSed.
Refer to this link for build results (access rights to CI server needed):
https://amplab.cs.berkeley.edu/jenkins//job/Ray-PRB/5699/
Test PASSed.

@AmplabJenkins
Copy link

Test PASSed.
Refer to this link for build results (access rights to CI server needed):
https://amplab.cs.berkeley.edu/jenkins//job/Ray-PRB/5708/
Test PASSed.

@alok
Copy link
Contributor Author

alok commented May 30, 2018

@richardliaw This passes lint and implements the decision to only support envs like CartPole and Pong (for now).

@richardliaw
Copy link
Contributor

Test failures unrelated

@richardliaw richardliaw merged commit fd234e3 into ray-project:master May 30, 2018
@richardliaw
Copy link
Contributor

thanks for contributing this!

@alok alok deleted the fix-a3c-torch branch June 1, 2018 05:24
alok added a commit to alok/ray that referenced this pull request Jun 3, 2018
* master:
  [autoscaler] GCP node provider (ray-project#2061)
  [xray] Evict tasks from the lineage cache (ray-project#2152)
  [ASV] Add ray.init and simple Ray benchmarks (ray-project#2166)
  Re-encrypt key for uploading to S3 from travis to use travis-ci.com. (ray-project#2169)
  [rllib] Fix A3C PyTorch implementation (ray-project#2036)
  [JavaWorker] Do not kill local-scheduler-forked workers in RunManager.cleanup (ray-project#2151)
  Update Travis CI badge from travis-ci.org to travis-ci.com. (ray-project#2155)
  Implement Python global state API for xray. (ray-project#2125)
  [xray] Improve flush algorithm for the lineage cache (ray-project#2130)
  Fix support for actor classmethods (ray-project#2146)
  Add empty df test (ray-project#1879)
  [JavaWorker] Enable java worker support (ray-project#2094)
  [DataFrame] Fixing the code formatting of the tests (ray-project#2123)
  Update resource documentation (remove outdated limitations). (ray-project#2022)
  bugfix: use array redis_primary_addr out of its scope (ray-project#2139)
  Fix infinite retry in Push function. (ray-project#2133)
  [JavaWorker] Changes to the directory under src for support java worker (ray-project#2093)
  Integrate credis with Ray & route task table entries into credis. (ray-project#1841)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants