Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
85 commits
Select commit Hold shift + click to select a range
682ae7e
wip
ericl May 27, 2018
846a3a6
cls
ericl May 27, 2018
a5e1416
re
ericl May 27, 2018
cfb77be
wip
ericl May 28, 2018
7966e63
Merge branch 'fix-classmethod' into v2-refactor
ericl May 28, 2018
3c07c29
wip
ericl May 28, 2018
332683c
a3c working
ericl May 28, 2018
3cea2c9
torch support
ericl May 28, 2018
d7472e5
pg works
ericl May 28, 2018
b4a782b
lint
ericl May 28, 2018
8738fa3
rm v2
ericl May 28, 2018
a88957c
consumer id
ericl May 28, 2018
370abf0
clean up pg
ericl May 28, 2018
6c2bcbb
clean up more
ericl May 28, 2018
56429fb
fix python 2.7
ericl May 28, 2018
2380c8f
Merge branch 'fix-classmethod' into v2-refactor
ericl May 28, 2018
f16f8f0
tf session management
ericl May 28, 2018
71d78b5
docs
ericl May 28, 2018
5ab8723
dqn wip
ericl May 29, 2018
c6d68ff
fix compile
ericl May 29, 2018
fa015ff
dqn
ericl May 29, 2018
e2a41a9
apex runs
ericl May 29, 2018
84624fe
up
ericl May 29, 2018
3c4a9fd
impotrs
ericl May 29, 2018
c56dcef
ddpg
ericl May 29, 2018
04220bf
quotes
ericl May 29, 2018
6f5ef1b
Merge remote-tracking branch 'upstream/master' into v2-refactor
ericl May 29, 2018
95a69df
fix tests
ericl May 29, 2018
c62a236
fix last r
ericl May 29, 2018
a9090a4
fix tests
ericl May 29, 2018
a63efae
lint
ericl May 29, 2018
19db8bd
pass checkpoint restore
ericl May 29, 2018
c2b4243
kwar
ericl May 29, 2018
0e56fd4
nits
ericl May 30, 2018
ed0b359
policy graph
ericl May 30, 2018
70ea79d
fix yapf
ericl May 30, 2018
496946f
com
ericl May 30, 2018
53b4e55
class
ericl May 30, 2018
da02fa9
Merge remote-tracking branch 'upstream/master' into v2-refactor
ericl May 30, 2018
3657108
pyt
ericl May 30, 2018
f08544b
vectorization
ericl May 31, 2018
1f435f7
update
ericl Jun 7, 2018
6dbd0e8
Merge remote-tracking branch 'upstream/master' into v2-refactor
ericl Jun 7, 2018
f910464
test cpe
ericl Jun 7, 2018
5685e32
unit test
ericl Jun 7, 2018
f2af5dc
fix ddpg2
ericl Jun 7, 2018
06ba0af
Merge branch 'v2-refactor' into v2-vectorization
ericl Jun 7, 2018
5b27640
changes
ericl Jun 8, 2018
550fe45
wip
ericl Jun 8, 2018
ad9a205
args
ericl Jun 8, 2018
1b9b192
faster test
ericl Jun 8, 2018
21cecdd
common
ericl Jun 8, 2018
8ec7c2d
tests
ericl Jun 8, 2018
20715c5
Merge remote-tracking branch 'upstream/master' into v2-vectorization
ericl Jun 9, 2018
66f2c9d
fix
ericl Jun 9, 2018
0140db4
add alg option
ericl Jun 9, 2018
c51f799
batch mode and policy serving
ericl Jun 10, 2018
833aeb8
multi serving test
ericl Jun 10, 2018
dbce953
todo
ericl Jun 10, 2018
9ac1da2
wip
ericl Jun 10, 2018
fd9bc5b
serving test
ericl Jun 10, 2018
03f78b5
doc async env
ericl Jun 10, 2018
9bfcdaf
num envs
ericl Jun 10, 2018
23a5022
comments
ericl Jun 10, 2018
f9ff790
thread
ericl Jun 10, 2018
9f8ac7b
remove init hook
ericl Jun 10, 2018
e30424b
update
ericl Jun 11, 2018
722917b
fix ppo
ericl Jun 12, 2018
923df74
comments1
ericl Jun 12, 2018
6fc62fe
fix
ericl Jun 12, 2018
f1f7d5e
updates
ericl Jun 12, 2018
74375de
add jenkins tests
ericl Jun 13, 2018
7daf94d
fix
ericl Jun 13, 2018
544bb4e
fix pytorch
ericl Jun 13, 2018
da230c2
Merge branch 'v2-vectorization' of github.com:ericl/ray into v2-vecto…
ericl Jun 14, 2018
beaab29
fix
ericl Jun 14, 2018
27bea6b
fixes
ericl Jun 15, 2018
c8f85ce
fix a3c policy
ericl Jun 15, 2018
09df795
Merge remote-tracking branch 'upstream/master' into v2-vectorization
ericl Jun 15, 2018
f5bb43d
fix squeeze
ericl Jun 15, 2018
b3214f4
fix trunc on apex
ericl Jun 15, 2018
8ebf32f
fix squeezing for real
ericl Jun 16, 2018
516a595
update
ericl Jun 16, 2018
e191e82
remove horizon test for now
ericl Jun 17, 2018
dcb6eba
fix race condition
ericl Jun 18, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion python/ray/rllib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
from ray.rllib.utils.policy_graph import PolicyGraph
from ray.rllib.utils.tf_policy_graph import TFPolicyGraph
from ray.rllib.utils.common_policy_evaluator import CommonPolicyEvaluator
from ray.rllib.utils.async_vector_env import AsyncVectorEnv
from ray.rllib.utils.vector_env import VectorEnv
from ray.rllib.utils.serving_env import ServingEnv
from ray.rllib.optimizers.sample_batch import SampleBatch


Expand All @@ -23,5 +26,6 @@ def _register_all():
_register_all()

__all__ = [
"PolicyGraph", "TFPolicyGraph", "CommonPolicyEvaluator", "SampleBatch"
"PolicyGraph", "TFPolicyGraph", "CommonPolicyEvaluator", "SampleBatch",
"AsyncVectorEnv", "VectorEnv", "ServingEnv",
]
8 changes: 6 additions & 2 deletions python/ray/rllib/a3c/a3c.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
DEFAULT_CONFIG = {
# Number of workers (excluding master)
"num_workers": 2,
# Number of environments to evaluate vectorwise per worker.
"num_envs": 1,
# Size of rollout batch
"batch_size": 10,
# Use LSTM model - only applicable for image states
Expand Down Expand Up @@ -101,15 +103,17 @@ def session_creator():
batch_mode="truncate_episodes",
tf_session_creator=session_creator,
registry=self.registry, env_config=self.config["env_config"],
model_config=self.config["model"], policy_config=self.config)
model_config=self.config["model"], policy_config=self.config,
num_envs=self.config["num_envs"])
self.remote_evaluators = [
remote_cls.remote(
self.env_creator, self.policy_cls,
batch_steps=self.config["batch_size"],
batch_mode="truncate_episodes", sample_async=True,
tf_session_creator=session_creator,
registry=self.registry, env_config=self.config["env_config"],
model_config=self.config["model"], policy_config=self.config)
model_config=self.config["model"], policy_config=self.config,
num_envs=self.config["num_envs"])
for i in range(self.config["num_workers"])]

self.optimizer = AsyncOptimizer(
Expand Down
8 changes: 4 additions & 4 deletions python/ray/rllib/a3c/a3c_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import division
from __future__ import print_function

import numpy as np
from threading import Lock

import torch
Expand Down Expand Up @@ -33,13 +34,12 @@ def setup_graph(self, obs_space, action_space):
self.optimizer = torch.optim.Adam(
self._model.parameters(), lr=self.config["lr"])

def compute_single_action(self, obs, state, is_training=False):
def compute_actions(self, obs, state, is_training=False):
assert not state, "RNN not supported"
with self.lock:
ob = torch.from_numpy(obs).float().unsqueeze(0)
ob = torch.from_numpy(np.array(obs)).float()
logits, values = self._model(ob)
samples = F.softmax(logits, dim=1).multinomial(1).squeeze()
values = values.squeeze()
samples = F.softmax(logits, dim=1).multinomial(1).squeeze(0)
return var_to_np(samples), [], {"vf_preds": var_to_np(values)}

def compute_gradients(self, samples):
Expand Down
85 changes: 0 additions & 85 deletions python/ray/rllib/a3c/shared_torch_policy.py

This file was deleted.

82 changes: 0 additions & 82 deletions python/ray/rllib/a3c/torchpolicy.py

This file was deleted.

2 changes: 2 additions & 0 deletions python/ray/rllib/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@
# to increase if your environment is particularly slow to sample, or if
# you"re using the Async or Ape-X optimizers.
"num_workers": 0,
# Number of environments to evaluate vectorwise per worker.
"num_envs": 1,
# Whether to allocate GPUs for workers (if > 0).
"num_gpus_per_worker": 0,
# Whether to allocate CPUs for workers (if > 0).
Expand Down
12 changes: 8 additions & 4 deletions python/ray/rllib/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@
# to increase if your environment is particularly slow to sample, or if
# you"re using the Async or Ape-X optimizers.
"num_workers": 0,
# Number of environments to evaluate vectorwise per worker.
"num_envs": 1,
# Whether to allocate GPUs for workers (if > 0).
"num_gpus_per_worker": 0,
# Whether to allocate CPUs for workers (if > 0).
Expand Down Expand Up @@ -125,21 +127,23 @@ def _init(self):
self.local_evaluator = CommonPolicyEvaluator(
self.env_creator, self._policy_graph,
batch_steps=adjusted_batch_size,
batch_mode="pack_episodes", preprocessor_pref="deepmind",
batch_mode="truncate_episodes", preprocessor_pref="deepmind",
compress_observations=True,
registry=self.registry, env_config=self.config["env_config"],
model_config=self.config["model"], policy_config=self.config)
model_config=self.config["model"], policy_config=self.config,
num_envs=self.config["num_envs"])
remote_cls = CommonPolicyEvaluator.as_remote(
num_cpus=self.config["num_cpus_per_worker"],
num_gpus=self.config["num_gpus_per_worker"])
self.remote_evaluators = [
remote_cls.remote(
self.env_creator, self._policy_graph,
batch_steps=adjusted_batch_size,
batch_mode="pack_episodes", preprocessor_pref="deepmind",
batch_mode="truncate_episodes", preprocessor_pref="deepmind",
compress_observations=True,
registry=self.registry, env_config=self.config["env_config"],
model_config=self.config["model"], policy_config=self.config)
model_config=self.config["model"], policy_config=self.config,
num_envs=self.config["num_envs"])
for _ in range(self.config["num_workers"])]

self.exploration0 = self._make_exploration_schedule(0)
Expand Down
4 changes: 1 addition & 3 deletions python/ray/rllib/dqn/dqn_policy_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,11 +223,9 @@ def _postprocess_dqn(policy_graph, sample_batch):
"obs": obs, "actions": actions, "rewards": rewards,
"new_obs": new_obs, "dones": dones,
"weights": np.ones_like(rewards)})
assert batch.count == policy_graph.config["sample_batch_size"], \
(batch.count, policy_graph.config["sample_batch_size"])

# Prioritize on the worker side
if policy_graph.config["worker_side_prioritization"]:
if batch.count > 0 and policy_graph.config["worker_side_prioritization"]:
td_errors = policy_graph.compute_td_error(
batch["obs"], batch["actions"], batch["rewards"],
batch["new_obs"], batch["dones"], batch["weights"])
Expand Down
2 changes: 1 addition & 1 deletion python/ray/rllib/models/action_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def kl(self, other):
reduction_indices=[1])

def sample(self):
return tf.multinomial(self.inputs, 1)[0]
return tf.squeeze(tf.multinomial(self.inputs, 1), axis=1)


class DiagGaussian(ActionDistribution):
Expand Down
6 changes: 0 additions & 6 deletions python/ray/rllib/models/preprocessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,22 +125,16 @@ def get_preprocessor(space):

legacy_patch_shapes(space)
obs_shape = space.shape
print("Observation shape is {}".format(obs_shape))

if isinstance(space, gym.spaces.Discrete):
print("Using one-hot preprocessor for discrete envs.")
preprocessor = OneHotPreprocessor
elif obs_shape == ATARI_OBS_SHAPE:
print("Assuming Atari pixel env, using AtariPixelPreprocessor.")
preprocessor = AtariPixelPreprocessor
elif obs_shape == ATARI_RAM_OBS_SHAPE:
print("Assuming Atari ram env, using AtariRamPreprocessor.")
preprocessor = AtariRamPreprocessor
elif isinstance(space, gym.spaces.Tuple):
print("Using a TupleFlatteningPreprocessor")
preprocessor = TupleFlatteningPreprocessor
else:
print("Not using any observation preprocessor.")
preprocessor = NoPreprocessor

return preprocessor
Expand Down
2 changes: 1 addition & 1 deletion python/ray/rllib/models/pytorch/fcnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,5 +56,5 @@ def forward(self, obs):
value: value function for each state"""
res = self.hidden_layers(obs)
logits = self.logits(res)
value = self.value_branch(res).reshape(-1)
value = self.value_branch(res).squeeze(1)
return logits, value
2 changes: 1 addition & 1 deletion python/ray/rllib/models/pytorch/visionnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,5 +65,5 @@ def forward(self, obs):
value (PyTorch): value function for each state"""
res = self.hidden_layers(obs)
logits = self.logits(res)
value = self.value_branch(res)
value = self.value_branch(res).squeeze(1)
return logits, value
5 changes: 5 additions & 0 deletions python/ray/rllib/optimizers/policy_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,11 @@ def get_host(self):

return os.uname()[1]

def apply(self, func, *args):
"""Apply the given function to this evaluator instance."""

return func(self, *args)


class TFMultiGPUSupport(PolicyEvaluator):
"""The multi-GPU TF optimizer requires additional TF-specific support.
Expand Down
20 changes: 20 additions & 0 deletions python/ray/rllib/optimizers/policy_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,23 @@ def restore(self, data):

self.num_steps_trained = data[0]
self.num_steps_sampled = data[1]

def foreach_evaluator(self, func):
"""Apply the given function to each evaluator instance."""

local_result = [func(self.local_evaluator)]
remote_results = ray.get(
[ev.apply.remote(func) for ev in self.remote_evaluators])
return local_result + remote_results

def foreach_evaluator_with_index(self, func):
"""Apply the given function to each evaluator instance.

The index will be passed as the second arg to the given function.
"""

local_result = [func(self.local_evaluator, 0)]
remote_results = ray.get(
[ev.apply.remote(func, i + 1)
for i, ev in enumerate(self.remote_evaluators)])
return local_result + remote_results
Loading