-
Notifications
You must be signed in to change notification settings - Fork 6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Add differentiable neural computer example #14844
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is awesome! Is this learning on some smaller env so we could add a learning test/example for this?
Yes, I will add a basic test using the repeatafterme env before pushing |
I've added passing unit tests, but it's not clear this is actually learning in the cartpole environment. I'll let it run a bit longer and see what happens. |
@smorad could you also run the LINTer?
The other tests look ok (we fixed the attention learning test, it's not related to your PR). |
This doesn't appear to learn stateless cartpole using PPO@5M timesteps. It's possible DNC learning in my env was the result of not using the state for learning. Until I find a bug (or acceptable hyperparameters s.t. DNC learns stateless cartpole) I'm hesitant to commit. I will run the linter again and squash after fixes. |
Hey @smorad , were you able to find the cause of the learning issue you mentioned above? |
I honestly cannot track it down. The I've tried multiple hyperparameters for up to 5M timesteps and it was unable to get it to go above 50 (I think, can't recall for sure) mean reward on stateless cartpole. I wrote more unit tests to ensure Every two weeks or so I spend a day on this trying to figure out what's wrong! If you have any better ideas for debugging this, please let me know. |
I thought this was cool too and have spent some cycles trying to get it to work but did not have any success either. I did find the r2d2 issue I posted last week while trying to get this to work so it paid off in another way at least. That issue should not be affecting this though. @smorad you got further than me I could not get it above 26 MER. I was using A2C. |
It trains! c. I set a horizon of 250. I noticed that when it was reaching 150 MER The max episode reward would be ~500 and the min would be 10. I wanted to see if I could force it to do better than that on the minimum. It did not help. I found 1 bug fix: ctrl_hidden = [
torch.zeros(2, self.cfg["hidden_size"]),
torch.zeros(2, self.cfg["hidden_size"]),
] needed to be this to change the number of hidden layers ctrl_hidden = [
torch.zeros(self.cfg["num_hidden_layers"], self.cfg["hidden_size"]),
torch.zeros(self.cfg["num_hidden_layers"], self.cfg["hidden_size"]),
] e. I got it to train with A2C and PPO but it did not work with R2D2 or IMPALA. I did not try tuning any hyper parameters so it could just be a bad configuration. In the plot below the shorter ones (<200k) are PPO and the longer ones (> 1e9) are A2C. import unittest
import torch
from ray.rllib.models import ModelCatalog
from torch import nn
import ray
import gym
from ray import tune
from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole
from typing import Union, Dict, List, Tuple
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models.torch.misc import SlimFC
from ray.rllib.utils.typing import ModelConfigDict, TensorType
try:
from dnc import DNC
except ModuleNotFoundError:
print("dnc module not found. Did you forget to 'pip install dnc'?")
raise
class DNCMemory(TorchModelV2, nn.Module):
"""Differentiable Neural Computer wrapper around ixaxaar's DNC implementation,
see https://github.com/ixaxaar/pytorch-dnc"""
DEFAULT_CONFIG = {
"dnc_model": DNC,
# Number of controller hidden layers
"num_hidden_layers": 2,
# Number of weights per controller hidden layer
"hidden_size": 128,
# Number of LSTM units
"num_layers": 1,
# Number of read heads, i.e. how many addrs are read at once
"read_heads": 4,
# Number of memory cells in the controller
"nr_cells": 32,
# Size of each cell
"cell_size": 16,
# LSTM activation function
"nonlinearity": "tanh",
"fc1_dim": 64,
}
MEMORY_KEYS = [
"memory",
"link_matrix",
"precedence",
"read_weights",
"write_weights",
"usage_vector",
]
def __init__(
self,
obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
num_outputs: int,
model_config: ModelConfigDict,
name: str,
**custom_model_kwargs,
):
nn.Module.__init__(self)
super(DNCMemory, self).__init__(obs_space, action_space, num_outputs,
model_config, name)
self.num_outputs = num_outputs
self.obs_dim = gym.spaces.utils.flatdim(obs_space)
self.act_dim = gym.spaces.utils.flatdim(action_space)
self.cfg = dict(self.DEFAULT_CONFIG, **custom_model_kwargs)
assert self.cfg['num_layers'] == 1, "num_layers != 1 has not been implemented yet"
self.cur_val = None
self.fc1 = SlimFC(
in_size=self.obs_dim,
out_size=self.cfg["fc1_dim"],
activation_fn=None,
initializer=torch.nn.init.xavier_uniform_,
)
self.logit_branch = SlimFC(
in_size=self.cfg["fc1_dim"],
out_size=self.num_outputs,
activation_fn=None,
initializer=torch.nn.init.xavier_uniform_,
)
self.value_branch = SlimFC(
in_size=self.cfg["fc1_dim"],
out_size=1,
activation_fn=None,
initializer=torch.nn.init.xavier_uniform_,
)
self.dnc: Union[None, DNC] = None
def get_initial_state(self) -> List[TensorType]:
ctrl_hidden = [
torch.zeros(self.cfg["num_hidden_layers"], self.cfg["hidden_size"]),
torch.zeros(self.cfg["num_hidden_layers"], self.cfg["hidden_size"]),
]
m = self.cfg["nr_cells"]
r = self.cfg["read_heads"]
w = self.cfg["cell_size"]
memory = [
torch.zeros(m, w), # memory
torch.zeros(1, m, m), # link_matrix
torch.zeros(1, m), # precedence
torch.zeros(r, m), # read_weights
torch.zeros(1, m), # write_weights
torch.zeros(m), # usage_vector
]
read_vecs = torch.zeros(w * r)
state = [*ctrl_hidden, read_vecs, *memory]
assert len(state) == 9
return state
def value_function(self) -> TensorType:
assert self.cur_val is not None, "must call forward() first"
return self.cur_val
def unpack_state(self, state: List[TensorType],
) -> Tuple[List[Tuple[TensorType, TensorType]],
Dict[str, TensorType], TensorType]:
"""Given a list of tensors, reformat for self.dnc input"""
assert len(state) == 9, "Failed to verify unpacked state"
ctrl_hidden: List[Tuple[TensorType, TensorType]] = [(
state[0].permute(1, 0, 2).contiguous(),
state[1].permute(1, 0, 2).contiguous(),
)]
read_vecs: TensorType = state[2]
memory: List[TensorType] = state[3:]
memory_dict: Dict[str, TensorType] = dict(
zip(self.MEMORY_KEYS, memory))
return ctrl_hidden, memory_dict, read_vecs
def pack_state(
self,
ctrl_hidden: List[Tuple[TensorType, TensorType]],
memory_dict: Dict[str, TensorType],
read_vecs: TensorType,
) -> List[TensorType]:
"""Given the dnc output, pack it into a list of tensors
for rllib state. Order is ctrl_hidden, read_vecs, memory_dict"""
state = []
ctrl_hidden = [
ctrl_hidden[0][0].permute(1, 0, 2),
ctrl_hidden[0][1].permute(1, 0, 2),
]
state += ctrl_hidden # len 2
state.append(read_vecs) # len 3
state += memory_dict.values() # len 9
assert len(state) == 9, "Failed to verify packed state"
return state
def validate_unpack(self, dnc_output, unpacked_state):
"""Ensure the unpacked state shapes match the DNC output"""
s_ctrl_hidden, s_memory_dict, s_read_vecs = unpacked_state
ctrl_hidden, memory_dict, read_vecs = dnc_output
for i in range(len(ctrl_hidden)):
for j in range(len(ctrl_hidden[i])):
assert s_ctrl_hidden[i][j].shape == ctrl_hidden[i][j].shape, (
"Controller state mismatch: got "
f"{s_ctrl_hidden[i][j].shape} should be "
f"{ctrl_hidden[i][j].shape}")
for k in memory_dict:
assert s_memory_dict[k].shape == memory_dict[k].shape, (
"Memory state mismatch at key "
f"{k}: got {s_memory_dict[k].shape} should be "
f"{memory_dict[k].shape}")
assert s_read_vecs.shape == read_vecs.shape, (
"Read state mismatch: got "
f"{s_read_vecs.shape} should be "
f"{read_vecs.shape}")
def forward(
self,
input_dict: Dict[str, TensorType],
state: List[TensorType],
seq_lens: TensorType,
) -> Tuple[TensorType, List[TensorType]]:
flat = input_dict["obs_flat"]
flat = self.fc1(flat)
# Batch and Time
# Forward expects outputs as [B, T, logits]
B = len(seq_lens)
T = flat.shape[0] // B
#logits = torch.zeros(B, T, self.num_outputs, device=flat.device)
#values = torch.zeros(B, T, 1, device=flat.device)
# Deconstruct batch into batch and time dimensions: [B, T, feats]
flat = torch.reshape(flat, [-1, T] + list(flat.shape[1:]))
# First run
if self.dnc is None:
(ctrl_hidden, read_vecs, memory_dict) = (None, None, None)
gpu_id = flat.device.index if flat.device.index is not None else -1
self.dnc = self.cfg["dnc_model"](
input_size=self.cfg["fc1_dim"],
hidden_size=self.cfg["hidden_size"],
num_layers=self.cfg["num_layers"],
num_hidden_layers=self.cfg["num_hidden_layers"],
read_heads=self.cfg["read_heads"],
cell_size=self.cfg["cell_size"],
nr_cells=self.cfg["nr_cells"],
nonlinearity=self.cfg["nonlinearity"],
gpu_id=gpu_id,
)
else:
ctrl_hidden, memory_dict, read_vecs = self.unpack_state(state)
output, (ctrl_hidden, memory_dict, read_vecs) = self.dnc(
flat, (ctrl_hidden, memory_dict, read_vecs))
packed_state = self.pack_state(ctrl_hidden, memory_dict, read_vecs)
# Compute action/value from output
logits = self.logit_branch(output.view(B * T, -1))
values = self.value_branch(output.view(B * T, -1))
self.cur_val = values.squeeze(1)
return logits, packed_state
if __name__ == "__main__":
ray.init(num_cpus=4)
ModelCatalog.register_custom_model("dnc", DNCMemory)
config = {
"env": StatelessCartPole,
"gamma": 0.99,
"num_envs_per_worker": 5,
"framework": "torch",
"num_workers": 1,
"num_gpus": 1,
"horizon": 250,
"entropy_coeff": 0.0005,
"lr": 0.01,
#"vf_loss_coeff": 1e-5,
#"num_sgd_iter": 5,
"model": {
"custom_model": "dnc",
"max_seq_len": 10,
"custom_model_config": {
"nr_cells": 10,
"read_heads": 2,
"cell_size": 4,
"num_layers": 1,
"hidden_size": 64,
"num_hidden_layers": 1,
"fc1_dim": 64,
},
},
}
tune.run("A2C", name="DNC_DEMO5", num_samples=4, config=config, stop={"episode_reward_mean": 150.0,"timesteps_total": 5000000,})
tune.run("PPO", name="DNC_DEMO5", num_samples=4, config=config, stop={"episode_reward_mean": 150.0,"timesteps_total": 5000000,}) |
Great work @mvindiola1! Now that you've verified this actually learns, I'll get to work on merging in my final changes and rebasing. |
@sven1977 this requires the |
I'm also finding that All the runs are @sven1977 Is it possible there is a bug in Once you're ok with the changes, I think we are ready to merge. |
We seem to be approaching the preprocessing differently. I am using the preprocessing layer to expand the values stored in the DNC where as you are storing entries that match the size of the observation space. Specifically, I made these changes: self.preprocessor = torch.nn.Sequential(
torch.nn.Linear(self.obs_dim, self.cfg["preprocessor_output_size"]))
...
self.logit_branch = SlimFC(
in_size=self.cfg["preprocessor_output_size"],
...
self.value_branch = SlimFC(
in_size=self.cfg["preprocessor_output_size"],
...
self.dnc = self.cfg["dnc_model"](
input_size=self.cfg["preprocessor_output_size"], |
Should we not be consistent with LSTM/GTrXL where we place a 2-layer MLP
before the memory module (in this case, the dnc)? I think the issue was the
fact we were missing an MLP/nonlinear activation layer before the DNC.
Anyways, you’re right that there is no reason to use the observation space
as hidden size.
…On Tue, 18 May 2021 at 18:07, mvindiola1 ***@***.***> wrote:
@smorad <https://github.com/smorad>,
We seem to be approaching this differently. I am using the preprocessing
layer to expand the values stored in the DNC where as you are storing
entries that match the size of the observation space.
In the plots below I changed the code to do the first approach and ran
with yesterday's nightly build.
Specifically, I made these changes:
self.preprocessor = torch.nn.Sequential(
torch.nn.Linear(self.obs_dim, self.cfg["preprocessor_output_size"]))
...self.logit_branch = SlimFC(
in_size=self.cfg["preprocessor_output_size"],
...self.value_branch = SlimFC(
in_size=self.cfg["preprocessor_output_size"],
...self.dnc = self.cfg["dnc_model"](
input_size=self.cfg["preprocessor_output_size"],
Here are results for:
PPO
[image: image]
<https://user-images.githubusercontent.com/4225775/118694147-c07aed80-b7d9-11eb-9c1f-56f1ca22126e.png>
A2C
[image: image]
<https://user-images.githubusercontent.com/4225775/118694074-ac36f080-b7d9-11eb-9942-2696a4879307.png>
IMPALA (with vtrace)
[image: image]
<https://user-images.githubusercontent.com/4225775/118693548-261aaa00-b7d9-11eb-8131-0085b3e502a1.png>
IMPALA (w/o vtrace)
[image: image]
<https://user-images.githubusercontent.com/4225775/118693962-8e698b80-b7d9-11eb-9291-db0a499d1589.png>
—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
<#14844 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AANJYEP3ACH4Y6NNSNRY7GTTOKNDLANCNFSM4ZTMJAHA>
.
|
The 2-layer MLP seems like overkill for this task and we have evidence that 1 layer is sufficient but as long as it learns I have no real opinion on this. What I do think matters, is that the MLP expands hidden side to as size larger than 2. Manny |
Good catch, I'll update with changes tomorrow. |
I re-ran it where I kept the 2 layer MLP but removed that last bottleneck layer. self.preprocessor = torch.nn.Sequential(
torch.nn.Linear(self.obs_dim, self.cfg["preprocessor_input_size"]),
self.cfg["preprocessor"])
self.logit_branch = SlimFC(
in_size=self.cfg["preprocessor_output_size"],
out_size=self.num_outputs,
activation_fn=None,
initializer=torch.nn.init.xavier_uniform_,
)
self.value_branch = SlimFC(
in_size=self.cfg["preprocessor_output_size"],
out_size=1, In these results PPO reaches the target mean episode reward in the fewest iterations. But it takes the longest wall time to do it 45-75 min for PPO versus 1-2 min for A2C and IMPALA |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome! This is super valuable, thanks @smorad !
:)
The commited version still has that last bottleneck layer in the preprocessing network. My test runs are showing that layer significantly increases training time and in the case of impala prevents it entirely. |
Why are these changes needed?
Deepmind's DNC is a pretty interesting solution to the partial observability problem and I think it should be included with rllib.
Related issue number
Checks
scripts/format.sh
to lint the changes in this PR.Shapes are verified at initialization and there are plenty of asserts