Skip to content

Conversation

PaLeroy
Copy link
Contributor

@PaLeroy PaLeroy commented Oct 23, 2022

Description

Creation of the GRUNet to control agents in QMIX

Motivation and Context

Types of changes

What types of changes does your code introduce? Remove all that do not apply:

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds core functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation (update in the documentation)
  • Example (update in the folder of examples)

Checklist

Go over all the following points, and put an x in all the boxes that apply.
If you are unsure about any of these, don't hesitate to ask. We are here to help!

  • I have read the CONTRIBUTION guide (required)
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 23, 2022
@codecov
Copy link

codecov bot commented Oct 23, 2022

Codecov Report

Merging #599 (b58ac7f) into qmix (8a7ed22) will increase coverage by 0.06%.
The diff coverage is 100.00%.

@@            Coverage Diff             @@
##             qmix     #599      +/-   ##
==========================================
+ Coverage   86.79%   86.86%   +0.06%     
==========================================
  Files         121      121              
  Lines       21667    21791     +124     
==========================================
+ Hits        18806    18928     +122     
- Misses       2861     2863       +2     
Flag Coverage Δ
linux-cpu 85.19% <100.00%> (+0.08%) ⬆️
linux-gpu 86.65% <100.00%> (+0.09%) ⬆️
linux-outdeps-gpu 75.41% <100.00%> (+0.15%) ⬆️
linux-stable-cpu 85.19% <100.00%> (+0.10%) ⬆️
linux-stable-gpu 86.65% <100.00%> (+0.09%) ⬆️
macos-cpu 84.95% <100.00%> (+0.07%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
test/test_modules.py 99.22% <100.00%> (+0.32%) ⬆️
torchrl/modules/models/models.py 95.78% <100.00%> (+0.71%) ⬆️
test/test_trainer.py 97.87% <0.00%> (-1.07%) ⬇️

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more

@vmoens vmoens added the enhancement New feature or request label Oct 25, 2022
"""

def __init__(
self, mlp_input_kwargs: Dict, gru_kwargs: Dict, mlp_output_kwargs: Dict
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it would be nice to have a device argument here, almost all torch modules now support instantiation on device


def forward(self, inputs, hidden_state):
mlp_in = self.mlp_in(inputs)
hidden_state = self.gru(mlp_in, hidden_state)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I expect that if you give a sequence as input this module will not behave as expected right?
Should we enforce that the input shapes match some requirement? Like no 3d input?
Problem with that 3d inputs are only an indirect indicator that someone is passing tensors that are time-stamped, it might as well be a 2d batch size. Plus, a 1d batch size might be time-stamped too :p

Perhaps a good docstring would go a long way. Can we put in the docstring of this class what kind of input we expect?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I never worked with GRU that takes full sequences... It appears that I missed a nice feature.
Indeed, the current implementation will not work well when taking a sequence as input.

I'll rework this as it is done in the GRU, and using GRU instead of GRUCell then (enforcing batch first)

input: tensor of shape :math:(L, H_{in}) for unbatched input,
:math:(N, L, H_{in}) when batch_first=True containing the features of
the input sequence.

This should lead to something closer to the LSTMNet already implemented.

@PaLeroy
Copy link
Contributor Author

PaLeroy commented Oct 25, 2022

I changed to GRU instead of GRUCell.
GRUNet now supports batched or unbatched inputs but the sequence length must be specified -> only 2d and 3d inputs.
Maybe the docstring is a bit overloaded?

Copy link
Collaborator

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Almost there! It's still a bit fuzzy to me how we're supposed to work with these modules in a way that handles single steps and trajectories interchangeably (e.g. using the same module at loss compute time and during policy execution while benefitting from the cudnn kernel at loss time).

But it's more an open question than a comment about this PR.
One option could be to use first-class dimensions, I'll explore that and keep you posted.


Args:
mlp_input_kwargs (dict): kwargs for the MLP before the GRU
mlp_output_kwargs (dict): kwargs for the MLP after the GRU
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be nice to have default values for those three modules, but that's open to discussion.
For GRU, usually in torchrl we're using batch_first=True, and the time dimension is always the one preceding the last dim. At least that should be the default.
Can we put an example in the docstring?

    Examples:
        >>> grunet = GRUNet(
        ...             mlp_input_kwargs={"feature_in": ...}, 
        ...             **etc)
        >>> foo = torch.randn(*shape)
        >>> bar = grunet(foo)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another possibility is to have input_features, hidden_size and output_features as non-optional arguments that would create a default GRUNet such as the one in the example I just pushed, and having these 3 param dictionaries as optional and used only if a different network configuration is required.
What do you think?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds great!

Comment on lines 1058 to 1065
N = batch size
L = sequence size
mlp_input_in = input size of the MLP_input
mlp_input_out = output size of the MLP_input
H_in = input size of the GRU
H_out = output size of the GRU
mlp_output_in = input size of the MLP_output
mlp_output_out = output size of the MLP_output
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are those fields? I don't see them anywhere

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only present in the docstring to document sizes of inputs and outputs.
They would become irrelevant with the suggestion I just made to change the arguments of GRUNet.

@vmoens vmoens changed the title Qmix - Gru Net [Feature] Qmix - Gru Net Oct 28, 2022
Copy link
Collaborator

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wonderful just a couple of things to fix and we're good!

if "bidirectional" in gru_kwargs and gru_kwargs["bidirectional"]:
raise NotImplementedError("bidirectional GRU is not yet implemented.")

self.device = device
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should avoid having self.device
I know we have some remaining in the lib but it's bad practice as more and more models are dispatched over multiple devices nowadays.

self.mlp_out = MLP(**mlp_output_kwargs)

def forward(self, x, h_0=None):
if 2 > len(x.size()) > 3:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe if we have more than 3 dims, we could flatten and then unflatten the first dims (only if the rnn batch_first is True)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we enforce the gru to be batch first, I guess we can.

@PaLeroy
Copy link
Contributor Author

PaLeroy commented Nov 1, 2022

Finally, I think we can allow input of any dimension.
Like you suggested with dim > 3, we can also unsqueeze here for dim = 1.
As we enforce batch_first, I added the warning if someone tries to use GRUNet without batch_first.

Copy link
Collaborator

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is absolutely wonderful!
I can help to make a TensorDict wrapper around that, if needed

@vmoens vmoens merged commit d439075 into pytorch:qmix Nov 4, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants