-
Notifications
You must be signed in to change notification settings - Fork 402
[Feature] Qmix - Gru Net #599
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 9 commits
6c1dbac
e6f1440
f937516
41039cf
aabcb0c
c8faed2
71bbbcd
d62af6e
39c6e9d
d3f2eb6
882a1b1
6b9e4c1
586816e
b58ac7f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,6 +29,7 @@ | |
"DdpgMlpActor", | ||
"DdpgMlpQNet", | ||
"LSTMNet", | ||
"GRUNet", | ||
] | ||
|
||
|
||
|
@@ -1034,3 +1035,191 @@ def forward( | |
|
||
input = self.mlp(input) | ||
return self._lstm(input, hidden0_in, hidden1_in) | ||
|
||
|
||
class GRUNet(nn.Module): | ||
"""The GRUNet is a neural network composed of a GRU layer encapsulated between two MLPs. | ||
It supports batched or unbatched sequences as input and the time dimension is always the one preceding the features dimensions. | ||
|
||
Args: | ||
in_features (int): number of input features. | ||
hidden_size (int): number of hidden features for the GRU. | ||
out_features (int): number of output features. | ||
mlp_input_kwargs (dict, optional): kwargs for the MLP before the GRU. | ||
mlp_output_kwargs (dict, optional): kwargs for the MLP after the GRU. | ||
gru_kwargs (dict, optional): kwargs for the GRU. The GRU is enforced to be batch_first. | ||
GRUNet supports stacked GRU but not bidirectional. | ||
|
||
By default, the GRUNet is: | ||
MLP(in_features=in_features, out_features=hidden_size, depth=0) | ||
GRU(in_features=hidden_size, out_features=hidden_size, batch_first=True) | ||
MLP(in_features=hidden_size, out_features=out_features, depth=0) | ||
|
||
If provided, the args dicts items are chosen over in_features, hidden_size or out_features. | ||
They must respect that: | ||
- The size of output features in mlp_input must be the same as the input size of the GRU. | ||
- The size of input features in mlp_output must be the same as the last hidden size of the GRU. | ||
|
||
N = batch size and L = sequence size. | ||
|
||
Inputs: | ||
x : Tensor of shape [L, in_features] or [N, L, in_features]. | ||
h_0 : Initial hidden state of shape [L, hidden_size] or [N, L, hidden_size]. (zeros if None provided) | ||
|
||
Outputs: | ||
mlp_out : Tensor of shape L, out_features] or [N, L, out_features]. | ||
last_h: Tensor of shape [D*num_layers, hidden_size] or [D*num_layers, N, hidden_size]. | ||
D = 1 always (no bidirectional) and num_layers = 1 by default (number of stacked GRU). | ||
|
||
Examples: | ||
>>> net = GRUNet(in_features=11, hidden_size=13, out_features=3) | ||
>>> print(net) | ||
GRUNet( | ||
(mlp_in): MLP( | ||
(0): Linear(in_features=11, out_features=13, bias=True) | ||
) | ||
(gru): GRU(13, 13, batch_first=True) | ||
(mlp_out): MLP( | ||
(0): Linear(in_features=13, out_features=3, bias=True) | ||
) | ||
) | ||
>>> x_no_batch = torch.randn(7, 11) | ||
>>> out_no_batch, h_no_batch = net(x_no_batch) | ||
>>> print(out_no_batch.shape, h_no_batch.shape) | ||
torch.Size([7, 3]) torch.Size([1, 13]) | ||
>>> x_batch = torch.randn(5, 7, 11) | ||
>>> out_batch, h_batch = net(x_batch) | ||
>>> print(out_batch.shape, h_batch.shape) | ||
torch.Size([5, 7, 3]) torch.Size([1, 5, 13]) | ||
>>> net2 = GRUNet( | ||
>>> in_features=11, | ||
>>> hidden_size=13, | ||
>>> out_features=3, | ||
>>> mlp_input_kwargs={ | ||
>>> "depth": 0, | ||
>>> "activation_class": nn.ReLU, | ||
>>> "activate_last_layer": True, | ||
>>> }, | ||
>>> ) | ||
>>> print(net2) | ||
GRUNet( | ||
(mlp_in): MLP( | ||
(0): Linear(in_features=11, out_features=13, bias=True) | ||
(1): ReLU() | ||
) | ||
(gru): GRU(13, 13, batch_first=True) | ||
(mlp_out): MLP( | ||
(0): Linear(in_features=13, out_features=3, bias=True) | ||
) | ||
) | ||
>>> net_stacked = GRUNet( | ||
PaLeroy marked this conversation as resolved.
Show resolved
Hide resolved
|
||
>>> in_features=123, | ||
>>> hidden_size=456, | ||
>>> out_features=3, | ||
>>> mlp_input_kwargs={ | ||
>>> "in_features": 11, | ||
>>> "out_features": 13, | ||
>>> }, | ||
>>> gru_kwargs={ | ||
>>> "input_size": 13, | ||
>>> "hidden_size": 13, | ||
>>> "num_layers": 2, | ||
>>> }, | ||
>>> mlp_output_kwargs={ | ||
>>> "in_features": 13, | ||
>>> "depth": 0 | ||
>>> }, | ||
>>> ) | ||
>>> print(net_stacked) | ||
GRUNet( | ||
(mlp_in): MLP( | ||
(0): Linear(in_features=11, out_features=32, bias=True) | ||
(1): Tanh() | ||
(2): Linear(in_features=32, out_features=32, bias=True) | ||
(3): Tanh() | ||
(4): Linear(in_features=32, out_features=32, bias=True) | ||
(5): Tanh() | ||
(6): Linear(in_features=32, out_features=13, bias=True) | ||
) | ||
(gru): GRU(13, 13, num_layers=2, batch_first=True) | ||
(mlp_out): MLP( | ||
(0): Linear(in_features=13, out_features=3, bias=True) | ||
) | ||
) | ||
>>> x_batch = torch.randn(5, 7, 11) | ||
>>> out_batch_stacked, h_batch_stacked = net_stacked(x_batch) | ||
>>> print(out_batch_stacked.shape, h_batch_stacked.shape) | ||
torch.Size([5, 7, 3]) torch.Size([2, 5, 13]) | ||
|
||
""" | ||
|
||
def __init__( | ||
self, | ||
in_features: int, | ||
hidden_size: int, | ||
out_features: int, | ||
mlp_input_kwargs: Optional[dict] = None, | ||
gru_kwargs: Optional[dict] = None, | ||
mlp_output_kwargs: Optional[dict] = None, | ||
device: DEVICE_TYPING = "cpu", | ||
) -> None: | ||
super().__init__() | ||
if mlp_input_kwargs is None: | ||
# Default config | ||
mlp_input_kwargs = { | ||
"in_features": in_features, | ||
"out_features": hidden_size, | ||
"depth": 0, | ||
} | ||
else: | ||
# Test if in_features or hidden_size should be ignored | ||
mlp_input_kwargs.setdefault("in_features", in_features) | ||
mlp_input_kwargs.setdefault("out_features", hidden_size) | ||
|
||
if gru_kwargs is None: | ||
# Default config | ||
gru_kwargs = {"input_size": hidden_size, "hidden_size": hidden_size} | ||
else: | ||
# Test if hidden_size should be ignored | ||
gru_kwargs.setdefault("input_size", hidden_size) | ||
gru_kwargs.setdefault("hidden_size", hidden_size) | ||
|
||
if mlp_output_kwargs is None: | ||
# Default config | ||
mlp_output_kwargs = { | ||
"in_features": hidden_size, | ||
"out_features": out_features, | ||
"depth": 0, | ||
} | ||
else: | ||
# Test if hidden_size or out_features should be ignored | ||
mlp_output_kwargs.setdefault("in_features", hidden_size) | ||
mlp_output_kwargs.setdefault("out_features", out_features) | ||
|
||
if mlp_input_kwargs["out_features"] != gru_kwargs["input_size"]: | ||
raise ValueError( | ||
"The size of output features in mlp_input must be the same as the input size of the GRU." | ||
) | ||
if mlp_output_kwargs["in_features"] != gru_kwargs["hidden_size"]: | ||
raise ValueError( | ||
"The size of input features in mlp_output must be the same as the last hidden size of the GRU." | ||
) | ||
if "bidirectional" in gru_kwargs and gru_kwargs["bidirectional"]: | ||
raise NotImplementedError("bidirectional GRU is not yet implemented.") | ||
|
||
self.device = device | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we should avoid having |
||
mlp_input_kwargs.update({"device": self.device}) | ||
gru_kwargs.update({"device": self.device, "batch_first": True}) | ||
mlp_output_kwargs.update({"device": self.device}) | ||
|
||
self.mlp_in = MLP(**mlp_input_kwargs) | ||
self.gru = nn.GRU(**gru_kwargs) | ||
self.mlp_out = MLP(**mlp_output_kwargs) | ||
|
||
def forward(self, x, h_0=None): | ||
if 2 > len(x.size()) > 3: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe if we have more than 3 dims, we could flatten and then unflatten the first dims (only if the rnn batch_first is True)? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since we enforce the gru to be batch first, I guess we can. |
||
raise RuntimeError("Input size must be of size 2 or 3.") | ||
mlp_in = self.mlp_in(x) | ||
all_h, last_h = self.gru(mlp_in, h_0) | ||
mlp_out = self.mlp_out(all_h) | ||
return mlp_out, last_h |
Uh oh!
There was an error while loading. Please reload this page.