Skip to content
98 changes: 98 additions & 0 deletions test/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from torchrl.modules import (
ActorValueOperator,
CEMPlanner,
GRUNet,
LSTMNet,
ProbabilisticActor,
QValueActor,
Expand Down Expand Up @@ -305,6 +306,103 @@ def test_lstm_net_nobatch(device, out_features, hidden_size):
torch.testing.assert_close(tds_vec["hidden1_out"][-1], tds_loop["hidden1_out"][-1])


@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("num_layers", [1]) # TODO: test stacked gru
@pytest.mark.parametrize(
"bidirectional", [False]
) # Todo: change if bidirectional implemented
def test_gru_net(device, num_layers, bidirectional):
batch_size = 5
seq_len = 7
in_features = 11
hidden_size = 13
out_features = 3

net = GRUNet(in_features, hidden_size, out_features, device=device)

# Test a whole sequence
# Test with batch size
x = torch.randn(batch_size, seq_len, in_features, device=device)
x_out, h = net(x)
assert x_out.size() == torch.Size([batch_size, seq_len, out_features])
assert h.size() == torch.Size([num_layers, batch_size, hidden_size])

# Test without batch_size
x = torch.randn(seq_len, in_features, device=device)
x_out, h = net(x)
assert x_out.size() == torch.Size([seq_len, out_features])
assert h.size() == torch.Size([num_layers, hidden_size])

# Test a sequence with intermediate hidden state
# Test with batch size
seq_len = 1
x = torch.randn(batch_size, seq_len, in_features, device=device)
x_out, h = net(x)
for i in range(5):
x_out, h = net(x, h)
assert x_out.size() == torch.Size([batch_size, seq_len, out_features])
assert h.size() == torch.Size([num_layers, batch_size, hidden_size])

# Test without batch size
x = torch.randn(seq_len, in_features, device=device)
x_out, h = net(x)
for i in range(5):
x_out, h = net(x, h)
assert x_out.size() == torch.Size([seq_len, out_features])
assert h.size() == torch.Size([num_layers, hidden_size])

# Test instantiation safety

# Test mlp_input_kwargs["out_features"] != gru_kwargs["input_size"]
with pytest.raises(ValueError):
gru_kwargs = {"input_size": int(hidden_size / 2)}
GRUNet(
in_features, hidden_size, out_features, gru_kwargs=gru_kwargs, device=device
)
with pytest.raises(ValueError):
mlp_input_kwargs = {"out_features": int(hidden_size / 2)}
GRUNet(
in_features,
hidden_size,
out_features,
mlp_input_kwargs=mlp_input_kwargs,
device=device,
)

# Test mlp_output_kwargs["in_features"] != gru_kwargs["hidden_size"]
with pytest.raises(ValueError):
gru_kwargs = {"hidden_size": int(hidden_size / 2)}
GRUNet(
in_features, hidden_size, out_features, gru_kwargs=gru_kwargs, device=device
)
with pytest.raises(ValueError):
mlp_output_kwargs = {"in_features": int(hidden_size / 2)}
GRUNet(
in_features,
hidden_size,
out_features,
mlp_output_kwargs=mlp_output_kwargs,
device=device,
)

# Test gru_kwargs["bidirectional"]
with pytest.raises(NotImplementedError):
gru_kwargs = {
"bidirectional": True,
}
GRUNet(
in_features, hidden_size, out_features, gru_kwargs=gru_kwargs, device=device
)

# Test error if the input is of undesired shape
with pytest.raises(RuntimeError):
x = torch.randn(1, device=device)
net(x)
with pytest.raises(RuntimeError):
x = torch.randn(1, 1, 1, 1, device=device)
net(x)


class TestFunctionalModules:
def test_func_seq(self):
module = nn.Sequential(nn.Linear(3, 4), nn.Linear(4, 3))
Expand Down
189 changes: 189 additions & 0 deletions torchrl/modules/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
"DdpgMlpActor",
"DdpgMlpQNet",
"LSTMNet",
"GRUNet",
]


Expand Down Expand Up @@ -1034,3 +1035,191 @@ def forward(

input = self.mlp(input)
return self._lstm(input, hidden0_in, hidden1_in)


class GRUNet(nn.Module):
"""The GRUNet is a neural network composed of a GRU layer encapsulated between two MLPs.
It supports batched or unbatched sequences as input and the time dimension is always the one preceding the features dimensions.

Args:
in_features (int): number of input features.
hidden_size (int): number of hidden features for the GRU.
out_features (int): number of output features.
mlp_input_kwargs (dict, optional): kwargs for the MLP before the GRU.
mlp_output_kwargs (dict, optional): kwargs for the MLP after the GRU.
gru_kwargs (dict, optional): kwargs for the GRU. The GRU is enforced to be batch_first.
GRUNet supports stacked GRU but not bidirectional.

By default, the GRUNet is:
MLP(in_features=in_features, out_features=hidden_size, depth=0)
GRU(in_features=hidden_size, out_features=hidden_size, batch_first=True)
MLP(in_features=hidden_size, out_features=out_features, depth=0)

If provided, the args dicts items are chosen over in_features, hidden_size or out_features.
They must respect that:
- The size of output features in mlp_input must be the same as the input size of the GRU.
- The size of input features in mlp_output must be the same as the last hidden size of the GRU.

N = batch size and L = sequence size.

Inputs:
x : Tensor of shape [L, in_features] or [N, L, in_features].
h_0 : Initial hidden state of shape [L, hidden_size] or [N, L, hidden_size]. (zeros if None provided)

Outputs:
mlp_out : Tensor of shape L, out_features] or [N, L, out_features].
last_h: Tensor of shape [D*num_layers, hidden_size] or [D*num_layers, N, hidden_size].
D = 1 always (no bidirectional) and num_layers = 1 by default (number of stacked GRU).

Examples:
>>> net = GRUNet(in_features=11, hidden_size=13, out_features=3)
>>> print(net)
GRUNet(
(mlp_in): MLP(
(0): Linear(in_features=11, out_features=13, bias=True)
)
(gru): GRU(13, 13, batch_first=True)
(mlp_out): MLP(
(0): Linear(in_features=13, out_features=3, bias=True)
)
)
>>> x_no_batch = torch.randn(7, 11)
>>> out_no_batch, h_no_batch = net(x_no_batch)
>>> print(out_no_batch.shape, h_no_batch.shape)
torch.Size([7, 3]) torch.Size([1, 13])
>>> x_batch = torch.randn(5, 7, 11)
>>> out_batch, h_batch = net(x_batch)
>>> print(out_batch.shape, h_batch.shape)
torch.Size([5, 7, 3]) torch.Size([1, 5, 13])
>>> net2 = GRUNet(
>>> in_features=11,
>>> hidden_size=13,
>>> out_features=3,
>>> mlp_input_kwargs={
>>> "depth": 0,
>>> "activation_class": nn.ReLU,
>>> "activate_last_layer": True,
>>> },
>>> )
>>> print(net2)
GRUNet(
(mlp_in): MLP(
(0): Linear(in_features=11, out_features=13, bias=True)
(1): ReLU()
)
(gru): GRU(13, 13, batch_first=True)
(mlp_out): MLP(
(0): Linear(in_features=13, out_features=3, bias=True)
)
)
>>> net_stacked = GRUNet(
>>> in_features=123,
>>> hidden_size=456,
>>> out_features=3,
>>> mlp_input_kwargs={
>>> "in_features": 11,
>>> "out_features": 13,
>>> },
>>> gru_kwargs={
>>> "input_size": 13,
>>> "hidden_size": 13,
>>> "num_layers": 2,
>>> },
>>> mlp_output_kwargs={
>>> "in_features": 13,
>>> "depth": 0
>>> },
>>> )
>>> print(net_stacked)
GRUNet(
(mlp_in): MLP(
(0): Linear(in_features=11, out_features=32, bias=True)
(1): Tanh()
(2): Linear(in_features=32, out_features=32, bias=True)
(3): Tanh()
(4): Linear(in_features=32, out_features=32, bias=True)
(5): Tanh()
(6): Linear(in_features=32, out_features=13, bias=True)
)
(gru): GRU(13, 13, num_layers=2, batch_first=True)
(mlp_out): MLP(
(0): Linear(in_features=13, out_features=3, bias=True)
)
)
>>> x_batch = torch.randn(5, 7, 11)
>>> out_batch_stacked, h_batch_stacked = net_stacked(x_batch)
>>> print(out_batch_stacked.shape, h_batch_stacked.shape)
torch.Size([5, 7, 3]) torch.Size([2, 5, 13])

"""

def __init__(
self,
in_features: int,
hidden_size: int,
out_features: int,
mlp_input_kwargs: Optional[dict] = None,
gru_kwargs: Optional[dict] = None,
mlp_output_kwargs: Optional[dict] = None,
device: DEVICE_TYPING = "cpu",
) -> None:
super().__init__()
if mlp_input_kwargs is None:
# Default config
mlp_input_kwargs = {
"in_features": in_features,
"out_features": hidden_size,
"depth": 0,
}
else:
# Test if in_features or hidden_size should be ignored
mlp_input_kwargs.setdefault("in_features", in_features)
mlp_input_kwargs.setdefault("out_features", hidden_size)

if gru_kwargs is None:
# Default config
gru_kwargs = {"input_size": hidden_size, "hidden_size": hidden_size}
else:
# Test if hidden_size should be ignored
gru_kwargs.setdefault("input_size", hidden_size)
gru_kwargs.setdefault("hidden_size", hidden_size)

if mlp_output_kwargs is None:
# Default config
mlp_output_kwargs = {
"in_features": hidden_size,
"out_features": out_features,
"depth": 0,
}
else:
# Test if hidden_size or out_features should be ignored
mlp_output_kwargs.setdefault("in_features", hidden_size)
mlp_output_kwargs.setdefault("out_features", out_features)

if mlp_input_kwargs["out_features"] != gru_kwargs["input_size"]:
raise ValueError(
"The size of output features in mlp_input must be the same as the input size of the GRU."
)
if mlp_output_kwargs["in_features"] != gru_kwargs["hidden_size"]:
raise ValueError(
"The size of input features in mlp_output must be the same as the last hidden size of the GRU."
)
if "bidirectional" in gru_kwargs and gru_kwargs["bidirectional"]:
raise NotImplementedError("bidirectional GRU is not yet implemented.")

self.device = device
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should avoid having self.device
I know we have some remaining in the lib but it's bad practice as more and more models are dispatched over multiple devices nowadays.

mlp_input_kwargs.update({"device": self.device})
gru_kwargs.update({"device": self.device, "batch_first": True})
mlp_output_kwargs.update({"device": self.device})

self.mlp_in = MLP(**mlp_input_kwargs)
self.gru = nn.GRU(**gru_kwargs)
self.mlp_out = MLP(**mlp_output_kwargs)

def forward(self, x, h_0=None):
if 2 > len(x.size()) > 3:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe if we have more than 3 dims, we could flatten and then unflatten the first dims (only if the rnn batch_first is True)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we enforce the gru to be batch first, I guess we can.

raise RuntimeError("Input size must be of size 2 or 3.")
mlp_in = self.mlp_in(x)
all_h, last_h = self.gru(mlp_in, h_0)
mlp_out = self.mlp_out(all_h)
return mlp_out, last_h