13
13
14
14
# noinspection PyProtectedMember
15
15
from returnn .frontend ._backend import Backend
16
- from returnn .frontend import RawTensorTypes , State
16
+ from returnn .frontend import RawTensorTypes
17
17
import returnn .frontend as rf
18
18
19
19
@@ -1142,26 +1142,20 @@ def pool(
1142
1142
@staticmethod
1143
1143
def lstm (
1144
1144
source : _TT ,
1145
- state : State ,
1145
+ * ,
1146
+ state_h : _TT ,
1147
+ state_c : _TT ,
1146
1148
ff_weights : _TT ,
1147
1149
ff_biases : Optional [_TT ],
1148
1150
rec_weights : _TT ,
1149
1151
rec_biases : Optional [_TT ],
1150
1152
spatial_dim : Dim ,
1153
+ in_dim : Dim ,
1151
1154
out_dim : Dim ,
1152
- ) -> Tuple [_TT , State ]:
1155
+ ) -> Tuple [_TT , Tuple [ _TT , _TT ] ]:
1153
1156
"""
1154
1157
Wraps the functional LSTM from PyTorch.
1155
1158
1156
- :param source: Tensor of shape ``[*, in_dim]``.
1157
- :param state: State of the LSTM.
1158
- :param ff_weights: Parameters for the weights of the feed-forward part.
1159
- :param ff_biases: Parameters for the biases of the feed-forward part.
1160
- :param rec_weights: Parameters for the weights of the recurrent part.
1161
- :param rec_biases: Parameters for the biases of the recurrent part.
1162
- :param spatial_dim: Dimension in which the LSTM operates.
1163
- :param out_dim:
1164
-
1165
1159
:return: Tuple consisting of two elements: the result as a :class:`Tensor`
1166
1160
and the new state as a :class:`State` (different from the previous one).
1167
1161
"""
@@ -1176,33 +1170,55 @@ def lstm(
1176
1170
# or torch LSTMCell: https://github.com/pytorch/pytorch/blob/4bead64/aten/src/ATen/native/RNN.cpp#L1458
1177
1171
lstm_params = (ff_weights .raw_tensor , rec_weights .raw_tensor , ff_biases .raw_tensor , rec_biases .raw_tensor )
1178
1172
has_biases = True
1179
- batch_first = source .dims [0 ].is_batch_dim ()
1180
1173
1181
- raw_result , raw_new_hidden_state , raw_new_cell_state = torch .lstm (
1182
- source .raw_tensor ,
1183
- (state .h .raw_tensor , state .c .raw_tensor ),
1174
+ batch_dims = [d for d in source .dims if d != spatial_dim and d != in_dim ]
1175
+ source = source .copy_transpose ([spatial_dim ] + batch_dims + [in_dim ])
1176
+ state_h = state_h .copy_transpose (batch_dims + [out_dim ])
1177
+ state_c = state_c .copy_transpose (batch_dims + [out_dim ])
1178
+
1179
+ source_raw = source .raw_tensor
1180
+ state_h_raw = state_h .raw_tensor
1181
+ state_c_raw = state_c .raw_tensor
1182
+ batch_dim = torch .prod (torch .tensor ([d .get_dim_value () for d in batch_dims ])) if batch_dims else 1
1183
+ if len (batch_dims ) != 1 :
1184
+ # Torch LSTM expects (seq_len, batch, input_size) as shape.
1185
+ # We need to merge all batch dims together.
1186
+ source_raw = torch .reshape (
1187
+ source_raw , [spatial_dim .get_dim_value ()] + [batch_dim ] + [in_dim .get_dim_value ()]
1188
+ )
1189
+ # Torch LSTM expects (num_layers * num_directions, batch, hidden_size) as shape.
1190
+ state_h_raw = torch .reshape (state_h_raw , [1 , batch_dim , out_dim .get_dim_value ()])
1191
+ state_c_raw = torch .reshape (state_c_raw , [1 , batch_dim , out_dim .get_dim_value ()])
1192
+
1193
+ out_raw , new_state_h_raw , new_state_c_raw = torch .lstm (
1194
+ source_raw ,
1195
+ (state_h_raw , state_c_raw ),
1184
1196
lstm_params ,
1185
1197
has_biases = has_biases ,
1186
1198
num_layers = 1 ,
1187
1199
dropout = 0.0 ,
1188
1200
train = rf .get_run_ctx ().train_flag ,
1189
1201
bidirectional = False ,
1190
- batch_first = batch_first ,
1202
+ batch_first = False ,
1191
1203
)
1192
1204
1193
- result_dims = list (source .dims )
1194
- index_in_dim = [i for i , dim in enumerate (result_dims ) if dim == spatial_dim ]
1195
- assert len (index_in_dim ) == 1 , "There are multiple spatial dimensions for the input tensor."
1196
- result_dims [index_in_dim [0 ]] = out_dim
1197
- result = Tensor (name = "lstm" , dims = result_dims , raw_tensor = raw_result , dtype = "float32" )
1198
- result .feature_dim = out_dim
1199
-
1200
- new_state = State ()
1201
- new_hidden_state = state .h .copy_template ()
1202
- new_hidden_state .raw_tensor = raw_new_hidden_state
1203
- new_state .h = new_hidden_state
1204
- new_cell_state = state .c .copy_template ()
1205
- new_cell_state .raw_tensor = raw_new_cell_state
1206
- new_state .c = new_cell_state
1207
-
1208
- return result , new_state
1205
+ if len (batch_dims ) != 1 :
1206
+ out_raw = torch .reshape (
1207
+ out_raw ,
1208
+ [spatial_dim .get_dim_value ()] + [d .get_dim_value () for d in batch_dims ] + [out_dim .get_dim_value ()],
1209
+ )
1210
+ new_state_h_raw = torch .reshape (new_state_h_raw , [d .get_dim_value () for d in state_h .dims ])
1211
+ new_state_c_raw = torch .reshape (new_state_c_raw , [d .get_dim_value () for d in state_c .dims ])
1212
+
1213
+ out = source .copy_template_replace_dim_tag (axis = - 1 , new_dim_tag = out_dim , name = "lstm" )
1214
+ out .feature_dim = out_dim
1215
+ out .raw_tensor = out_raw
1216
+
1217
+ new_state_h = state_h .copy_template ()
1218
+ new_state_h .raw_tensor = new_state_h_raw
1219
+ new_state_h .feature_dim = out_dim
1220
+ new_state_c = state_c .copy_template ()
1221
+ new_state_c .raw_tensor = new_state_c_raw
1222
+ new_state_c .feature_dim = out_dim
1223
+
1224
+ return out , (new_state_h , new_state_c )
0 commit comments