Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[rllib] Document creating an ensemble of envs; also add vector_index attribute to env config #2513

Merged
merged 15 commits into from
Aug 1, 2018
20 changes: 20 additions & 0 deletions doc/source/rllib-env.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,26 @@ In the high-level agent APIs, environments are identified with string names. By
while True:
print(trainer.train())

Configuring Environments
------------------------

In the above example, note that the ``env_creator`` function takes in an ``env_config`` object. This is a dict containing options passed in through your agent. You can also access ``env_config.worker_index`` and ``env_config.vector_index`` to get the worker id and env id within the worker (if ``num_envs_per_worker > 0``). This can be useful if you want to train over an ensemble of different environments, for example:

.. code-block:: python

class MultiEnv(gym.Env):
def __init__(self, env_config):
# pick actual env based on worker and env indexes
self.env = gym.make(
choose_env_for(env_config.worker_index, env_config.vector_index))
self.action_space = self.env.action_space
self.observation_space = self.env.observation_space
def reset(self):
return self.env.reset()
def step(self, action):
return self.env.step(action)

register_env("multienv", lambda config: MultiEnv(config))

OpenAI Gym
----------
Expand Down
38 changes: 37 additions & 1 deletion doc/source/rllib-training.rst
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,43 @@ Here is an example of the basic usage:
checkpoint = agent.save()
print("checkpoint saved at", checkpoint)

All RLlib agents implement the tune Trainable API, which means they support incremental training and checkpointing. This enables them to be easily used in experiments with Ray Tune.
.. note::

It's recommended that you run RLlib agents with `Tune <tune.html>`__, for easy experiment management and visualization of results. Just set ``"run": AGENT_NAME, "env": ENV_NAME`` in the experiment config.

All RLlib agents are compatible with the `Tune API <tune.html#concepts>`__. This enables them to be easily used in experiments with `Tune <tune.html>`__. For example, the following code performs a simple hyperparam sweep of PPO:

.. code-block:: python

import ray
import ray.tune as tune

ray.init()
tune.run_experiments({
"my_experiment": {
"run": "PPO",
"env": "CartPole-v0",
"stop": {"episode_reward_mean": 200},
"config": {
"num_workers": 1,
"sgd_stepsize": tune.grid_search([0.01, 0.001, 0.0001]),
},
},
})

Tune will schedule the trials to run in parallel on your Ray cluster:

::

== Status ==
Using FIFO scheduling algorithm.
Resources requested: 4/4 CPUs, 0/0 GPUs
Result logdir: /home/eric/ray_results/my_experiment
PENDING trials:
- PPO_CartPole-v0_2_sgd_stepsize=0.0001: PENDING
RUNNING trials:
- PPO_CartPole-v0_0_sgd_stepsize=0.01: RUNNING [pid=21940], 16 s, 4013 ts, 22 rew
- PPO_CartPole-v0_1_sgd_stepsize=0.001: RUNNING [pid=21942], 27 s, 8111 ts, 54.7 rew

Accessing Global State
~~~~~~~~~~~~~~~~~~~~~~
Expand Down
2 changes: 1 addition & 1 deletion python/ray/rllib/env/async_vector_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def __init__(self, make_env, existing_envs, num_envs):
self.num_envs = num_envs
self.dones = set()
while len(self.envs) < self.num_envs:
self.envs.append(self.make_env())
self.envs.append(self.make_env(len(self.envs)))
for env in self.envs:
assert isinstance(env, MultiAgentEnv)
self.env_states = [_MultiAgentEnvState(env) for env in self.envs]
Expand Down
9 changes: 8 additions & 1 deletion python/ray/rllib/env/env_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,15 @@ class EnvContext(dict):
Attributes:
worker_index (int): When there are multiple workers created, this
uniquely identifies the worker the env is created in.
vector_index (int): When there are multiple envs per worker, this
uniquely identifies the env index within the worker.
"""

def __init__(self, env_config, worker_index):
def __init__(self, env_config, worker_index, vector_index=0):
dict.__init__(self, env_config)
self.worker_index = worker_index
self.vector_index = vector_index

def with_vector_index(self, vector_index):
return EnvContext(
self, worker_index=self.worker_index, vector_index=vector_index)
52 changes: 2 additions & 50 deletions python/ray/rllib/env/vector_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@
from __future__ import division
from __future__ import print_function

import queue
import threading


class VectorEnv(object):
"""An environment that supports batch evaluation.
Expand Down Expand Up @@ -70,20 +67,14 @@ def __init__(self, make_env, existing_envs, num_envs):
self.make_env = make_env
self.envs = existing_envs
self.num_envs = num_envs
if make_env and num_envs > 1:
self.resetter = _AsyncResetter(make_env, int(self.num_envs**0.5))
else:
self.resetter = _SimpleResetter(make_env)
while len(self.envs) < self.num_envs:
self.envs.append(self.make_env())
self.envs.append(self.make_env(len(self.envs)))

def vector_reset(self):
return [e.reset() for e in self.envs]

def reset_at(self, index):
new_obs, new_env = self.resetter.trade_for_resetted(self.envs[index])
self.envs[index] = new_env
return new_obs
return self.envs[index].reset()

def vector_step(self, actions):
obs_batch, rew_batch, done_batch, info_batch = [], [], [], []
Expand All @@ -97,42 +88,3 @@ def vector_step(self, actions):

def get_unwrapped(self):
return self.envs[0]


class _AsyncResetter(threading.Thread):
"""Does env reset asynchronously in the background.

This is useful since resetting an env can be 100x slower than stepping."""

def __init__(self, make_env, pool_size):
threading.Thread.__init__(self)
self.make_env = make_env
self.pool_size = 0
self.to_reset = queue.Queue()
self.resetted = queue.Queue()
self.daemon = True
self.pool_size = pool_size
while self.resetted.qsize() < self.pool_size:
env = self.make_env()
obs = env.reset()
self.resetted.put((obs, env))
self.start()

def run(self):
while True:
env = self.to_reset.get()
obs = env.reset()
self.resetted.put((obs, env))

def trade_for_resetted(self, env):
self.to_reset.put(env)
new_obs, new_env = self.resetted.get(timeout=30)
return new_obs, new_env


class _SimpleResetter(object):
def __init__(self, make_env):
pass

def trade_for_resetted(self, env):
return env.reset(), env
5 changes: 3 additions & 2 deletions python/ray/rllib/evaluation/policy_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,9 @@ def wrap(env):

self.env = wrap(self.env)

def make_env():
return wrap(env_creator(env_context))
def make_env(vector_index):
return wrap(
env_creator(env_context.with_vector_index(vector_index)))

self.tf_sess = None
policy_dict = _validate_and_canonicalize(policy_graph, self.env)
Expand Down
4 changes: 2 additions & 2 deletions python/ray/rllib/test/test_multi_agent_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def testRoundRobinMock(self):
self.assertEqual(done["__all__"], True)

def testVectorizeBasic(self):
env = _MultiAgentEnvToAsync(lambda: BasicMultiAgent(2), [], 2)
env = _MultiAgentEnvToAsync(lambda v: BasicMultiAgent(2), [], 2)
obs, rew, dones, _, _ = env.poll()
self.assertEqual(obs, {0: {0: 0, 1: 0}, 1: {0: 0, 1: 0}})
self.assertEqual(rew, {0: {0: None, 1: None}, 1: {0: None, 1: None}})
Expand Down Expand Up @@ -236,7 +236,7 @@ def testVectorizeBasic(self):
})

def testVectorizeRoundRobin(self):
env = _MultiAgentEnvToAsync(lambda: RoundRobinMultiAgent(2), [], 2)
env = _MultiAgentEnvToAsync(lambda v: RoundRobinMultiAgent(2), [], 2)
obs, rew, dones, _, _ = env.poll()
self.assertEqual(obs, {0: {0: 0}, 1: {0: 0}})
self.assertEqual(rew, {0: {0: None}, 1: {0: None}})
Expand Down
10 changes: 8 additions & 2 deletions python/ray/rllib/test/test_policy_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@ def postprocess_trajectory(self, batch, other_agent_batches=None):


class MockEnv(gym.Env):
def __init__(self, episode_length):
def __init__(self, episode_length, config=None):
self.episode_length = episode_length
self.config = config
self.i = 0
self.observation_space = gym.spaces.Discrete(1)
self.action_space = gym.spaces.Discrete(2)
Expand Down Expand Up @@ -150,7 +151,7 @@ def testAutoConcat(self):

def testAutoVectorization(self):
ev = PolicyEvaluator(
env_creator=lambda _: MockEnv(episode_length=20),
env_creator=lambda cfg: MockEnv(episode_length=20, config=cfg),
policy_graph=MockPolicyGraph,
batch_mode="truncate_episodes",
batch_steps=16,
Expand All @@ -165,6 +166,11 @@ def testAutoVectorization(self):
self.assertEqual(batch.count, 16)
result = collect_metrics(ev, [])
self.assertEqual(result.episodes_total, 8)
indices = []
for env in ev.async_env.vector_env.envs:
self.assertEqual(env.unwrapped.config.worker_index, 0)
indices.append(env.unwrapped.config.vector_index)
self.assertEqual(indices, [0, 1, 2, 3, 4, 5, 6, 7])

def testBatchDivisibilityCheck(self):
self.assertRaises(
Expand Down