Skip to content

Commit 081fd5e

Browse files
committed
RF LSTM cleanup and fixes
#1120 (comment)
1 parent 526ff23 commit 081fd5e

File tree

4 files changed

+92
-60
lines changed

4 files changed

+92
-60
lines changed

returnn/frontend/_backend.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,7 @@
66
from typing import TYPE_CHECKING, Optional, Any, Union, TypeVar, Generic, Type, Sequence, Dict, Tuple
77
import contextlib
88
import numpy
9-
109
import returnn.frontend as rf
11-
from . import State
1210

1311
if TYPE_CHECKING:
1412
from returnn.tensor import Tensor, Dim
@@ -911,25 +909,31 @@ def pool(
911909
@staticmethod
912910
def lstm(
913911
source: Tensor,
914-
state: State,
912+
*,
913+
state_c: Tensor,
914+
state_h: Tensor,
915915
ff_weights: Tensor,
916916
ff_biases: Tensor,
917917
rec_weights: Tensor,
918918
rec_biases: Tensor,
919919
spatial_dim: Dim,
920+
in_dim: Dim,
920921
out_dim: Dim,
921-
) -> Tuple[Tensor, State]:
922+
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
922923
"""
923924
Functional LSTM.
924925
925926
:param source: Tensor of shape [*, in_dim].
926-
:param state: State of the LSTM.
927+
:param state_c:
928+
:param state_h:
927929
:param ff_weights: Parameters for the weights of the feed-forward part.
928930
:param ff_biases: Parameters for the biases of the feed-forward part.
929931
:param rec_weights: Parameters for the weights of the recurrent part.
930932
:param rec_biases: Parameters for the biases of the recurrent part.
931933
:param spatial_dim: Dimension in which the LSTM operates.
934+
:param in_dim:
932935
:param out_dim:
936+
:return: output, (state_h, state_c)
933937
"""
934938
raise NotImplementedError
935939

returnn/frontend/lstm.py

+24-13
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,12 @@
77
from typing import Tuple, TypeVar
88

99
import returnn.frontend as rf
10-
from returnn.frontend import State
1110
from returnn.tensor import Tensor, Dim
1211

1312

1413
T = TypeVar("T")
1514

16-
__all__ = ["LSTM"]
15+
__all__ = ["LSTM", "LstmState"]
1716

1817

1918
class LSTM(rf.Module):
@@ -36,27 +35,27 @@ def __init__(
3635
self.in_dim = in_dim
3736
self.out_dim = out_dim
3837

39-
self.ff_weights = rf.Parameter((self.out_dim * 4, self.in_dim)) # type: Tensor[T]
38+
self.ff_weights = rf.Parameter((4 * self.out_dim, self.in_dim)) # type: Tensor[T]
4039
self.ff_weights.initial = rf.init.Glorot()
41-
self.recurrent_weights = rf.Parameter((self.out_dim * 4, self.out_dim)) # type: Tensor[T]
40+
self.recurrent_weights = rf.Parameter((4 * self.out_dim, self.out_dim)) # type: Tensor[T]
4241
self.recurrent_weights.initial = rf.init.Glorot()
4342

4443
self.ff_biases = None
4544
self.recurrent_biases = None
4645
if with_bias:
47-
self.ff_biases = rf.Parameter((self.out_dim * 4,)) # type: Tensor[T]
46+
self.ff_biases = rf.Parameter((4 * self.out_dim,)) # type: Tensor[T]
4847
self.ff_biases.initial = 0.0
49-
self.recurrent_biases = rf.Parameter((self.out_dim * 4,)) # type: Tensor[T]
48+
self.recurrent_biases = rf.Parameter((4 * self.out_dim,)) # type: Tensor[T]
5049
self.recurrent_biases.initial = 0.0
5150

52-
def __call__(self, source: Tensor[T], state: State) -> Tuple[Tensor, State]:
51+
def __call__(self, source: Tensor[T], *, state: LstmState, spatial_dim: Dim) -> Tuple[Tensor, LstmState]:
5352
"""
5453
Forward call of the LSTM.
5554
56-
:param source: Tensor of size ``[*, in_dim]``.
57-
:param state: State of the LSTM. Contains two :class:`Tensor`: ``state.h`` as the hidden state,
58-
and ``state.c`` as the cell state. Both are of shape ``[out_dim]``.
59-
:return: Output of forward as a :class:`Tensor` of size ``[*, out_dim]``, and next LSTM state.
55+
:param source: Tensor of size {...,in_dim} if spatial_dim is single_step_dim else {...,spatial_dim,in_dim}.
56+
:param state: State of the LSTM. Both h and c are of shape {...,out_dim}.
57+
:return: output of shape {...,out_dim} if spatial_dim is single_step_dim else {...,spatial_dim,out_dim},
58+
and new state of the LSTM.
6059
"""
6160
if not state.h or not state.c:
6261
raise ValueError(f"{self}: state {state} needs attributes ``h`` (hidden) and ``c`` (cell).")
@@ -66,13 +65,25 @@ def __call__(self, source: Tensor[T], state: State) -> Tuple[Tensor, State]:
6665
# noinspection PyProtectedMember
6766
result, new_state = source._raw_backend.lstm(
6867
source=source,
69-
state=state,
68+
state_c=state.c,
69+
state_h=state.h,
7070
ff_weights=self.ff_weights,
7171
ff_biases=self.ff_biases,
7272
rec_weights=self.recurrent_weights,
7373
rec_biases=self.recurrent_biases,
74-
spatial_dim=self.in_dim,
74+
spatial_dim=spatial_dim,
75+
in_dim=self.in_dim,
7576
out_dim=self.out_dim,
7677
)
78+
new_state = LstmState(*new_state)
7779

7880
return result, new_state
81+
82+
83+
class LstmState(rf.State):
84+
"""LSTM state"""
85+
86+
def __init__(self, h: Tensor, c: Tensor):
87+
super().__init__()
88+
self.h = h
89+
self.c = c

returnn/torch/frontend/_backend.py

+49-33
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
# noinspection PyProtectedMember
1515
from returnn.frontend._backend import Backend
16-
from returnn.frontend import RawTensorTypes, State
16+
from returnn.frontend import RawTensorTypes
1717
import returnn.frontend as rf
1818

1919

@@ -1142,26 +1142,20 @@ def pool(
11421142
@staticmethod
11431143
def lstm(
11441144
source: _TT,
1145-
state: State,
1145+
*,
1146+
state_h: _TT,
1147+
state_c: _TT,
11461148
ff_weights: _TT,
11471149
ff_biases: Optional[_TT],
11481150
rec_weights: _TT,
11491151
rec_biases: Optional[_TT],
11501152
spatial_dim: Dim,
1153+
in_dim: Dim,
11511154
out_dim: Dim,
1152-
) -> Tuple[_TT, State]:
1155+
) -> Tuple[_TT, Tuple[_TT, _TT]]:
11531156
"""
11541157
Wraps the functional LSTM from PyTorch.
11551158
1156-
:param source: Tensor of shape ``[*, in_dim]``.
1157-
:param state: State of the LSTM.
1158-
:param ff_weights: Parameters for the weights of the feed-forward part.
1159-
:param ff_biases: Parameters for the biases of the feed-forward part.
1160-
:param rec_weights: Parameters for the weights of the recurrent part.
1161-
:param rec_biases: Parameters for the biases of the recurrent part.
1162-
:param spatial_dim: Dimension in which the LSTM operates.
1163-
:param out_dim:
1164-
11651159
:return: Tuple consisting of two elements: the result as a :class:`Tensor`
11661160
and the new state as a :class:`State` (different from the previous one).
11671161
"""
@@ -1176,33 +1170,55 @@ def lstm(
11761170
# or torch LSTMCell: https://github.com/pytorch/pytorch/blob/4bead64/aten/src/ATen/native/RNN.cpp#L1458
11771171
lstm_params = (ff_weights.raw_tensor, rec_weights.raw_tensor, ff_biases.raw_tensor, rec_biases.raw_tensor)
11781172
has_biases = True
1179-
batch_first = source.dims[0].is_batch_dim()
11801173

1181-
raw_result, raw_new_hidden_state, raw_new_cell_state = torch.lstm(
1182-
source.raw_tensor,
1183-
(state.h.raw_tensor, state.c.raw_tensor),
1174+
batch_dims = [d for d in source.dims if d != spatial_dim and d != in_dim]
1175+
source = source.copy_transpose([spatial_dim] + batch_dims + [in_dim])
1176+
state_h = state_h.copy_transpose(batch_dims + [out_dim])
1177+
state_c = state_c.copy_transpose(batch_dims + [out_dim])
1178+
1179+
source_raw = source.raw_tensor
1180+
state_h_raw = state_h.raw_tensor
1181+
state_c_raw = state_c.raw_tensor
1182+
batch_dim = torch.prod(torch.tensor([d.get_dim_value() for d in batch_dims])) if batch_dims else 1
1183+
if len(batch_dims) != 1:
1184+
# Torch LSTM expects (seq_len, batch, input_size) as shape.
1185+
# We need to merge all batch dims together.
1186+
source_raw = torch.reshape(
1187+
source_raw, [spatial_dim.get_dim_value()] + [batch_dim] + [in_dim.get_dim_value()]
1188+
)
1189+
# Torch LSTM expects (num_layers * num_directions, batch, hidden_size) as shape.
1190+
state_h_raw = torch.reshape(state_h_raw, [1, batch_dim, out_dim.get_dim_value()])
1191+
state_c_raw = torch.reshape(state_c_raw, [1, batch_dim, out_dim.get_dim_value()])
1192+
1193+
out_raw, new_state_h_raw, new_state_c_raw = torch.lstm(
1194+
source_raw,
1195+
(state_h_raw, state_c_raw),
11841196
lstm_params,
11851197
has_biases=has_biases,
11861198
num_layers=1,
11871199
dropout=0.0,
11881200
train=rf.get_run_ctx().train_flag,
11891201
bidirectional=False,
1190-
batch_first=batch_first,
1202+
batch_first=False,
11911203
)
11921204

1193-
result_dims = list(source.dims)
1194-
index_in_dim = [i for i, dim in enumerate(result_dims) if dim == spatial_dim]
1195-
assert len(index_in_dim) == 1, "There are multiple spatial dimensions for the input tensor."
1196-
result_dims[index_in_dim[0]] = out_dim
1197-
result = Tensor(name="lstm", dims=result_dims, raw_tensor=raw_result, dtype="float32")
1198-
result.feature_dim = out_dim
1199-
1200-
new_state = State()
1201-
new_hidden_state = state.h.copy_template()
1202-
new_hidden_state.raw_tensor = raw_new_hidden_state
1203-
new_state.h = new_hidden_state
1204-
new_cell_state = state.c.copy_template()
1205-
new_cell_state.raw_tensor = raw_new_cell_state
1206-
new_state.c = new_cell_state
1207-
1208-
return result, new_state
1205+
if len(batch_dims) != 1:
1206+
out_raw = torch.reshape(
1207+
out_raw,
1208+
[spatial_dim.get_dim_value()] + [d.get_dim_value() for d in batch_dims] + [out_dim.get_dim_value()],
1209+
)
1210+
new_state_h_raw = torch.reshape(new_state_h_raw, [d.get_dim_value() for d in state_h.dims])
1211+
new_state_c_raw = torch.reshape(new_state_c_raw, [d.get_dim_value() for d in state_c.dims])
1212+
1213+
out = source.copy_template_replace_dim_tag(axis=-1, new_dim_tag=out_dim, name="lstm")
1214+
out.feature_dim = out_dim
1215+
out.raw_tensor = out_raw
1216+
1217+
new_state_h = state_h.copy_template()
1218+
new_state_h.raw_tensor = new_state_h_raw
1219+
new_state_h.feature_dim = out_dim
1220+
new_state_c = state_c.copy_template()
1221+
new_state_c.raw_tensor = new_state_c_raw
1222+
new_state_c.feature_dim = out_dim
1223+
1224+
return out, (new_state_h, new_state_c)

tests/test_rf_lstm.py

+10-9
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from typing import Tuple
77
import _setup_test_env # noqa
88
import returnn.frontend as rf
9-
from returnn.frontend import State
109
from returnn.tensor import Tensor, Dim, TensorDict, batch_dim
1110
from rf_utils import run_model
1211

@@ -26,16 +25,18 @@ def __init__(self):
2625
super().__init__()
2726
self.lstm = rf.LSTM(in_dim, out_dim)
2827

29-
def __call__(self, x: Tensor, s: State) -> Tuple[Tensor, State]:
30-
return self.lstm(x, s)
28+
def __call__(self, x: Tensor, *, spatial_dim: Dim, state: rf.LstmState) -> Tuple[Tensor, rf.LstmState]:
29+
return self.lstm(x, state=state, spatial_dim=spatial_dim)
3130

3231
# noinspection PyShadowingNames
3332
def _forward_step(*, model: _Net, extern_data: TensorDict):
34-
first_dim_state = Dim(1, name="blstm_times_nlayers")
35-
state = State()
36-
state.h = rf.random(distribution="normal", dims=[first_dim_state, batch_dim, out_dim], dtype="float32")
37-
state.c = rf.random(distribution="normal", dims=[first_dim_state, batch_dim, out_dim], dtype="float32")
38-
out, new_state = model(extern_data["data"], state)
39-
out.mark_as_default_output()
33+
state = rf.LstmState(
34+
h=rf.random(distribution="normal", dims=[batch_dim, out_dim], dtype="float32"),
35+
c=rf.random(distribution="normal", dims=[batch_dim, out_dim], dtype="float32"),
36+
)
37+
out, new_state = model(extern_data["data"], state=state, spatial_dim=time_dim)
38+
out.mark_as_output("out", shape=(batch_dim, time_dim, out_dim))
39+
new_state.h.mark_as_output("h", shape=(batch_dim, out_dim))
40+
new_state.c.mark_as_output("c", shape=(batch_dim, out_dim))
4041

4142
run_model(extern_data, lambda *, epoch, step: _Net(), _forward_step)

0 commit comments

Comments
 (0)