Skip to content

Commit

Permalink
RF LSTM cleanup and fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Apr 18, 2023
1 parent 526ff23 commit 081fd5e
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 60 deletions.
14 changes: 9 additions & 5 deletions returnn/frontend/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@
from typing import TYPE_CHECKING, Optional, Any, Union, TypeVar, Generic, Type, Sequence, Dict, Tuple
import contextlib
import numpy

import returnn.frontend as rf
from . import State

if TYPE_CHECKING:
from returnn.tensor import Tensor, Dim
Expand Down Expand Up @@ -911,25 +909,31 @@ def pool(
@staticmethod
def lstm(
source: Tensor,
state: State,
*,
state_c: Tensor,
state_h: Tensor,
ff_weights: Tensor,
ff_biases: Tensor,
rec_weights: Tensor,
rec_biases: Tensor,
spatial_dim: Dim,
in_dim: Dim,
out_dim: Dim,
) -> Tuple[Tensor, State]:
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
"""
Functional LSTM.
:param source: Tensor of shape [*, in_dim].
:param state: State of the LSTM.
:param state_c:
:param state_h:
:param ff_weights: Parameters for the weights of the feed-forward part.
:param ff_biases: Parameters for the biases of the feed-forward part.
:param rec_weights: Parameters for the weights of the recurrent part.
:param rec_biases: Parameters for the biases of the recurrent part.
:param spatial_dim: Dimension in which the LSTM operates.
:param in_dim:
:param out_dim:
:return: output, (state_h, state_c)
"""
raise NotImplementedError

Expand Down
37 changes: 24 additions & 13 deletions returnn/frontend/lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,12 @@
from typing import Tuple, TypeVar

import returnn.frontend as rf
from returnn.frontend import State
from returnn.tensor import Tensor, Dim


T = TypeVar("T")

__all__ = ["LSTM"]
__all__ = ["LSTM", "LstmState"]


class LSTM(rf.Module):
Expand All @@ -36,27 +35,27 @@ def __init__(
self.in_dim = in_dim
self.out_dim = out_dim

self.ff_weights = rf.Parameter((self.out_dim * 4, self.in_dim)) # type: Tensor[T]
self.ff_weights = rf.Parameter((4 * self.out_dim, self.in_dim)) # type: Tensor[T]
self.ff_weights.initial = rf.init.Glorot()
self.recurrent_weights = rf.Parameter((self.out_dim * 4, self.out_dim)) # type: Tensor[T]
self.recurrent_weights = rf.Parameter((4 * self.out_dim, self.out_dim)) # type: Tensor[T]
self.recurrent_weights.initial = rf.init.Glorot()

self.ff_biases = None
self.recurrent_biases = None
if with_bias:
self.ff_biases = rf.Parameter((self.out_dim * 4,)) # type: Tensor[T]
self.ff_biases = rf.Parameter((4 * self.out_dim,)) # type: Tensor[T]
self.ff_biases.initial = 0.0
self.recurrent_biases = rf.Parameter((self.out_dim * 4,)) # type: Tensor[T]
self.recurrent_biases = rf.Parameter((4 * self.out_dim,)) # type: Tensor[T]
self.recurrent_biases.initial = 0.0

def __call__(self, source: Tensor[T], state: State) -> Tuple[Tensor, State]:
def __call__(self, source: Tensor[T], *, state: LstmState, spatial_dim: Dim) -> Tuple[Tensor, LstmState]:
"""
Forward call of the LSTM.
:param source: Tensor of size ``[*, in_dim]``.
:param state: State of the LSTM. Contains two :class:`Tensor`: ``state.h`` as the hidden state,
and ``state.c`` as the cell state. Both are of shape ``[out_dim]``.
:return: Output of forward as a :class:`Tensor` of size ``[*, out_dim]``, and next LSTM state.
:param source: Tensor of size {...,in_dim} if spatial_dim is single_step_dim else {...,spatial_dim,in_dim}.
:param state: State of the LSTM. Both h and c are of shape {...,out_dim}.
:return: output of shape {...,out_dim} if spatial_dim is single_step_dim else {...,spatial_dim,out_dim},
and new state of the LSTM.
"""
if not state.h or not state.c:
raise ValueError(f"{self}: state {state} needs attributes ``h`` (hidden) and ``c`` (cell).")
Expand All @@ -66,13 +65,25 @@ def __call__(self, source: Tensor[T], state: State) -> Tuple[Tensor, State]:
# noinspection PyProtectedMember
result, new_state = source._raw_backend.lstm(
source=source,
state=state,
state_c=state.c,
state_h=state.h,
ff_weights=self.ff_weights,
ff_biases=self.ff_biases,
rec_weights=self.recurrent_weights,
rec_biases=self.recurrent_biases,
spatial_dim=self.in_dim,
spatial_dim=spatial_dim,
in_dim=self.in_dim,
out_dim=self.out_dim,
)
new_state = LstmState(*new_state)

return result, new_state


class LstmState(rf.State):
"""LSTM state"""

def __init__(self, h: Tensor, c: Tensor):
super().__init__()
self.h = h
self.c = c
82 changes: 49 additions & 33 deletions returnn/torch/frontend/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

# noinspection PyProtectedMember
from returnn.frontend._backend import Backend
from returnn.frontend import RawTensorTypes, State
from returnn.frontend import RawTensorTypes
import returnn.frontend as rf


Expand Down Expand Up @@ -1142,26 +1142,20 @@ def pool(
@staticmethod
def lstm(
source: _TT,
state: State,
*,
state_h: _TT,
state_c: _TT,
ff_weights: _TT,
ff_biases: Optional[_TT],
rec_weights: _TT,
rec_biases: Optional[_TT],
spatial_dim: Dim,
in_dim: Dim,
out_dim: Dim,
) -> Tuple[_TT, State]:
) -> Tuple[_TT, Tuple[_TT, _TT]]:
"""
Wraps the functional LSTM from PyTorch.
:param source: Tensor of shape ``[*, in_dim]``.
:param state: State of the LSTM.
:param ff_weights: Parameters for the weights of the feed-forward part.
:param ff_biases: Parameters for the biases of the feed-forward part.
:param rec_weights: Parameters for the weights of the recurrent part.
:param rec_biases: Parameters for the biases of the recurrent part.
:param spatial_dim: Dimension in which the LSTM operates.
:param out_dim:
:return: Tuple consisting of two elements: the result as a :class:`Tensor`
and the new state as a :class:`State` (different from the previous one).
"""
Expand All @@ -1176,33 +1170,55 @@ def lstm(
# or torch LSTMCell: https://github.com/pytorch/pytorch/blob/4bead64/aten/src/ATen/native/RNN.cpp#L1458
lstm_params = (ff_weights.raw_tensor, rec_weights.raw_tensor, ff_biases.raw_tensor, rec_biases.raw_tensor)
has_biases = True
batch_first = source.dims[0].is_batch_dim()

raw_result, raw_new_hidden_state, raw_new_cell_state = torch.lstm(
source.raw_tensor,
(state.h.raw_tensor, state.c.raw_tensor),
batch_dims = [d for d in source.dims if d != spatial_dim and d != in_dim]
source = source.copy_transpose([spatial_dim] + batch_dims + [in_dim])
state_h = state_h.copy_transpose(batch_dims + [out_dim])
state_c = state_c.copy_transpose(batch_dims + [out_dim])

source_raw = source.raw_tensor
state_h_raw = state_h.raw_tensor
state_c_raw = state_c.raw_tensor
batch_dim = torch.prod(torch.tensor([d.get_dim_value() for d in batch_dims])) if batch_dims else 1
if len(batch_dims) != 1:
# Torch LSTM expects (seq_len, batch, input_size) as shape.
# We need to merge all batch dims together.
source_raw = torch.reshape(
source_raw, [spatial_dim.get_dim_value()] + [batch_dim] + [in_dim.get_dim_value()]
)
# Torch LSTM expects (num_layers * num_directions, batch, hidden_size) as shape.
state_h_raw = torch.reshape(state_h_raw, [1, batch_dim, out_dim.get_dim_value()])
state_c_raw = torch.reshape(state_c_raw, [1, batch_dim, out_dim.get_dim_value()])

out_raw, new_state_h_raw, new_state_c_raw = torch.lstm(
source_raw,
(state_h_raw, state_c_raw),
lstm_params,
has_biases=has_biases,
num_layers=1,
dropout=0.0,
train=rf.get_run_ctx().train_flag,
bidirectional=False,
batch_first=batch_first,
batch_first=False,
)

result_dims = list(source.dims)
index_in_dim = [i for i, dim in enumerate(result_dims) if dim == spatial_dim]
assert len(index_in_dim) == 1, "There are multiple spatial dimensions for the input tensor."
result_dims[index_in_dim[0]] = out_dim
result = Tensor(name="lstm", dims=result_dims, raw_tensor=raw_result, dtype="float32")
result.feature_dim = out_dim

new_state = State()
new_hidden_state = state.h.copy_template()
new_hidden_state.raw_tensor = raw_new_hidden_state
new_state.h = new_hidden_state
new_cell_state = state.c.copy_template()
new_cell_state.raw_tensor = raw_new_cell_state
new_state.c = new_cell_state

return result, new_state
if len(batch_dims) != 1:
out_raw = torch.reshape(
out_raw,
[spatial_dim.get_dim_value()] + [d.get_dim_value() for d in batch_dims] + [out_dim.get_dim_value()],
)
new_state_h_raw = torch.reshape(new_state_h_raw, [d.get_dim_value() for d in state_h.dims])
new_state_c_raw = torch.reshape(new_state_c_raw, [d.get_dim_value() for d in state_c.dims])

out = source.copy_template_replace_dim_tag(axis=-1, new_dim_tag=out_dim, name="lstm")
out.feature_dim = out_dim
out.raw_tensor = out_raw

new_state_h = state_h.copy_template()
new_state_h.raw_tensor = new_state_h_raw
new_state_h.feature_dim = out_dim
new_state_c = state_c.copy_template()
new_state_c.raw_tensor = new_state_c_raw
new_state_c.feature_dim = out_dim

return out, (new_state_h, new_state_c)
19 changes: 10 additions & 9 deletions tests/test_rf_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from typing import Tuple
import _setup_test_env # noqa
import returnn.frontend as rf
from returnn.frontend import State
from returnn.tensor import Tensor, Dim, TensorDict, batch_dim
from rf_utils import run_model

Expand All @@ -26,16 +25,18 @@ def __init__(self):
super().__init__()
self.lstm = rf.LSTM(in_dim, out_dim)

def __call__(self, x: Tensor, s: State) -> Tuple[Tensor, State]:
return self.lstm(x, s)
def __call__(self, x: Tensor, *, spatial_dim: Dim, state: rf.LstmState) -> Tuple[Tensor, rf.LstmState]:
return self.lstm(x, state=state, spatial_dim=spatial_dim)

# noinspection PyShadowingNames
def _forward_step(*, model: _Net, extern_data: TensorDict):
first_dim_state = Dim(1, name="blstm_times_nlayers")
state = State()
state.h = rf.random(distribution="normal", dims=[first_dim_state, batch_dim, out_dim], dtype="float32")
state.c = rf.random(distribution="normal", dims=[first_dim_state, batch_dim, out_dim], dtype="float32")
out, new_state = model(extern_data["data"], state)
out.mark_as_default_output()
state = rf.LstmState(
h=rf.random(distribution="normal", dims=[batch_dim, out_dim], dtype="float32"),
c=rf.random(distribution="normal", dims=[batch_dim, out_dim], dtype="float32"),
)
out, new_state = model(extern_data["data"], state=state, spatial_dim=time_dim)
out.mark_as_output("out", shape=(batch_dim, time_dim, out_dim))
new_state.h.mark_as_output("h", shape=(batch_dim, out_dim))
new_state.c.mark_as_output("c", shape=(batch_dim, out_dim))

run_model(extern_data, lambda *, epoch, step: _Net(), _forward_step)

0 comments on commit 081fd5e

Please sign in to comment.