Skip to content

Commit

Permalink
[RLlib] Working/learning example: PPO + torch + LSTM. (ray-project#7797)
Browse files Browse the repository at this point in the history
  • Loading branch information
sven1977 authored Apr 1, 2020
1 parent c23e56c commit 66df8b8
Show file tree
Hide file tree
Showing 17 changed files with 593 additions and 228 deletions.
9 changes: 9 additions & 0 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1266,6 +1266,15 @@ py_test(
args = ["--iters=2", "--num-cpus=4"]
)

py_test(
name = "examples/custom_torch_rnn_model",
main = "examples/custom_torch_rnn_model.py",
tags = ["examples", "examples_C"],
size = "medium",
srcs = ["examples/custom_torch_rnn_model.py"],
args = ["--run=PPO", "--stop=90", "--num-cpus=4"]
)

py_test(
name = "examples/custom_torch_policy",
tags = ["examples", "examples_C"],
Expand Down
17 changes: 9 additions & 8 deletions rllib/agents/ppo/ppo_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,10 @@ def __init__(self,
use_gae (bool): If true, use the Generalized Advantage Estimator.
"""
if valid_mask is not None:
num_valid = torch.sum(valid_mask)

def reduce_mean_valid(t):
return torch.mean(t * valid_mask)
return torch.sum(t * valid_mask) / num_valid

else:

Expand Down Expand Up @@ -190,14 +191,14 @@ def __init__(self, obs_space, action_space, config):

def value(ob, prev_action, prev_reward, *state):
model_out, _ = self.model({
SampleBatch.CUR_OBS: torch.Tensor([ob]).to(self.device),
SampleBatch.PREV_ACTIONS: torch.Tensor([prev_action]).to(
self.device),
SampleBatch.PREV_REWARDS: torch.Tensor([prev_reward]).to(
self.device),
SampleBatch.CUR_OBS: self._convert_to_tensor([ob]),
SampleBatch.PREV_ACTIONS: self._convert_to_tensor(
[prev_action]),
SampleBatch.PREV_REWARDS: self._convert_to_tensor(
[prev_reward]),
"is_training": False,
}, [torch.Tensor([s]).to(self.device) for s in state],
torch.Tensor([1]).to(self.device))
}, [self._convert_to_tensor(s) for s in state],
self._convert_to_tensor([1]))
return self.model.value_function()[0]

else:
Expand Down
115 changes: 51 additions & 64 deletions rllib/agents/ppo/tests/test_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,10 @@ def test_ppo_loss_function(self):
"""Tests the PPO loss function math."""
config = ppo.DEFAULT_CONFIG.copy()
config["num_workers"] = 0 # Run locally.
config["eager"] = True
config["gamma"] = 0.99
config["model"]["fcnet_hiddens"] = [10]
config["model"]["fcnet_activation"] = "linear"
config["vf_share_layers"] = True

# Fake CartPole episode of n time steps.
train_batch = {
Expand All @@ -114,69 +114,56 @@ def test_ppo_loss_function(self):
ACTION_LOGP: np.array([-0.5, -0.1, -0.2], dtype=np.float32)
}

# tf.
trainer = ppo.PPOTrainer(config=config, env="CartPole-v0")
policy = trainer.get_policy()

# Post-process (calculate simple (non-GAE) advantages) and attach to
# train_batch dict.
# A = [0.99^2 * 0.5 + 0.99 * -1.0 + 1.0, 0.99 * 0.5 - 1.0, 0.5] =
# [0.50005, -0.505, 0.5]
train_batch = postprocess_ppo_gae_tf(policy, train_batch)
# Check Advantage values.
check(train_batch[Postprocessing.VALUE_TARGETS],
[0.50005, -0.505, 0.5])

# Calculate actual PPO loss (results are stored in policy.loss_obj) for
# tf.
ppo_surrogate_loss_tf(policy, policy.model, Categorical, train_batch)

vars = policy.model.trainable_variables()
expected_logits = fc(
fc(train_batch[SampleBatch.CUR_OBS], vars[0].numpy(),
vars[1].numpy()), vars[4].numpy(), vars[5].numpy())
expected_value_outs = fc(
fc(train_batch[SampleBatch.CUR_OBS], vars[2].numpy(),
vars[3].numpy()), vars[6].numpy(), vars[7].numpy())

kl, entropy, pg_loss, vf_loss, overall_loss = \
self._ppo_loss_helper(
policy, policy.model, Categorical, train_batch,
expected_logits, expected_value_outs
)
check(policy.loss_obj.mean_kl, kl)
check(policy.loss_obj.mean_entropy, entropy)
check(policy.loss_obj.mean_policy_loss, np.mean(-pg_loss))
check(policy.loss_obj.mean_vf_loss, np.mean(vf_loss), decimals=4)
check(policy.loss_obj.loss, overall_loss, decimals=4)

# Torch.
config["use_pytorch"] = True
trainer = ppo.PPOTrainer(config=config, env="CartPole-v0")
policy = trainer.get_policy()
train_batch = postprocess_ppo_gae_torch(policy, train_batch)
train_batch = policy._lazy_tensor_dict(train_batch)

# Check Advantage values.
check(train_batch[Postprocessing.VALUE_TARGETS],
[0.50005, -0.505, 0.5])

# Calculate actual PPO loss (results are stored in policy.loss_obj)
# for tf.
ppo_surrogate_loss_torch(policy, policy.model, TorchCategorical,
train_batch)

kl, entropy, pg_loss, vf_loss, overall_loss = \
self._ppo_loss_helper(
policy, policy.model, TorchCategorical, train_batch,
policy.model.last_output(),
policy.model.value_function().detach().numpy()
)
check(policy.loss_obj.mean_kl, kl)
check(policy.loss_obj.mean_entropy, entropy)
check(policy.loss_obj.mean_policy_loss, np.mean(-pg_loss))
check(policy.loss_obj.mean_vf_loss, np.mean(vf_loss), decimals=4)
check(policy.loss_obj.loss, overall_loss, decimals=4)
for fw in ["tf", "torch"]:
print("framework={}".format(fw))
config["use_pytorch"] = fw == "torch"
config["eager"] = fw == "tf"

trainer = ppo.PPOTrainer(config=config, env="CartPole-v0")
policy = trainer.get_policy()

# Post-process (calculate simple (non-GAE) advantages) and attach
# to train_batch dict.
# A = [0.99^2 * 0.5 + 0.99 * -1.0 + 1.0, 0.99 * 0.5 - 1.0, 0.5] =
# [0.50005, -0.505, 0.5]
if fw == "tf":
train_batch = postprocess_ppo_gae_tf(policy, train_batch)
else:
train_batch = postprocess_ppo_gae_torch(policy, train_batch)
train_batch = policy._lazy_tensor_dict(train_batch)

# Check Advantage values.
check(train_batch[Postprocessing.VALUE_TARGETS],
[0.50005, -0.505, 0.5])

# Calculate actual PPO loss (results are stored in policy.loss_obj)
# for tf.
if fw == "tf":
ppo_surrogate_loss_tf(policy, policy.model, Categorical,
train_batch)
else:
ppo_surrogate_loss_torch(policy, policy.model,
TorchCategorical, train_batch)

vars = policy.model.variables() if fw == "tf" else \
list(policy.model.parameters())
expected_shared_out = fc(train_batch[SampleBatch.CUR_OBS], vars[0],
vars[1])
expected_logits = fc(expected_shared_out, vars[2], vars[3])
expected_value_outs = fc(expected_shared_out, vars[4], vars[5])

kl, entropy, pg_loss, vf_loss, overall_loss = \
self._ppo_loss_helper(
policy, policy.model,
Categorical if fw == "tf" else TorchCategorical,
train_batch,
expected_logits, expected_value_outs
)
check(policy.loss_obj.mean_kl, kl)
check(policy.loss_obj.mean_entropy, entropy)
check(policy.loss_obj.mean_policy_loss, np.mean(-pg_loss))
check(policy.loss_obj.mean_vf_loss, np.mean(vf_loss), decimals=4)
check(policy.loss_obj.loss, overall_loss, decimals=4)

def _ppo_loss_helper(self, policy, model, dist_class, train_batch, logits,
vf_outs):
Expand Down
50 changes: 29 additions & 21 deletions rllib/examples/custom_keras_rnn_model.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""Example of using a custom RNN keras model."""

import argparse
import gym
from gym.spaces import Discrete
import numpy as np
import random
import argparse

import ray
from ray import tune
Expand Down Expand Up @@ -89,13 +89,17 @@ def value_function(self):


class RepeatInitialEnv(gym.Env):
"""Simple env in which the policy learns to repeat the initial observation
seen at timestep 0."""
"""Simple env where policy has to always repeat the initial observation.
Runs for 100 steps.
r=1 if action correct, -1 otherwise (max. R=100).
"""

def __init__(self):
def __init__(self, episode_len=100):
self.observation_space = Discrete(2)
self.action_space = Discrete(2)
self.token = None
self.episode_len = episode_len
self.num_steps = 0

def reset(self):
Expand All @@ -109,7 +113,7 @@ def step(self, action):
else:
reward = -1
self.num_steps += 1
done = self.num_steps > 100
done = self.num_steps >= self.episode_len
return 0, reward, done, {}


Expand Down Expand Up @@ -148,22 +152,26 @@ def _next_obs(self):
ModelCatalog.register_custom_model("rnn", MyKerasRNN)
register_env("RepeatAfterMeEnv", lambda c: RepeatAfterMeEnv(c))
register_env("RepeatInitialEnv", lambda _: RepeatInitialEnv())

config = {
"env": args.env,
"env_config": {
"repeat_delay": 2,
},
"gamma": 0.9,
"num_workers": 0,
"num_envs_per_worker": 20,
"entropy_coeff": 0.001,
"num_sgd_iter": 5,
"vf_loss_coeff": 1e-5,
"model": {
"custom_model": "rnn",
"max_seq_len": 20,
},
}

tune.run(
args.run,
config=config,
stop={"episode_reward_mean": args.stop},
config={
"env": args.env,
"env_config": {
"repeat_delay": 2,
},
"gamma": 0.9,
"num_workers": 0,
"num_envs_per_worker": 20,
"entropy_coeff": 0.001,
"num_sgd_iter": 5,
"vf_loss_coeff": 1e-5,
"model": {
"custom_model": "rnn",
"max_seq_len": 20,
},
})
)
128 changes: 128 additions & 0 deletions rllib/examples/custom_torch_rnn_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import argparse

import ray
from ray.rllib.examples.cartpole_lstm import CartPoleStatelessEnv
from ray.rllib.examples.custom_keras_rnn_model import RepeatInitialEnv, \
RepeatAfterMeEnv
from ray.rllib.models.preprocessors import get_preprocessor
from ray.rllib.models.torch.recurrent_torch_model import RecurrentTorchModel
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.utils.annotations import override
from ray.rllib.utils import try_import_torch
from ray.rllib.models import ModelCatalog
import ray.tune as tune

torch, nn = try_import_torch()

parser = argparse.ArgumentParser()
parser.add_argument("--run", type=str, default="PPO")
parser.add_argument("--env", type=str, default="repeat_initial")
parser.add_argument("--stop", type=int, default=90)
parser.add_argument("--num-cpus", type=int, default=0)
parser.add_argument("--fc-size", type=int, default=64)
parser.add_argument("--lstm-cell-size", type=int, default=256)


class RNNModel(RecurrentTorchModel):
def __init__(self,
obs_space,
action_space,
num_outputs,
model_config,
name,
fc_size=64,
lstm_state_size=256):
super().__init__(obs_space, action_space, num_outputs, model_config,
name)

self.obs_size = get_preprocessor(obs_space)(obs_space).size
self.fc_size = fc_size
self.lstm_state_size = lstm_state_size

# Build the Module from fc + LSTM + 2xfc (action + value outs).
self.fc1 = nn.Linear(self.obs_size, self.fc_size)
self.lstm = nn.LSTM(
self.fc_size, self.lstm_state_size, batch_first=True)
self.action_branch = nn.Linear(self.lstm_state_size, num_outputs)
self.value_branch = nn.Linear(self.lstm_state_size, 1)
# Store the value output to save an extra forward pass.
self._cur_value = None

@override(ModelV2)
def get_initial_state(self):
# make hidden states on same device as model
h = [
self.fc1.weight.new(1, self.lstm_state_size).zero_().squeeze(0),
self.fc1.weight.new(1, self.lstm_state_size).zero_().squeeze(0)
]
return h

@override(ModelV2)
def value_function(self):
assert self._cur_value is not None, "must call forward() first"
return self._cur_value

@override(RecurrentTorchModel)
def forward_rnn(self, inputs, state, seq_lens):
"""Feeds `inputs` (B x T x ..) through the Gru Unit.
Returns the resulting outputs as a sequence (B x T x ...).
Values are stored in self._cur_value in simple (B) shape (where B
contains both the B and T dims!).
Returns:
NN Outputs (B x T x ...) as sequence.
The state batches as a List of two items (c- and h-states).
"""
x = nn.functional.relu(self.fc1(inputs))
lstm_out = self.lstm(
x, [torch.unsqueeze(state[0], 0),
torch.unsqueeze(state[1], 0)])
action_out = self.action_branch(lstm_out[0])
self._cur_value = torch.reshape(self.value_branch(lstm_out[0]), [-1])
return action_out, [
torch.squeeze(lstm_out[1][0], 0),
torch.squeeze(lstm_out[1][1], 0)
]


if __name__ == "__main__":
args = parser.parse_args()

ray.init(num_cpus=args.num_cpus or None)
ModelCatalog.register_custom_model("rnn", RNNModel)
tune.register_env(
"repeat_initial", lambda _: RepeatInitialEnv(episode_len=100))
tune.register_env(
"repeat_after_me", lambda _: RepeatAfterMeEnv({"repeat_delay": 1}))
tune.register_env("cartpole_stateless", lambda _: CartPoleStatelessEnv())

config = {
"env": args.env,
"use_pytorch": True,
"num_workers": 0,
"num_envs_per_worker": 20,
"gamma": 0.9,
"entropy_coeff": 0.0001,
"model": {
"custom_model": "rnn",
"max_seq_len": 20,
"lstm_use_prev_action_reward": "store_true",
"custom_options": {
"fc_size": args.fc_size,
"lstm_state_size": args.lstm_cell_size,
}
},
"lr": 3e-4,
"num_sgd_iter": 5,
"vf_loss_coeff": 0.0003,
}

tune.run(
args.run,
stop={
"episode_reward_mean": args.stop,
"timesteps_total": 100000
},
config=config,
)
Loading

0 comments on commit 66df8b8

Please sign in to comment.