Skip to content

[rllib] Contribute DDPG to RLlib #1877

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 33 commits into from
Apr 20, 2018
Merged

Conversation

joneswong
Copy link
Contributor

What do these changes do?

Implemented DDPG (see ./rllib/ddpg) in a consistent style with the DQN:

  • DDPGAgent
  • DDPGEvaluator
  • DDPGGraph

Validated on Pendulum-v0 and MountainCarContinuous-v0 with LocalSyncReplayOptimizer and ApeXOptimizer:

  • Using LocalSyncReplayOptimizer on Pendulum-v0 (see ./rllib/tuned_examples/pendulum-ddpg.yaml) and mean100rewards reaches -160 in around 30k to 40k timesteps
    pendulum_ddpg

  • Using ApeXOptimizer on Pendulum-v0 (see ./rllib/tuned_examples/pendulum-apex-ddpg.yaml) and mean100reward reaches -160 within around 15mins with 16 workers
    pendulum_ddpg_apex

  • Using LocalSyncReplayOptimizer on MountainCarContinuous-v0 (see ./rllib/tuned_examples/mountaincarcontinuous-ddpg.yaml) and mean100rewards reaches 90 in around 20k to 30k timesteps
    mountaincar_ddpg

  • Using ApeXOptimizer on MountainCarContinuous-v0 (see ./rllib/tuned_examples/mountaincarcontinuous-apex-ddpg.yaml) and mean100reward reaches 90 within around 15mins with 16 workers
    mountaincar_apex_ddpg

Some functionalities e.g., OU process for generating noise, Schedulers, etc. can be refactored as common utilities (i.e., put them in ./rllib/utils). However, we want to keep each pull-request clean and specific for one function.

Related issue number

#1868

@AmplabJenkins
Copy link

Test PASSed.
Refer to this link for build results (access rights to CI server needed):
https://amplab.cs.berkeley.edu/jenkins//job/Ray-PRB/4809/
Test PASSed.

@ericl ericl changed the title Contribute DDPG to RLlib [rllib] Contribute DDPG to RLlib Apr 11, 2018
@ericl
Copy link
Contributor

ericl commented Apr 11, 2018

Thanks for adding this! We'll need to think of some strategy for how to merge these DDPG impls, one possibility is to have ddpg1, ddpg2, ddpg3 initially. Then, we can compare their performance and integrate them better into a single algorithm.

#1685
#1868

@ericl ericl self-assigned this Apr 11, 2018
@ericl
Copy link
Contributor

ericl commented Apr 12, 2018

cc @vlad17

@@ -8,7 +8,7 @@


def _register_all():
for key in ["PPO", "ES", "DQN", "APEX", "A3C", "BC", "PG", "__fake",
for key in ["DDPG", "APEX_DDPG", "PPO", "ES", "DQN", "APEX", "A3C", "BC", "PG", "__fake",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's call this DDPG2 and APEX_DDPG2 for now, since there is a conflicting PR that was merged earlier. We will resolve the differences later to combine the implementations.

The package directory should also be moved to rllib/ddpg2.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

from __future__ import print_function

from ray.rllib.ddpg.apex import ApexDDPGAgent
from ray.rllib.ddpg.ddpg import DDPGAgent, DEFAULT_CONFIG
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ApexDDPG2Agent, DDPG2Agent

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -0,0 +1,108 @@
"""This file is used for specifying various schedules that evolve over
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file is literally a copy of dqn/schedules.py. Let's move that to common to avoid this code duplication.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To avoid making any change to dqn, I just use the schedulers of dqn currently.

# Override atari default to use the deepmind wrappers.
# TODO(ekl) this logic should be pushed to the catalog.
if is_atari and "custom_preprocessor" not in options:
return wrap_deepmind(env, random_starts=random_starts)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since DDPG isn't typically used with discrete action spaces (e.g, atari), how about we remove this wrapper and just use ModelCatalog.get_preprocessor...?

This means we can remove this file.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried your solution. However, without an complicated assertion on observation type, say Box(0.0, 1.0, (80, 80, 1), dtype=np.float32) is allowed while Box(0.0, 1.0, (210, 160, 3), dtype=np.float32) is unsupported, DDPG can NOT pass the test/test_supported_spaces.py test. Thus, I removed this file according to your comment but directly import and use it from dqn instead of your proposal. Any concern, please let me know and I will revise according to your comments.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, that sounds fine then, we can clean up Atari handling later.

smoothing_num_episodes=100,



Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two newlines only.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, got your conventions.

print(str(result))
pretty_print(result)

if __name__=="__main__":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it intended for this test to be run as part of the automated tests?

If so, consider adding it as an entry in run_multi_node_tests.sh, otherwise removing it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry for making it messy, I removed these files.

print(mean_100ep_reward)
"""

if __name__=="__main__":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this intended to be an automated test?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed.

ob = new_ob

if __name__=="__main__":
main()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this intended to be an automated test?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed.

self.saved_mean_reward = data[3]
self.obs = data[4]
self.global_timestep = data[5]
self.local_timestep = data[6]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little concerned this file has too much in common with dqn_evaluator.py. However, it's not clear if coupling DQN and DDPG would be a good idea either. @richardliaw any thoughts?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right. Both of them are in a Q-learning behavior. The differences can be expressed by DQNGraph and DDPGGraph. We can consider distilling a parent class like QAgent/QEvaluator later.

@@ -0,0 +1,15 @@
pendulum-ddpg:
env: Pendulum-v0
run: DDPG
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DDPG2 here and elsewhere in the YAML examples.

Btw, how long does this usually take to complete?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. It takes around 650 to 750 seconds.

@ericl
Copy link
Contributor

ericl commented Apr 12, 2018

Two other requests for tests:

  • Add an entry for DDPG2 to test_checkpoint_restore.py, to verify checkpointing correctness.
  • Add an entry for DDPG2 to test_supported_spaces.py, to verify action / observation space support.

@AmplabJenkins
Copy link

Test PASSed.
Refer to this link for build results (access rights to CI server needed):
https://amplab.cs.berkeley.edu/jenkins//job/Ray-PRB/4827/
Test PASSed.

@@ -0,0 +1,108 @@
"""This file is used for specifying various schedules that evolve over
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file can be removed.

@@ -0,0 +1,408 @@
from __future__ import absolute_import
from __future__ import division
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vladf can you take a look over the models here?

s_func_vars = _scope_vars(scope.name)

# Actor: P (policy) network
p_scope_name = TOWER_SCOPE_NAME + "/p_func"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can drop TOWER SCOPE name, that was only used for the multi GPU optimizer, which is being refactored to not need it

@joneswong
Copy link
Contributor Author

Hi teachers,
we don't have mujoco now. I will buy it later and conduct experiments on it.
I know how to run a Ray cluster but I have no idea how to prepare an entry (you mean a .yaml?) for you so that the Jenkins can automatically run corresponding scripts from different machines.
Sorry for making this pr trivial. Any requirement, please let me know. My group and I really want to make more contributions to Ray.

@AmplabJenkins
Copy link

Test PASSed.
Refer to this link for build results (access rights to CI server needed):
https://amplab.cs.berkeley.edu/jenkins//job/Ray-PRB/4872/
Test PASSed.

@ericl
Copy link
Contributor

ericl commented Apr 14, 2018

Looks like there is a conflicting file. I think we can merge this once that's fixed, but let's make sure to add results for HalfCheetah experiments afterwards.

@AmplabJenkins
Copy link

Test PASSed.
Refer to this link for build results (access rights to CI server needed):
https://amplab.cs.berkeley.edu/jenkins//job/Ray-PRB/4945/
Test PASSed.

@ericl
Copy link
Contributor

ericl commented Apr 16, 2018

@joneswong the lint tests are failing in travis. Can you run find . -name '*.py' -type f -exec yapf -i -r {} \; in the ddpg2 directory to fix the formatting?

@@ -22,15 +22,20 @@ def get_mean_action(alg, obs):
CONFIGS = {
"ES": {"episodes_per_batch": 10, "timesteps_per_batch": 100},
"DQN": {},
"DDPG2": {"noise_scale": 0.0},
Copy link
Contributor

@richardliaw richardliaw Apr 18, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add @alvkao58's DDPG to this file too (in a separate PR that is)

@@ -114,6 +114,7 @@ class ModelSupportedSpaces(unittest.TestCase):
def testAll(self):
ray.init()
stats = {}
check_support("DDPG2", {"timesteps_per_iteration": 1}, stats)
Copy link
Contributor

@richardliaw richardliaw Apr 18, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment about @alvkao58 (not in this pr)

@AmplabJenkins
Copy link

Test PASSed.
Refer to this link for build results (access rights to CI server needed):
https://amplab.cs.berkeley.edu/jenkins//job/Ray-PRB/4982/
Test PASSed.

@joneswong
Copy link
Contributor Author

@ericl formatted the .py files in ddpg2 folder by yapf by executing the command you provided.

@ericl
Copy link
Contributor

ericl commented Apr 18, 2018

Hm, there seem to be some lint errors still: https://api.travis-ci.org/v3/job/367974621/log.txt (click the travis details -> go to the LINT job)

travis_time:start:0f2ea2b0
�[0K$ flake8 --exclude=python/ray/core/src/common/flatbuffers_ep-prefix/,python/ray/core/generated/,src/common/format/,doc/source/conf.py,python/ray/cloudpickle/
./python/ray/rllib/ddpg2/models.py:25:80: E501 line too long (83 > 79 characters)
./python/ray/rllib/ddpg2/models.py:40:80: E501 line too long (87 > 79 characters)
./python/ray/rllib/ddpg2/models.py:40:87: E502 the backslash is redundant between brackets
./python/ray/rllib/ddpg2/models.py:41:26: E128 continuation line under-indented for visual indent
./python/ray/rllib/ddpg2/models.py:44:59: E502 the backslash is redundant between brackets
./python/ray/rllib/ddpg2/models.py:45:26: E128 continuation line under-indented for visual indent
./python/ray/rllib/ddpg2/models.py:45:80: E501 line too long (81 > 79 characters)
./python/ray/rllib/ddpg2/models.py:220:13: F841 local variable 'q_values' is assigned to but never used
./python/ray/rllib/ddpg2/models.py:270:20: E713 test for membership should be 'not in'
./python/ray/rllib/ddpg2/models.py:273:20: E713 test for membership should be 'not in'
./python/ray/rllib/ddpg2/models.py:300:80: E501 line too long (84 > 79 characters)
./python/ray/rllib/ddpg2/ddpg_evaluator.py:44:44: E225 missing whitespace around operator

@AmplabJenkins
Copy link

Test FAILed.
Refer to this link for build results (access rights to CI server needed):
https://amplab.cs.berkeley.edu/jenkins//job/Ray-PRB/5001/
Test FAILed.

@AmplabJenkins
Copy link

Test PASSed.
Refer to this link for build results (access rights to CI server needed):
https://amplab.cs.berkeley.edu/jenkins//job/Ray-PRB/5002/
Test PASSed.

@ericl ericl merged commit c9a7744 into ray-project:master Apr 20, 2018
@ericl
Copy link
Contributor

ericl commented Apr 20, 2018

Merged, thanks!

@robertnishihara
Copy link
Collaborator

Nice work!

@joneswong
Copy link
Contributor Author

Thanks for eric's patience. I learned some conventions of open source project from this pr. I will adhere to the code style in the following prs.

@richardliaw
Copy link
Contributor

richardliaw commented Apr 20, 2018 via email

royf added a commit to royf/ray that referenced this pull request Apr 22, 2018
* master:
  Handle interrupts correctly for ASIO synchronous reads and writes. (ray-project#1929)
  [DataFrame] Adding read methods and tests (ray-project#1712)
  Allow task_table_update to fail when tasks are finished. (ray-project#1927)
  [rllib] Contribute DDPG to RLlib (ray-project#1877)
  [xray] Workers blocked in a `ray.get` release their resources (ray-project#1920)
  Raylet task dispatch and throttling worker startup (ray-project#1912)
  [DataFrame] Eval fix (ray-project#1903)
  [tune] Polishing docs (ray-project#1846)
  [tune] [rllib] Automatically determine RLlib resources and add queueing mechanism for autoscaling (ray-project#1848)
  Preemptively push local arguments for actor tasks (ray-project#1901)
  [tune] Allow fetching pinned objects from trainable functions (ray-project#1895)
  Multithreading refactor for ObjectManager. (ray-project#1911)
  Add slice functionality (ray-project#1832)
  [DataFrame] Pass read_csv kwargs to _infer_column (ray-project#1894)
  Addresses missed comments from multichunk object transfer PR. (ray-project#1908)
  Allow numpy arrays to be passed by value into tasks (and inlined in the task spec). (ray-project#1816)
  [xray] Lineage cache requests notifications from the GCS about remote tasks (ray-project#1834)
  Fix UI issue for non-json-serializable task arguments. (ray-project#1892)
  Remove unnecessary calls to .hex() for object IDs. (ray-project#1910)
  Allow multiple raylets to be started on a single machine. (ray-project#1904)

# Conflicts:
#	python/ray/rllib/__init__.py
#	python/ray/rllib/dqn/dqn.py
alok added a commit to alok/ray that referenced this pull request Apr 28, 2018
* master:
  updates (ray-project#1958)
  Pin Cython in autoscaler development example. (ray-project#1951)
  Incorporate C++ Buffer management and Seal global threadpool fix from arrow (ray-project#1950)
  [XRay] Add consistency check for protocol between node_manager and local_scheduler_client (ray-project#1944)
  Remove smart_open install. (ray-project#1943)
  [DataFrame] Fully implement append, concat and join (ray-project#1932)
  [DataFrame] Fix for __getitem__ string indexing (ray-project#1939)
  [DataFrame] Implementing write methods (ray-project#1918)
  [rllib] arr[end] was excluded when end is not None (ray-project#1931)
  [DataFrame] Implementing API correct groupby with aggregation methods (ray-project#1914)
  Handle interrupts correctly for ASIO synchronous reads and writes. (ray-project#1929)
  [DataFrame] Adding read methods and tests (ray-project#1712)
  Allow task_table_update to fail when tasks are finished. (ray-project#1927)
  [rllib] Contribute DDPG to RLlib (ray-project#1877)
  [xray] Workers blocked in a `ray.get` release their resources (ray-project#1920)
  Raylet task dispatch and throttling worker startup (ray-project#1912)
  [DataFrame] Eval fix (ray-project#1903)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants