Skip to content
135 changes: 135 additions & 0 deletions test/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from torchrl.modules import (
ActorValueOperator,
CEMPlanner,
GRUNet,
LSTMNet,
ProbabilisticActor,
QValueActor,
Expand Down Expand Up @@ -305,6 +306,140 @@ def test_lstm_net_nobatch(device, out_features, hidden_size):
torch.testing.assert_close(tds_vec["hidden1_out"][-1], tds_loop["hidden1_out"][-1])


@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("num_layers", [1]) # TODO: test stacked gru
@pytest.mark.parametrize(
"bidirectional", [False]
) # Todo: change if bidirectional implemented
def test_gru_net(device, num_layers, bidirectional):
batch_size = 5
seq_len = 7
in_features = 11
hidden_size = 13
out_features = 3

net = GRUNet(in_features, hidden_size, out_features, device=device)

# Test a whole sequence
# Test with dim = 1
x = torch.randn(in_features, device=device)
x_out, h = net(x)
assert x_out.size() == torch.Size([out_features])
assert h.size() == torch.Size([num_layers, hidden_size])

# Test with dim = 2
x = torch.randn(seq_len, in_features, device=device)
x_out, h = net(x)
assert x_out.size() == torch.Size([seq_len, out_features])
assert h.size() == torch.Size([num_layers, hidden_size])

# Test with dim = 3
x = torch.randn(batch_size, seq_len, in_features, device=device)
x_out, h = net(x)
assert x_out.size() == torch.Size([batch_size, seq_len, out_features])
assert h.size() == torch.Size([num_layers, batch_size, hidden_size])

# Test with dim > 3
x = torch.randn(2, 3, batch_size, seq_len, in_features, device=device)
x_out, h = net(x)
assert x_out.size() == torch.Size([2, 3, batch_size, seq_len, out_features])
assert h.size() == torch.Size([num_layers, 2, 3, batch_size, hidden_size])

# Test a sequence with intermediate hidden state
# Test with dim = 1
x = torch.randn(in_features, device=device)
x_out, h = net(x)
for i in range(5):
x_out, h = net(x, h)
assert x_out.size() == torch.Size([out_features])
assert h.size() == torch.Size([num_layers, hidden_size])

# Test with dim = 2
x = torch.randn(seq_len, in_features, device=device)
x_out, h = net(x)
for i in range(5):
x_out, h = net(x, h)
assert x_out.size() == torch.Size([seq_len, out_features])
assert h.size() == torch.Size([num_layers, hidden_size])

# Test with dim = 3
seq_len = 1
x = torch.randn(batch_size, seq_len, in_features, device=device)
x_out, h = net(x)
for i in range(5):
x_out, h = net(x, h)
assert x_out.size() == torch.Size([batch_size, seq_len, out_features])
assert h.size() == torch.Size([num_layers, batch_size, hidden_size])

# Test with dim > 3
seq_len = 1
x = torch.randn(2, 3, batch_size, seq_len, in_features, device=device)
x_out, h = net(x)
for i in range(5):
x_out, h = net(x, h)
assert x_out.size() == torch.Size([2, 3, batch_size, seq_len, out_features])
assert h.size() == torch.Size([num_layers, 2, 3, batch_size, hidden_size])

# Test instantiation safety
# Test mlp_input_kwargs["out_features"] != gru_kwargs["input_size"]
with pytest.raises(ValueError):
gru_kwargs = {"input_size": int(hidden_size / 2)}
GRUNet(
in_features, hidden_size, out_features, gru_kwargs=gru_kwargs, device=device
)
with pytest.raises(ValueError):
mlp_input_kwargs = {"out_features": int(hidden_size / 2)}
GRUNet(
in_features,
hidden_size,
out_features,
mlp_input_kwargs=mlp_input_kwargs,
device=device,
)

# Test mlp_output_kwargs["in_features"] != gru_kwargs["hidden_size"]
with pytest.raises(ValueError):
gru_kwargs = {"hidden_size": int(hidden_size / 2)}
GRUNet(
in_features, hidden_size, out_features, gru_kwargs=gru_kwargs, device=device
)
with pytest.raises(ValueError):
mlp_output_kwargs = {"in_features": int(hidden_size / 2)}
GRUNet(
in_features,
hidden_size,
out_features,
mlp_output_kwargs=mlp_output_kwargs,
device=device,
)

# Test gru_kwargs["bidirectional"]
with pytest.raises(NotImplementedError):
gru_kwargs = {
"bidirectional": True,
}
GRUNet(
in_features, hidden_size, out_features, gru_kwargs=gru_kwargs, device=device
)

# Test error if the input is of undesired shape
with pytest.raises(RuntimeError):
x = torch.randn(1, device=device)
net(x)
with pytest.raises(RuntimeError):
x = torch.randn(1, 1, 1, 1, device=device)
net(x)

# Test warning if batch_first False is asked
with pytest.warns(UserWarning):
gru_kwargs = {
"batch_first": False,
}
GRUNet(
in_features, hidden_size, out_features, gru_kwargs=gru_kwargs, device=device
)


class TestFunctionalModules:
def test_func_seq(self):
module = nn.Sequential(nn.Linear(3, 4), nn.Linear(4, 3))
Expand Down
219 changes: 218 additions & 1 deletion torchrl/modules/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import warnings
from numbers import Number
from typing import Dict, List, Optional, Sequence, Tuple, Type, Union

Expand All @@ -29,6 +29,7 @@
"DdpgMlpActor",
"DdpgMlpQNet",
"LSTMNet",
"GRUNet",
]


Expand Down Expand Up @@ -1034,3 +1035,219 @@ def forward(

input = self.mlp(input)
return self._lstm(input, hidden0_in, hidden1_in)


class GRUNet(nn.Module):
"""The GRUNet is a neural network composed of a GRU layer encapsulated between two MLPs.
It supports unbatched or batched (of any dim) sequences as input.
The time dimension is always the one preceding the features dimensions.
Unsequenced inputs are accepted but only unbatched (dim=1).

Args:
in_features (int): number of input features.
hidden_size (int): number of hidden features for the GRU.
out_features (int): number of output features.
mlp_input_kwargs (dict, optional): kwargs for the MLP before the GRU.
mlp_output_kwargs (dict, optional): kwargs for the MLP after the GRU.
gru_kwargs (dict, optional): kwargs for the GRU. The GRU is enforced to be batch_first.
GRUNet supports stacked GRU but not bidirectional.

By default, the GRUNet is:
MLP(in_features=in_features, out_features=hidden_size, depth=0)
GRU(in_features=hidden_size, out_features=hidden_size, batch_first=True)
MLP(in_features=hidden_size, out_features=out_features, depth=0)

If provided, the args dicts items are chosen over in_features, hidden_size or out_features.
They must respect that:
- The size of output features in mlp_input must be the same as the input size of the GRU.
- The size of input features in mlp_output must be the same as the last hidden size of the GRU.

Inputs:
x : Tensor of shape [..., L, in_features]. L is the sequence size.
An input of dim 1 is considered as an input of shape [1, in_features].
h_0 : Initial hidden state of shape [D*num_layers, ..., hidden_size] (zeros if None provided).
D = 1 always (no bidirectional) and num_layers = 1 by default (number of stacked GRU).

Outputs:
mlp_out : Tensor of shape [..., L, out_features].
last_h: Tensor of shape [D*num_layers, ..., hidden_size].

Examples:
>>> net = GRUNet(in_features=11, hidden_size=13, out_features=3)
>>> print(net)
GRUNet(
(mlp_in): MLP(
(0): Linear(in_features=11, out_features=13, bias=True)
)
(gru): GRU(13, 13, batch_first=True)
(mlp_out): MLP(
(0): Linear(in_features=13, out_features=3, bias=True)
)
)
>>> x_no_batch = torch.randn(11)
>>> out_no_batch, h_no_batch = net(x_no_batch)
>>> print(out_no_batch.shape, h_no_batch.shape)
torch.Size([3]) torch.Size([1, 13])
>>> x_no_batch = torch.randn(7, 11)
>>> out_no_batch, h_no_batch = net(x_no_batch)
>>> print(out_no_batch.shape, h_no_batch.shape)
torch.Size([7, 3]) torch.Size([1, 13])
>>> x_batch = torch.randn(5, 7, 11)
>>> out_batch, h_batch = net(x_batch)
>>> print(out_batch.shape, h_batch.shape)
torch.Size([5, 7, 3]) torch.Size([1, 5, 13])
>>> x_batch = torch.randn(3, 5, 7, 11)
>>> out_batch, h_batch = net(x_batch)
>>> print(out_batch.shape, h_batch.shape)
torch.Size([3, 5, 7, 3]) torch.Size([1, 3, 5, 13])
>>> net2 = GRUNet(
... in_features=11,
... hidden_size=13,
... out_features=3,
... mlp_input_kwargs={
... "depth": 0,
... "activation_class": nn.ReLU,
... "activate_last_layer": True,
... },
... )
>>> print(net2)
GRUNet(
(mlp_in): MLP(
(0): Linear(in_features=11, out_features=13, bias=True)
(1): ReLU()
)
(gru): GRU(13, 13, batch_first=True)
(mlp_out): MLP(
(0): Linear(in_features=13, out_features=3, bias=True)
)
)
>>> net_stacked = GRUNet(
... in_features=123,
... hidden_size=456,
... out_features=3,
... mlp_input_kwargs={
... "in_features": 11,
... "out_features": 13,
... },
... gru_kwargs={
... "input_size": 13,
... "hidden_size": 13,
... "num_layers": 2,
... },
... mlp_output_kwargs={
... "in_features": 13,
... "depth": 0
... },
... )
>>> print(net_stacked)
GRUNet(
(mlp_in): MLP(
(0): Linear(in_features=11, out_features=32, bias=True)
(1): Tanh()
(2): Linear(in_features=32, out_features=32, bias=True)
(3): Tanh()
(4): Linear(in_features=32, out_features=32, bias=True)
(5): Tanh()
(6): Linear(in_features=32, out_features=13, bias=True)
)
(gru): GRU(13, 13, num_layers=2, batch_first=True)
(mlp_out): MLP(
(0): Linear(in_features=13, out_features=3, bias=True)
)
)
>>> x_batch = torch.randn(5, 7, 11)
>>> out_batch_stacked, h_batch_stacked = net_stacked(x_batch)
>>> print(out_batch_stacked.shape, h_batch_stacked.shape)
torch.Size([5, 7, 3]) torch.Size([2, 5, 13])

"""

def __init__(
self,
in_features: int,
hidden_size: int,
out_features: int,
mlp_input_kwargs: Optional[dict] = None,
gru_kwargs: Optional[dict] = None,
mlp_output_kwargs: Optional[dict] = None,
device: DEVICE_TYPING = "cpu",
) -> None:
super().__init__()
if mlp_input_kwargs is None:
# Default config
mlp_input_kwargs = {
"in_features": in_features,
"out_features": hidden_size,
"depth": 0,
}
else:
# Test if in_features or hidden_size should be ignored
mlp_input_kwargs.setdefault("in_features", in_features)
mlp_input_kwargs.setdefault("out_features", hidden_size)

if gru_kwargs is None:
# Default config
gru_kwargs = {"input_size": hidden_size, "hidden_size": hidden_size}
else:
# Test if hidden_size should be ignored
gru_kwargs.setdefault("input_size", hidden_size)
gru_kwargs.setdefault("hidden_size", hidden_size)

if mlp_output_kwargs is None:
# Default config
mlp_output_kwargs = {
"in_features": hidden_size,
"out_features": out_features,
"depth": 0,
}
else:
# Test if hidden_size or out_features should be ignored
mlp_output_kwargs.setdefault("in_features", hidden_size)
mlp_output_kwargs.setdefault("out_features", out_features)

if mlp_input_kwargs["out_features"] != gru_kwargs["input_size"]:
raise ValueError(
"The size of output features in mlp_input must be the same as the input size of the GRU."
)
if mlp_output_kwargs["in_features"] != gru_kwargs["hidden_size"]:
raise ValueError(
"The size of input features in mlp_output must be the same as the last hidden size of the GRU."
)
if "bidirectional" in gru_kwargs and gru_kwargs["bidirectional"]:
raise NotImplementedError("bidirectional GRU is not yet implemented.")

if not gru_kwargs.get("batch_first", True):
warnings.warn(
"You set batch_first to False, but GRUNet enforces batch_first to be True, gru_kwargs will be updated."
)

mlp_input_kwargs.update({"device": device})
gru_kwargs.update({"device": device, "batch_first": True})
mlp_output_kwargs.update({"device": device})

self.mlp_in = MLP(**mlp_input_kwargs)
self.gru = nn.GRU(**gru_kwargs)
self.mlp_out = MLP(**mlp_output_kwargs)

def forward(self, x, h_0=None):
x_shape = x.shape
x_dim = len(x_shape)

if x_dim < 2:
x = x.unsqueeze(0)
if x_dim > 3:
x = x.flatten(end_dim=-3)
if h_0 is not None:
h_0 = h_0.flatten(start_dim=1, end_dim=-2)

mlp_in = self.mlp_in(x)
all_h, last_h = self.gru(mlp_in, h_0)
mlp_out = self.mlp_out(all_h)

if x_dim < 2:
mlp_out = mlp_out.squeeze()
if x_dim > 3:
mlp_out = mlp_out.unflatten(0, x_shape[:-2])
last_h = last_h.unflatten(1, x_shape[:-2])

return mlp_out, last_h