Skip to content

[rllib] Support torch device and distributions. #4553

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Apr 12, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions ci/jenkins_tests/run_rllib_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,13 @@ docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
--stop '{"training_iteration": 1}' \
--config '{"num_workers": 2, "use_pytorch": true, "sample_async": false}'

docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
/ray/ci/suppress_output /ray/python/ray/rllib/train.py \
--env Pendulum-v0 \
--run A3C \
--stop '{"training_iteration": 1}' \
--config '{"num_workers": 2, "use_pytorch": true, "sample_async": false}'

docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
/ray/ci/suppress_output /ray/python/ray/rllib/train.py \
--env PongDeterministic-v4 \
Expand Down
71 changes: 46 additions & 25 deletions python/ray/rllib/agents/a3c/a3c_torch_policy_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,28 @@


class A3CLoss(nn.Module):
def __init__(self, policy_model, vf_loss_coeff=0.5, entropy_coeff=0.01):
def __init__(self, dist_class, vf_loss_coeff=0.5, entropy_coeff=0.01):
nn.Module.__init__(self)
self.policy_model = policy_model
self.dist_class = dist_class
self.vf_loss_coeff = vf_loss_coeff
self.entropy_coeff = entropy_coeff

def forward(self, observations, actions, advantages, value_targets):
logits, _, values, _ = self.policy_model({"obs": observations}, [])
log_probs = F.log_softmax(logits, dim=1)
probs = F.softmax(logits, dim=1)
action_log_probs = log_probs.gather(1, actions.view(-1, 1))
entropy = -(log_probs * probs).sum(-1).sum()
pi_err = -advantages.dot(action_log_probs.reshape(-1))
value_err = F.mse_loss(values.reshape(-1), value_targets)
def forward(self, policy_model, observations, actions, advantages,
value_targets):
logits, _, values, _ = policy_model({
SampleBatch.CUR_OBS: observations
}, [])
dist = self.dist_class(logits)
log_probs = dist.logp(actions)
self.entropy = dist.entropy().mean()
self.pi_err = -advantages.dot(log_probs.reshape(-1))
self.value_err = F.mse_loss(values.reshape(-1), value_targets)
overall_err = sum([
pi_err,
self.vf_loss_coeff * value_err,
-self.entropy_coeff * entropy,
self.pi_err,
self.vf_loss_coeff * self.value_err,
-self.entropy_coeff * self.entropy,
])

return overall_err


Expand All @@ -44,7 +47,7 @@ class A3CPostprocessing(object):

@override(TorchPolicyGraph)
def extra_action_out(self, model_out):
return {SampleBatch.VF_PREDS: model_out[2].numpy()}
return {SampleBatch.VF_PREDS: model_out[2].cpu().numpy()}

@override(PolicyGraph)
def postprocess_trajectory(self,
Expand All @@ -66,29 +69,47 @@ class A3CTorchPolicyGraph(A3CPostprocessing, TorchPolicyGraph):
def __init__(self, obs_space, action_space, config):
config = dict(ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG, **config)
self.config = config
_, self.logit_dim = ModelCatalog.get_action_dist(
action_space, self.config["model"])
self.model = ModelCatalog.get_torch_model(obs_space, self.logit_dim,
self.config["model"])
loss = A3CLoss(self.model, self.config["vf_loss_coeff"],
dist_class, self.logit_dim = ModelCatalog.get_action_dist(
action_space, self.config["model"], torch=True)
model = ModelCatalog.get_torch_model(obs_space, self.logit_dim,
self.config["model"])
loss = A3CLoss(dist_class, self.config["vf_loss_coeff"],
self.config["entropy_coeff"])
TorchPolicyGraph.__init__(
self,
obs_space,
action_space,
self.model,
model,
loss,
loss_inputs=[
SampleBatch.CUR_OBS, SampleBatch.ACTIONS,
Postprocessing.ADVANTAGES, Postprocessing.VALUE_TARGETS
])
],
action_distribution_cls=dist_class)

@override(TorchPolicyGraph)
def optimizer(self):
return torch.optim.Adam(self.model.parameters(), lr=self.config["lr"])
return torch.optim.Adam(self._model.parameters(), lr=self.config["lr"])

@override(TorchPolicyGraph)
def extra_grad_process(self):
info = {}
if self.config["grad_clip"]:
total_norm = nn.utils.clip_grad_norm_(self._model.parameters(),
self.config["grad_clip"])
info["grad_gnorm"] = total_norm
return info

@override(TorchPolicyGraph)
def extra_grad_info(self):
return {
"policy_entropy": self._loss.entropy.item(),
"policy_loss": self._loss.pi_err.item(),
"vf_loss": self._loss.value_err.item()
}

def _value(self, obs):
with self.lock:
obs = torch.from_numpy(obs).float().unsqueeze(0)
_, _, vf, _ = self.model({"obs": obs}, [])
return vf.detach().numpy().squeeze()
obs = torch.from_numpy(obs).float().unsqueeze(0).to(self.device)
_, _, vf, _ = self._model({"obs": obs}, [])
return vf.detach().cpu().numpy().squeeze()
46 changes: 26 additions & 20 deletions python/ray/rllib/agents/pg/torch_pg_policy_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from __future__ import print_function

import torch
import torch.nn.functional as F
from torch import nn

import ray
Expand All @@ -17,24 +16,26 @@


class PGLoss(nn.Module):
def __init__(self, policy_model):
def __init__(self, dist_class):
nn.Module.__init__(self)
self.policy_model = policy_model
self.dist_class = dist_class

def forward(self, observations, actions, advantages):
logits, _, values, _ = self.policy_model({"obs": observations}, [])
log_probs = F.log_softmax(logits, dim=1)
action_log_probs = log_probs.gather(1, actions.view(-1, 1))
pi_err = -advantages.dot(action_log_probs.reshape(-1))
return pi_err
def forward(self, policy_model, observations, actions, advantages):
logits, _, values, _ = policy_model({
SampleBatch.CUR_OBS: observations
}, [])
dist = self.dist_class(logits)
log_probs = dist.logp(actions)
self.pi_err = -advantages.dot(log_probs.reshape(-1))
return self.pi_err


class PGPostprocessing(object):
"""Adds the value func output and advantages field to the trajectory."""

@override(TorchPolicyGraph)
def extra_action_out(self, model_out):
return {SampleBatch.VF_PREDS: model_out[2].numpy()}
return {SampleBatch.VF_PREDS: model_out[2].cpu().numpy()}

@override(PolicyGraph)
def postprocess_trajectory(self,
Expand All @@ -49,29 +50,34 @@ class PGTorchPolicyGraph(PGPostprocessing, TorchPolicyGraph):
def __init__(self, obs_space, action_space, config):
config = dict(ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG, **config)
self.config = config
_, self.logit_dim = ModelCatalog.get_action_dist(
action_space, self.config["model"])
self.model = ModelCatalog.get_torch_model(obs_space, self.logit_dim,
self.config["model"])
loss = PGLoss(self.model)
dist_class, self.logit_dim = ModelCatalog.get_action_dist(
action_space, self.config["model"], torch=True)
model = ModelCatalog.get_torch_model(obs_space, self.logit_dim,
self.config["model"])
loss = PGLoss(dist_class)

TorchPolicyGraph.__init__(
self,
obs_space,
action_space,
self.model,
model,
loss,
loss_inputs=[
SampleBatch.CUR_OBS, SampleBatch.ACTIONS,
Postprocessing.ADVANTAGES
])
],
action_distribution_cls=dist_class)

@override(TorchPolicyGraph)
def optimizer(self):
return torch.optim.Adam(self.model.parameters(), lr=self.config["lr"])
return torch.optim.Adam(self._model.parameters(), lr=self.config["lr"])

@override(TorchPolicyGraph)
def extra_grad_info(self):
return {"policy_loss": self._loss.pi_err.item()}

def _value(self, obs):
with self.lock:
obs = torch.from_numpy(obs).float().unsqueeze(0)
obs = torch.from_numpy(obs).float().unsqueeze(0).to(self.device)
_, _, vf, _ = self.model({"obs": obs}, [])
return vf.detach().numpy().squeeze()
return vf.detach().cpu().numpy().squeeze()
56 changes: 42 additions & 14 deletions python/ray/rllib/evaluation/torch_policy_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,17 @@
from __future__ import division
from __future__ import print_function

import os

import numpy as np
from threading import Lock

try:
import torch
import torch.nn.functional as F
except ImportError:
pass # soft dep

from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
from ray.rllib.evaluation.policy_graph import PolicyGraph
from ray.rllib.utils.annotations import override

Expand All @@ -28,11 +30,11 @@ class TorchPolicyGraph(PolicyGraph):
"""

def __init__(self, observation_space, action_space, model, loss,
loss_inputs):
loss_inputs, action_distribution_cls):
"""Build a policy graph from policy and loss torch modules.

Note that module inputs will be CPU tensors. The model and loss modules
are responsible for moving inputs to the right device.
Note that model will be placed on GPU device if CUDA_VISIBLE_DEVICES
is set. Only single GPU is supported for now.

Arguments:
observation_space (gym.Space): observation space of the policy.
Expand All @@ -47,14 +49,20 @@ def __init__(self, observation_space, action_space, model, loss,
loss_inputs (list): List of SampleBatch columns that will be
passed to the loss module's forward() function when computing
the loss. For example, ["obs", "action", "advantages"].
action_distribution_cls (ActionDistribution): Class for action
distribution.
"""
self.observation_space = observation_space
self.action_space = action_space
self.lock = Lock()
self._model = model
self.device = (torch.device("cuda")
if bool(os.environ.get("CUDA_VISIBLE_DEVICES", None))
else torch.device("cpu"))
self._model = model.to(self.device)
self._loss = loss
self._loss_inputs = loss_inputs
self._optimizer = self.optimizer()
self._action_dist_cls = action_distribution_cls

@override(PolicyGraph)
def compute_actions(self,
Expand All @@ -67,45 +75,55 @@ def compute_actions(self,
**kwargs):
with self.lock:
with torch.no_grad():
ob = torch.from_numpy(np.array(obs_batch)).float()
ob = torch.from_numpy(np.array(obs_batch)) \
.float().to(self.device)
model_out = self._model({"obs": ob}, state_batches)
logits, _, vf, state = model_out
actions = F.softmax(logits, dim=1).multinomial(1).squeeze(0)
return (actions.numpy(), [h.numpy() for h in state],
action_dist = self._action_dist_cls(logits)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that A2C/PG presumably work with continuous action spaces, you can add two entries to run_rllib_tests.sh to check they work on Pendulum-v0:

Similar to the CartPole-v0 entries:
https://github.com/ray-project/ray/blob/master/ci/jenkins_tests/run_rllib_tests.sh#L407

actions = action_dist.sample()
return (actions.cpu().numpy(),
[h.cpu().numpy() for h in state],
self.extra_action_out(model_out))

@override(PolicyGraph)
def compute_gradients(self, postprocessed_batch):
with self.lock:
loss_in = []
for key in self._loss_inputs:
loss_in.append(torch.from_numpy(postprocessed_batch[key]))
loss_out = self._loss(*loss_in)
loss_in.append(
torch.from_numpy(postprocessed_batch[key]).to(self.device))
loss_out = self._loss(self._model, *loss_in)
self._optimizer.zero_grad()
loss_out.backward()

grad_process_info = self.extra_grad_process()

# Note that return values are just references;
# calling zero_grad will modify the values
grads = []
for p in self._model.parameters():
if p.grad is not None:
grads.append(p.grad.data.numpy())
grads.append(p.grad.data.cpu().numpy())
else:
grads.append(None)
return grads, {}

grad_info = self.extra_grad_info()
grad_info.update(grad_process_info)
return grads, {LEARNER_STATS_KEY: grad_info}

@override(PolicyGraph)
def apply_gradients(self, gradients):
with self.lock:
for g, p in zip(gradients, self._model.parameters()):
if g is not None:
p.grad = torch.from_numpy(g)
p.grad = torch.from_numpy(g).to(self.device)
self._optimizer.step()
return {}

@override(PolicyGraph)
def get_weights(self):
with self.lock:
return self._model.state_dict()
return {k: v.cpu() for k, v in self._model.state_dict().items()}

@override(PolicyGraph)
def set_weights(self, weights):
Expand All @@ -116,13 +134,23 @@ def set_weights(self, weights):
def get_initial_state(self):
return [s.numpy() for s in self._model.state_init()]

def extra_grad_process(self):
"""Allow subclass to do extra processing on gradients and
return processing info."""
return {}

def extra_action_out(self, model_out):
"""Returns dict of extra info to include in experience batch.

Arguments:
model_out (list): Outputs of the policy model module."""
return {}

def extra_grad_info(self):
"""Return dict of extra grad info."""

return {}

def optimizer(self):
"""Custom PyTorch optimizer to use."""
return torch.optim.Adam(self._model.parameters())
Loading