Skip to content

Fix A3C PyTorch implementation #2036

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 29 commits into from
May 30, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions docker/examples/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# The examples Docker image adds dependencies needed to run the examples

FROM ray-project/deploy
RUN conda install -y -c conda-forge tensorflow

# This updates numpy to 1.14 and mutes errors from other libraries
RUN conda install -y numpy
RUN apt-get install -y zlib1g-dev
RUN pip install gym[atari] opencv-python==3.2.0.8
RUN pip install gym[atari] opencv-python==3.2.0.8 tensorflow
RUN pip install --upgrade git+git://github.com/hyperopt/hyperopt.git
# RUN conda install -y -q pytorch torchvision -c soumith
RUN conda install pytorch-cpu torchvision-cpu -c pytorch
56 changes: 32 additions & 24 deletions python/ray/rllib/a3c/a3c.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from ray.tune.result import TrainingResult
from ray.tune.trial import Resources


DEFAULT_CONFIG = {
# Number of workers (excluding master)
"num_workers": 4,
Expand Down Expand Up @@ -52,7 +51,7 @@
# (Image statespace) - Converts image to (dim, dim, C)
"dim": 80,
# (Image statespace) - Converts image shape to (C, dim, dim)
"channel_major": False
"channel_major": False,
},
# Arguments to pass to the rllib optimizer
"optimizer": {
Expand All @@ -73,46 +72,53 @@ class A3CAgent(Agent):
def default_resource_request(cls, config):
cf = dict(cls._default_config, **config)
return Resources(
cpu=1, gpu=0,
cpu=1,
gpu=0,
extra_cpu=cf["num_workers"],
extra_gpu=cf["use_gpu_for_workers"] and cf["num_workers"] or 0)

def _init(self):
self.local_evaluator = A3CEvaluator(
self.registry, self.env_creator, self.config, self.logdir,
self.registry,
self.env_creator,
self.config,
self.logdir,
start_sampler=False)
if self.config["use_gpu_for_workers"]:
remote_cls = GPURemoteA3CEvaluator
else:
remote_cls = RemoteA3CEvaluator
self.remote_evaluators = [
remote_cls.remote(
self.registry, self.env_creator, self.config, self.logdir)
for i in range(self.config["num_workers"])]
self.optimizer = AsyncOptimizer(
self.config["optimizer"], self.local_evaluator,
self.remote_evaluators)
remote_cls.remote(self.registry, self.env_creator, self.config,
self.logdir)
for i in range(self.config["num_workers"])
]
self.optimizer = AsyncOptimizer(self.config["optimizer"],
self.local_evaluator,
self.remote_evaluators)

def _train(self):
self.optimizer.step()
FilterManager.synchronize(
self.local_evaluator.filters, self.remote_evaluators)
FilterManager.synchronize(self.local_evaluator.filters,
self.remote_evaluators)
res = self._fetch_metrics_from_remote_evaluators()
return res

def _fetch_metrics_from_remote_evaluators(self):
episode_rewards = []
episode_lengths = []
metric_lists = [a.get_completed_rollout_metrics.remote()
for a in self.remote_evaluators]
metric_lists = [
a.get_completed_rollout_metrics.remote()
for a in self.remote_evaluators
]
for metrics in metric_lists:
for episode in ray.get(metrics):
episode_lengths.append(episode.episode_length)
episode_rewards.append(episode.episode_reward)
avg_reward = (
np.mean(episode_rewards) if episode_rewards else float('nan'))
avg_length = (
np.mean(episode_lengths) if episode_lengths else float('nan'))
avg_reward = (np.mean(episode_rewards)
if episode_rewards else float('nan'))
avg_length = (np.mean(episode_lengths)
if episode_lengths else float('nan'))
timesteps = np.sum(episode_lengths) if episode_lengths else 0

result = TrainingResult(
Expand All @@ -129,21 +135,23 @@ def _stop(self):
ev.__ray_terminate__.remote()

def _save(self, checkpoint_dir):
checkpoint_path = os.path.join(
checkpoint_dir, "checkpoint-{}".format(self.iteration))
checkpoint_path = os.path.join(checkpoint_dir,
"checkpoint-{}".format(self.iteration))
agent_state = ray.get(
[a.save.remote() for a in self.remote_evaluators])
extra_data = {
"remote_state": agent_state,
"local_state": self.local_evaluator.save()}
"local_state": self.local_evaluator.save()
}
pickle.dump(extra_data, open(checkpoint_path + ".extra_data", "wb"))
return checkpoint_path

def _restore(self, checkpoint_path):
extra_data = pickle.load(open(checkpoint_path + ".extra_data", "rb"))
ray.get(
[a.restore.remote(o) for a, o in zip(
self.remote_evaluators, extra_data["remote_state"])])
ray.get([
a.restore.remote(o)
for a, o in zip(self.remote_evaluators, extra_data["remote_state"])
])
self.local_evaluator.restore(extra_data["local_state"])

def compute_action(self, observation):
Expand Down
49 changes: 28 additions & 21 deletions python/ray/rllib/a3c/shared_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from __future__ import print_function

import torch
from torch.autograd import Variable
import torch.nn.functional as F

from ray.rllib.a3c.torchpolicy import TorchPolicy
Expand All @@ -18,8 +17,8 @@ class SharedTorchPolicy(TorchPolicy):
is_recurrent = False

def __init__(self, registry, ob_space, ac_space, config, **kwargs):
super(SharedTorchPolicy, self).__init__(
registry, ob_space, ac_space, config, **kwargs)
super(SharedTorchPolicy, self).__init__(registry, ob_space, ac_space,
config, **kwargs)

def _setup_graph(self, ob_space, ac_space):
_, self.logit_dim = ModelCatalog.get_action_dist(ac_space)
Expand All @@ -31,48 +30,56 @@ def _setup_graph(self, ob_space, ac_space):
def compute(self, ob, *args):
"""Should take in a SINGLE ob"""
with self.lock:
ob = Variable(torch.from_numpy(ob).float().unsqueeze(0))
ob = torch.from_numpy(ob).float().unsqueeze(0)
logits, values = self._model(ob)
samples = self._model.probs(logits).multinomial().squeeze()
values = values.squeeze(0)
return var_to_np(samples), {"vf_preds": var_to_np(values)}
# TODO(alok): Support non-categorical distributions. Multinomial
# is only for categorical.
sampled_actions = F.softmax(logits, dim=1).multinomial(1).squeeze()
values = values.squeeze()
return var_to_np(sampled_actions), {"vf_preds": var_to_np(values)}

def compute_logits(self, ob, *args):
with self.lock:
ob = Variable(torch.from_numpy(ob).float().unsqueeze(0))
ob = torch.from_numpy(ob).float().unsqueeze(0)
res = self._model.hidden_layers(ob)
return var_to_np(self._model.logits(res))

def value(self, ob, *args):
with self.lock:
ob = Variable(torch.from_numpy(ob).float().unsqueeze(0))
ob = torch.from_numpy(ob).float().unsqueeze(0)
res = self._model.hidden_layers(ob)
res = self._model.value_branch(res)
res = res.squeeze(0)
res = res.squeeze()
return var_to_np(res)

def _evaluate(self, obs, actions):
"""Passes in multiple obs."""
logits, values = self._model(obs)
log_probs = F.log_softmax(logits)
probs = self._model.probs(logits)
log_probs = F.log_softmax(logits, dim=1)
probs = F.softmax(logits, dim=1)
action_log_probs = log_probs.gather(1, actions.view(-1, 1))
# TODO(alok): set distribution based on action space and use its
# `.entropy()` method to calculate automatically
entropy = -(log_probs * probs).sum(-1).sum()
return values, action_log_probs, entropy

def _backward(self, batch):
"""Loss is encoded in here. Defining a new loss function
would start by rewriting this function"""

states, acs, advs, rs, _ = convert_batch(batch)
values, ac_logprobs, entropy = self._evaluate(states, acs)
pi_err = -(advs * ac_logprobs).sum()
value_err = 0.5 * (values - rs).pow(2).sum()
states, actions, advs, rs, _ = convert_batch(batch)
values, action_log_probs, entropy = self._evaluate(states, actions)
pi_err = -advs.dot(action_log_probs.reshape(-1))
value_err = F.mse_loss(values.reshape(-1), rs)

self.optimizer.zero_grad()
overall_err = (pi_err +
value_err * self.config["vf_loss_coeff"] +
entropy * self.config["entropy_coeff"])

overall_err = sum([
pi_err,
self.config["vf_loss_coeff"] * value_err,
self.config["entropy_coeff"] * entropy,
])

overall_err.backward()
torch.nn.utils.clip_grad_norm(
self._model.parameters(), self.config["grad_clip"])
torch.nn.utils.clip_grad_norm_(self._model.parameters(),
Copy link
Contributor Author

@alok alok May 11, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clip_grad_norm is deprecated in favor of the underscore version, hence the change

self.config["grad_clip"])
14 changes: 9 additions & 5 deletions python/ray/rllib/a3c/torchpolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from __future__ import print_function

import torch
from torch.autograd import Variable

from ray.rllib.a3c.policy import Policy
from threading import Lock
Expand All @@ -15,8 +14,13 @@ class TorchPolicy(Policy):
The model is a separate object than the policy. This could be changed
in the future."""

def __init__(self, registry, ob_space, action_space, config,
name="local", summarize=True):
def __init__(self,
registry,
ob_space,
action_space,
config,
name="local",
summarize=True):
self.registry = registry
self.local_steps = 0
self.config = config
Expand All @@ -28,7 +32,7 @@ def __init__(self, registry, ob_space, action_space, config,
def apply_gradients(self, grads):
self.optimizer.zero_grad()
for g, p in zip(grads, self._model.parameters()):
p.grad = Variable(torch.from_numpy(g))
p.grad = torch.from_numpy(g)
self.optimizer.step()

def get_weights(self):
Expand Down Expand Up @@ -69,7 +73,7 @@ def _setup_graph(ob_space, action_space):

def _backward(self, batch):
"""Implements the loss function and calculates the gradient.
Pytorch automatically generates a backward trace for each variable.
Pytorch automatically generates a backward trace for each tensor.
Assumption right now is that variables are moved, so the backward
trace is lost.

Expand Down
3 changes: 3 additions & 0 deletions python/ray/rllib/models/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,11 +180,14 @@ def get_torch_model(registry, input_shape, num_outputs, options={}):
return registry.get(RLLIB_MODEL, model)(
input_shape, num_outputs, options)

# TODO(alok): fix to handle Discrete(n) state spaces
obs_rank = len(input_shape) - 1

if obs_rank > 1:
return PyTorchVisionNet(input_shape, num_outputs, options)

# TODO(alok): overhaul PyTorchFCNet so it can just
# take input shape directly
return PyTorchFCNet(input_shape[0], num_outputs, options)

@staticmethod
Expand Down
22 changes: 13 additions & 9 deletions python/ray/rllib/models/pytorch/fcnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

class FullyConnectedNetwork(Model):
"""TODO(rliaw): Logits, Value should both be contained here"""

def _init(self, inputs, num_outputs, options):
assert type(inputs) is int
hiddens = options.get("fcnet_hiddens", [256, 256])
Expand All @@ -23,26 +24,29 @@ def _init(self, inputs, num_outputs, options):
layers = []
last_layer_size = inputs
for size in hiddens:
layers.append(SlimFC(
last_layer_size, size,
initializer=normc_initializer(1.0),
activation_fn=activation))
layers.append(
SlimFC(
in_size=last_layer_size,
out_size=size,
initializer=normc_initializer(1.0),
activation_fn=activation))
last_layer_size = size

self.hidden_layers = nn.Sequential(*layers)

self.logits = SlimFC(
last_layer_size, num_outputs,
in_size=last_layer_size,
out_size=num_outputs,
initializer=normc_initializer(0.01),
activation_fn=None)
self.probs = nn.Softmax()
self.value_branch = SlimFC(
last_layer_size, 1,
in_size=last_layer_size,
out_size=1,
initializer=normc_initializer(1.0),
activation_fn=None)

def forward(self, obs):
""" Internal method - pass in Variables, not numpy arrays
""" Internal method - pass in torch tensors, not numpy arrays

Args:
obs: observations and features
Expand All @@ -52,5 +56,5 @@ def forward(self, obs):
value: value function for each state"""
res = self.hidden_layers(obs)
logits = self.logits(res)
value = self.value_branch(res)
value = self.value_branch(res).reshape(-1)
return logits, value
22 changes: 8 additions & 14 deletions python/ray/rllib/models/pytorch/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,38 +5,32 @@

import numpy as np
import torch
from torch.autograd import Variable


def convert_batch(trajectory, has_features=False):
"""Convert trajectory from numpy to PT variable"""
states = Variable(torch.from_numpy(
trajectory["observations"]).float())
acs = Variable(torch.from_numpy(
trajectory["actions"]))
advs = Variable(torch.from_numpy(
trajectory["advantages"].copy()).float())
advs = advs.view(-1, 1)
rs = Variable(torch.from_numpy(
trajectory["value_targets"]).float())
rs = rs.view(-1, 1)
states = torch.from_numpy(trajectory["obs"]).float()
acs = torch.from_numpy(trajectory["actions"])
advs = torch.from_numpy(
trajectory["advantages"].copy()).float().reshape(-1)
rs = torch.from_numpy(trajectory["rewards"]).float().reshape(-1)
if has_features:
features = [Variable(torch.from_numpy(f))
for f in trajectory["features"]]
features = [torch.from_numpy(f) for f in trajectory["features"]]
else:
features = trajectory["features"]
return states, acs, advs, rs, features


def var_to_np(var):
return var.data.numpy()[0]
return var.detach().numpy()


def normc_initializer(std=1.0):
def initializer(tensor):
tensor.data.normal_(0, 1)
tensor.data *= std / torch.sqrt(
tensor.data.pow(2).sum(1, keepdim=True))

return initializer


Expand Down
Loading