-
Notifications
You must be signed in to change notification settings - Fork 402
[Feature] Qmix - Gru Net #599
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov Report
@@ Coverage Diff @@
## qmix #599 +/- ##
==========================================
+ Coverage 86.79% 86.86% +0.06%
==========================================
Files 121 121
Lines 21667 21791 +124
==========================================
+ Hits 18806 18928 +122
- Misses 2861 2863 +2
Flags with carried forward coverage won't be shown. Click here to find out more.
📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more |
torchrl/modules/models/models.py
Outdated
""" | ||
|
||
def __init__( | ||
self, mlp_input_kwargs: Dict, gru_kwargs: Dict, mlp_output_kwargs: Dict |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it would be nice to have a device
argument here, almost all torch modules now support instantiation on device
torchrl/modules/models/models.py
Outdated
|
||
def forward(self, inputs, hidden_state): | ||
mlp_in = self.mlp_in(inputs) | ||
hidden_state = self.gru(mlp_in, hidden_state) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I expect that if you give a sequence as input this module will not behave as expected right?
Should we enforce that the input shapes match some requirement? Like no 3d input?
Problem with that 3d inputs are only an indirect indicator that someone is passing tensors that are time-stamped, it might as well be a 2d batch size. Plus, a 1d batch size might be time-stamped too :p
Perhaps a good docstring would go a long way. Can we put in the docstring of this class what kind of input we expect?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I never worked with GRU that takes full sequences... It appears that I missed a nice feature.
Indeed, the current implementation will not work well when taking a sequence as input.
I'll rework this as it is done in the GRU, and using GRU instead of GRUCell then (enforcing batch first)
input: tensor of shape :math:
(L, H_{in})
for unbatched input,
:math:(N, L, H_{in})
whenbatch_first=True
containing the features of
the input sequence.
This should lead to something closer to the LSTMNet already implemented.
I changed to GRU instead of GRUCell. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Almost there! It's still a bit fuzzy to me how we're supposed to work with these modules in a way that handles single steps and trajectories interchangeably (e.g. using the same module at loss compute time and during policy execution while benefitting from the cudnn kernel at loss time).
But it's more an open question than a comment about this PR.
One option could be to use first-class dimensions, I'll explore that and keep you posted.
torchrl/modules/models/models.py
Outdated
|
||
Args: | ||
mlp_input_kwargs (dict): kwargs for the MLP before the GRU | ||
mlp_output_kwargs (dict): kwargs for the MLP after the GRU |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could be nice to have default values for those three modules, but that's open to discussion.
For GRU, usually in torchrl we're using batch_first=True, and the time dimension is always the one preceding the last dim. At least that should be the default.
Can we put an example in the docstring?
Examples:
>>> grunet = GRUNet(
... mlp_input_kwargs={"feature_in": ...},
... **etc)
>>> foo = torch.randn(*shape)
>>> bar = grunet(foo)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another possibility is to have input_features
, hidden_size
and output_features
as non-optional arguments that would create a default GRUNet such as the one in the example I just pushed, and having these 3 param dictionaries as optional and used only if a different network configuration is required.
What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds great!
torchrl/modules/models/models.py
Outdated
N = batch size | ||
L = sequence size | ||
mlp_input_in = input size of the MLP_input | ||
mlp_input_out = output size of the MLP_input | ||
H_in = input size of the GRU | ||
H_out = output size of the GRU | ||
mlp_output_in = input size of the MLP_output | ||
mlp_output_out = output size of the MLP_output |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What are those fields? I don't see them anywhere
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only present in the docstring to document sizes of inputs and outputs.
They would become irrelevant with the suggestion I just made to change the arguments of GRUNet.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wonderful just a couple of things to fix and we're good!
torchrl/modules/models/models.py
Outdated
if "bidirectional" in gru_kwargs and gru_kwargs["bidirectional"]: | ||
raise NotImplementedError("bidirectional GRU is not yet implemented.") | ||
|
||
self.device = device |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should avoid having self.device
I know we have some remaining in the lib but it's bad practice as more and more models are dispatched over multiple devices nowadays.
torchrl/modules/models/models.py
Outdated
self.mlp_out = MLP(**mlp_output_kwargs) | ||
|
||
def forward(self, x, h_0=None): | ||
if 2 > len(x.size()) > 3: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe if we have more than 3 dims, we could flatten and then unflatten the first dims (only if the rnn batch_first is True)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we enforce the gru to be batch first, I guess we can.
Finally, I think we can allow input of any dimension. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is absolutely wonderful!
I can help to make a TensorDict wrapper around that, if needed
Description
Creation of the GRUNet to control agents in QMIX
Motivation and Context
Types of changes
What types of changes does your code introduce? Remove all that do not apply:
Checklist
Go over all the following points, and put an
x
in all the boxes that apply.If you are unsure about any of these, don't hesitate to ask. We are here to help!