Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Add differentiable neural computer example #14844

Merged
merged 2 commits into from
May 19, 2021
Merged

[RLlib] Add differentiable neural computer example #14844

merged 2 commits into from
May 19, 2021

Conversation

smorad
Copy link
Contributor

@smorad smorad commented Mar 22, 2021

Why are these changes needed?

Deepmind's DNC is a pretty interesting solution to the partial observability problem and I think it should be included with rllib.

Related issue number

Checks

  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

Shapes are verified at initialization and there are plenty of asserts

Copy link
Contributor

@sven1977 sven1977 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is awesome! Is this learning on some smaller env so we could add a learning test/example for this?

@smorad
Copy link
Contributor Author

smorad commented Mar 24, 2021

Yes, I will add a basic test using the repeatafterme env before pushing

@smorad
Copy link
Contributor Author

smorad commented Mar 25, 2021

I've added passing unit tests, but it's not clear this is actually learning in the cartpole environment. I'll let it run a bit longer and see what happens.

@sven1977 sven1977 added the @author-action-required The PR author is responsible for the next step. Remove tag to send back to the reviewer. label Mar 29, 2021
@sven1977
Copy link
Contributor

@smorad could you also run the LINTer?
Simply do:

cd ray
ci/travis/format.sh

The other tests look ok (we fixed the attention learning test, it's not related to your PR).

@smorad
Copy link
Contributor Author

smorad commented Mar 29, 2021

This doesn't appear to learn stateless cartpole using PPO@5M timesteps. It's possible DNC learning in my env was the result of not using the state for learning.

Until I find a bug (or acceptable hyperparameters s.t. DNC learns stateless cartpole) I'm hesitant to commit. I will run the linter again and squash after fixes.

@sven1977
Copy link
Contributor

sven1977 commented May 4, 2021

Hey @smorad , were you able to find the cause of the learning issue you mentioned above?
Would love to have this example in the lib!

@smorad
Copy link
Contributor Author

smorad commented May 4, 2021

I honestly cannot track it down. The dnc module this uses has unit tests and is used by others, so it should work.

I've tried multiple hyperparameters for up to 5M timesteps and it was unable to get it to go above 50 (I think, can't recall for sure) mean reward on stateless cartpole. I wrote more unit tests to ensure pack(unpack(state)) == state. I tried using PPO and IMPALA.

Every two weeks or so I spend a day on this trying to figure out what's wrong! If you have any better ideas for debugging this, please let me know.

@mvindiola1
Copy link
Contributor

mvindiola1 commented May 7, 2021

Hi @sven1977 and @smorad,

I thought this was cool too and have spent some cycles trying to get it to work but did not have any success either.

I did find the r2d2 issue I posted last week while trying to get this to work so it paid off in another way at least. That issue should not be affecting this though.

@smorad you got further than me I could not get it above 26 MER. I was using A2C.

@smorad
Copy link
Contributor Author

smorad commented May 10, 2021

So although I'm still struggling with cartpole, I think it's starting to work on my navigation problem
Screenshot 2021-05-10 at 17 50 00

With cartpole it's not clear it's learning.
Screenshot 2021-05-10 at 17 59 43

So I'm thinking at this point I may have fixed the issue, but the DNC could be very picky with regards to hyperparameters.

@mvindiola1
Copy link
Contributor

mvindiola1 commented May 17, 2021

@smorad,

It trains!
I had to change two things to get it to train.
a. I had to add a fully connected layer before the DNC or it would not reach even 50 MER.
b. I had to reduce num_hidden_layers to 1 or it would not learn past 80 MER.

c. I set a horizon of 250. I noticed that when it was reaching 150 MER The max episode reward would be ~500 and the min would be 10. I wanted to see if I could force it to do better than that on the minimum. It did not help.
d. I lowered the envs per worker to 5 that seemed to lower the number of timesteps needed to reach 150

I found 1 bug fix:

       ctrl_hidden = [
           torch.zeros(2, self.cfg["hidden_size"]),
           torch.zeros(2, self.cfg["hidden_size"]),
       ]

needed to be this to change the number of hidden layers

       ctrl_hidden = [
           torch.zeros(self.cfg["num_hidden_layers"], self.cfg["hidden_size"]),
           torch.zeros(self.cfg["num_hidden_layers"], self.cfg["hidden_size"]),
       ]

e.
Other than those, I did not change any of your original implementation.

I got it to train with A2C and PPO but it did not work with R2D2 or IMPALA. I did not try tuning any hyper parameters so it could just be a bad configuration.

In the plot below the shorter ones (<200k) are PPO and the longer ones (> 1e9) are A2C.

image

import unittest
import torch
from ray.rllib.models import ModelCatalog
from torch import nn
import ray
import gym
from ray import tune
from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole
from typing import Union, Dict, List, Tuple
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models.torch.misc import SlimFC
from ray.rllib.utils.typing import ModelConfigDict, TensorType

try:
   from dnc import DNC
except ModuleNotFoundError:
   print("dnc module not found. Did you forget to 'pip install dnc'?")
   raise


class DNCMemory(TorchModelV2, nn.Module):
   """Differentiable Neural Computer wrapper around ixaxaar's DNC implementation,
   see https://github.com/ixaxaar/pytorch-dnc"""

   DEFAULT_CONFIG = {
       "dnc_model": DNC,
       # Number of controller hidden layers
       "num_hidden_layers": 2,
       # Number of weights per controller hidden layer
       "hidden_size": 128,
       # Number of LSTM units
       "num_layers": 1,
       # Number of read heads, i.e. how many addrs are read at once
       "read_heads": 4,
       # Number of memory cells in the controller
       "nr_cells": 32,
       # Size of each cell
       "cell_size": 16,
       # LSTM activation function
       "nonlinearity": "tanh",
       "fc1_dim": 64,
   }

   MEMORY_KEYS = [
       "memory",
       "link_matrix",
       "precedence",
       "read_weights",
       "write_weights",
       "usage_vector",
   ]

   def __init__(
           self,
           obs_space: gym.spaces.Space,
           action_space: gym.spaces.Space,
           num_outputs: int,
           model_config: ModelConfigDict,
           name: str,
           **custom_model_kwargs,
   ):
       nn.Module.__init__(self)
       super(DNCMemory, self).__init__(obs_space, action_space, num_outputs,
                                       model_config, name)
       self.num_outputs = num_outputs
       self.obs_dim = gym.spaces.utils.flatdim(obs_space)
       self.act_dim = gym.spaces.utils.flatdim(action_space)

       self.cfg = dict(self.DEFAULT_CONFIG, **custom_model_kwargs)
       assert self.cfg['num_layers'] == 1, "num_layers != 1 has not been implemented yet"
       self.cur_val = None

       self.fc1 = SlimFC(
           in_size=self.obs_dim,
           out_size=self.cfg["fc1_dim"],
           activation_fn=None,
           initializer=torch.nn.init.xavier_uniform_,
       )

       self.logit_branch = SlimFC(
           in_size=self.cfg["fc1_dim"],
           out_size=self.num_outputs,
           activation_fn=None,
           initializer=torch.nn.init.xavier_uniform_,
       )

       self.value_branch = SlimFC(
           in_size=self.cfg["fc1_dim"],
           out_size=1,
           activation_fn=None,
           initializer=torch.nn.init.xavier_uniform_,
       )

       self.dnc: Union[None, DNC] = None

   def get_initial_state(self) -> List[TensorType]:
       ctrl_hidden = [
           torch.zeros(self.cfg["num_hidden_layers"], self.cfg["hidden_size"]),
           torch.zeros(self.cfg["num_hidden_layers"], self.cfg["hidden_size"]),
       ]
       m = self.cfg["nr_cells"]
       r = self.cfg["read_heads"]
       w = self.cfg["cell_size"]
       memory = [
           torch.zeros(m, w),  # memory
           torch.zeros(1, m, m),  # link_matrix
           torch.zeros(1, m),  # precedence
           torch.zeros(r, m),  # read_weights
           torch.zeros(1, m),  # write_weights
           torch.zeros(m),  # usage_vector
       ]

       read_vecs = torch.zeros(w * r)

       state = [*ctrl_hidden, read_vecs, *memory]
       assert len(state) == 9
       return state

   def value_function(self) -> TensorType:
       assert self.cur_val is not None, "must call forward() first"
       return self.cur_val

   def unpack_state(self, state: List[TensorType],
                    ) -> Tuple[List[Tuple[TensorType, TensorType]],
                               Dict[str, TensorType], TensorType]:
       """Given a list of tensors, reformat for self.dnc input"""
       assert len(state) == 9, "Failed to verify unpacked state"
       ctrl_hidden: List[Tuple[TensorType, TensorType]] = [(
           state[0].permute(1, 0, 2).contiguous(),
           state[1].permute(1, 0, 2).contiguous(),
       )]
       read_vecs: TensorType = state[2]
       memory: List[TensorType] = state[3:]
       memory_dict: Dict[str, TensorType] = dict(
           zip(self.MEMORY_KEYS, memory))

       return ctrl_hidden, memory_dict, read_vecs

   def pack_state(
           self,
           ctrl_hidden: List[Tuple[TensorType, TensorType]],
           memory_dict: Dict[str, TensorType],
           read_vecs: TensorType,
   ) -> List[TensorType]:
       """Given the dnc output, pack it into a list of tensors
       for rllib state. Order is ctrl_hidden, read_vecs, memory_dict"""
       state = []
       ctrl_hidden = [
           ctrl_hidden[0][0].permute(1, 0, 2),
           ctrl_hidden[0][1].permute(1, 0, 2),
       ]
       state += ctrl_hidden  # len 2
       state.append(read_vecs)  # len 3
       state += memory_dict.values()  # len 9
       assert len(state) == 9, "Failed to verify packed state"
       return state

   def validate_unpack(self, dnc_output, unpacked_state):
       """Ensure the unpacked state shapes match the DNC output"""
       s_ctrl_hidden, s_memory_dict, s_read_vecs = unpacked_state
       ctrl_hidden, memory_dict, read_vecs = dnc_output

       for i in range(len(ctrl_hidden)):
           for j in range(len(ctrl_hidden[i])):
               assert s_ctrl_hidden[i][j].shape == ctrl_hidden[i][j].shape, (
                   "Controller state mismatch: got "
                   f"{s_ctrl_hidden[i][j].shape} should be "
                   f"{ctrl_hidden[i][j].shape}")

       for k in memory_dict:
           assert s_memory_dict[k].shape == memory_dict[k].shape, (
               "Memory state mismatch at key "
               f"{k}: got {s_memory_dict[k].shape} should be "
               f"{memory_dict[k].shape}")

       assert s_read_vecs.shape == read_vecs.shape, (
           "Read state mismatch: got "
           f"{s_read_vecs.shape} should be "
           f"{read_vecs.shape}")

   def forward(
           self,
           input_dict: Dict[str, TensorType],
           state: List[TensorType],
           seq_lens: TensorType,
   ) -> Tuple[TensorType, List[TensorType]]:

       flat = input_dict["obs_flat"]
       flat = self.fc1(flat)
       # Batch and Time
       # Forward expects outputs as [B, T, logits]
       B = len(seq_lens)
       T = flat.shape[0] // B

       #logits = torch.zeros(B, T, self.num_outputs, device=flat.device)
       #values = torch.zeros(B, T, 1, device=flat.device)
       # Deconstruct batch into batch and time dimensions: [B, T, feats]
       flat = torch.reshape(flat, [-1, T] + list(flat.shape[1:]))

       # First run
       if self.dnc is None:
           (ctrl_hidden, read_vecs, memory_dict) = (None, None, None)
           gpu_id = flat.device.index if flat.device.index is not None else -1
           self.dnc = self.cfg["dnc_model"](
               input_size=self.cfg["fc1_dim"],
               hidden_size=self.cfg["hidden_size"],
               num_layers=self.cfg["num_layers"],
               num_hidden_layers=self.cfg["num_hidden_layers"],
               read_heads=self.cfg["read_heads"],
               cell_size=self.cfg["cell_size"],
               nr_cells=self.cfg["nr_cells"],
               nonlinearity=self.cfg["nonlinearity"],
               gpu_id=gpu_id,
           )

       else:
           ctrl_hidden, memory_dict, read_vecs = self.unpack_state(state)

       output, (ctrl_hidden, memory_dict, read_vecs) = self.dnc(
           flat, (ctrl_hidden, memory_dict, read_vecs))

       packed_state = self.pack_state(ctrl_hidden, memory_dict, read_vecs)

       # Compute action/value from output
       logits = self.logit_branch(output.view(B * T, -1))
       values = self.value_branch(output.view(B * T, -1))

       self.cur_val = values.squeeze(1)

       return logits, packed_state


if __name__ == "__main__":

   ray.init(num_cpus=4)

   ModelCatalog.register_custom_model("dnc", DNCMemory)
   config = {
       "env": StatelessCartPole,
       "gamma": 0.99,
       "num_envs_per_worker": 5,
       "framework": "torch",
       "num_workers": 1,
       "num_gpus": 1,
       "horizon": 250,
       "entropy_coeff": 0.0005,
       "lr": 0.01,
       #"vf_loss_coeff": 1e-5,
       #"num_sgd_iter": 5,
       "model": {
           "custom_model": "dnc",
           "max_seq_len": 10,
           "custom_model_config": {
               "nr_cells": 10,
               "read_heads": 2,
               "cell_size": 4,
               "num_layers": 1,
               "hidden_size": 64,
               "num_hidden_layers": 1,
               "fc1_dim": 64,
           },
       },
   }
   tune.run("A2C", name="DNC_DEMO5", num_samples=4, config=config, stop={"episode_reward_mean": 150.0,"timesteps_total": 5000000,})
   tune.run("PPO", name="DNC_DEMO5", num_samples=4, config=config, stop={"episode_reward_mean": 150.0,"timesteps_total": 5000000,})

@smorad
Copy link
Contributor Author

smorad commented May 17, 2021

Great work @mvindiola1! Now that you've verified this actually learns, I'll get to work on merging in my final changes and rebasing.

@smorad
Copy link
Contributor Author

smorad commented May 17, 2021

@sven1977 this requires the dnc package. Should I include it in the rllib requirements.txt or would you rather not? Unit test won't pass if the package is not installed.

@smorad
Copy link
Contributor Author

smorad commented May 17, 2021

I'm also finding that IMPALA is nearly useless in solving this, but A2C seems to be doing much better. On master, PPO just hangs indefinitely and never completes a training loop.

All the runs are IMPALA except for one, see if you can guess which!
image

@sven1977 Is it possible there is a bug in IMPALA dealing with recurrent models? A2C and IMPALA implementations are quite similar, aren't they?

Once you're ok with the changes, I think we are ready to merge.

@mvindiola1
Copy link
Contributor

mvindiola1 commented May 18, 2021

@smorad,

We seem to be approaching the preprocessing differently. I am using the preprocessing layer to expand the values stored in the DNC where as you are storing entries that match the size of the observation space.
In the plots below I changed the code to do the first approach and ran with yesterday's nightly build.

Specifically, I made these changes:

self.preprocessor = torch.nn.Sequential(
    torch.nn.Linear(self.obs_dim, self.cfg["preprocessor_output_size"]))
...
self.logit_branch = SlimFC(
            in_size=self.cfg["preprocessor_output_size"],
 ...
self.value_branch = SlimFC(
            in_size=self.cfg["preprocessor_output_size"],
...
self.dnc = self.cfg["dnc_model"](
            input_size=self.cfg["preprocessor_output_size"],

Here are results for:
PPO
image

A2C
image

IMPALA (with vtrace)
image

IMPALA (w/o vtrace)
image

@smorad
Copy link
Contributor Author

smorad commented May 18, 2021 via email

@mvindiola1
Copy link
Contributor

The 2-layer MLP seems like overkill for this task and we have evidence that 1 layer is sufficient but as long as it learns I have no real opinion on this.

What I do think matters, is that the MLP expands hidden side to as size larger than 2.

Manny

@smorad
Copy link
Contributor Author

smorad commented May 18, 2021

Good catch, I'll update with changes tomorrow.

@mvindiola1
Copy link
Contributor

mvindiola1 commented May 18, 2021

I re-ran it where I kept the 2 layer MLP but removed that last bottleneck layer.

self.preprocessor = torch.nn.Sequential(
            torch.nn.Linear(self.obs_dim, self.cfg["preprocessor_input_size"]),
            self.cfg["preprocessor"])

        self.logit_branch = SlimFC(
            in_size=self.cfg["preprocessor_output_size"],
            out_size=self.num_outputs,
            activation_fn=None,
            initializer=torch.nn.init.xavier_uniform_,
        )

        self.value_branch = SlimFC(
            in_size=self.cfg["preprocessor_output_size"],
            out_size=1,

In these results PPO reaches the target mean episode reward in the fewest iterations. But it takes the longest wall time to do it 45-75 min for PPO versus 1-2 min for A2C and IMPALA

PPO:
image

A2C:
image

IMPALA (with vtrace):
image

IMPALA (w/o vtrace):
image

Times:
image

@sven1977 sven1977 removed the @author-action-required The PR author is responsible for the next step. Remove tag to send back to the reviewer. label May 19, 2021
Copy link
Contributor

@sven1977 sven1977 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome! This is super valuable, thanks @smorad !
:)

doc/source/rllib-examples.rst Outdated Show resolved Hide resolved
@sven1977 sven1977 changed the title [rllib] Add differentiable neural computer example [RLlib] Add differentiable neural computer example May 19, 2021
@sven1977 sven1977 merged commit d8eed68 into ray-project:master May 19, 2021
@mvindiola1
Copy link
Contributor

@smorad,

The commited version still has that last bottleneck layer in the preprocessing network. My test runs are showing that layer significantly increases training time and in the case of impala prevents it entirely.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants