diff --git a/model.py b/model.py index dac80eb..a9a1efa 100644 --- a/model.py +++ b/model.py @@ -50,11 +50,15 @@ def _create_decoder(step: Tensorflow2ModelStep, last_encoder_outputs, last_encod decoder_lstm = RNN(cell=_create_stacked_rnn_cells(step), return_sequences=True, return_state=False) last_encoder_output = tf.expand_dims(last_encoder_outputs, axis=1) + # last encoder output shape: (batch_size, 1, hidden_dim) + replicated_last_encoder_output = tf.repeat( input=last_encoder_output, repeats=step.hyperparams['window_size_future'], axis=1 ) + # replicated last encoder output shape: (batch_size, window_size_future, hidden_dim) + decoder_outputs = decoder_lstm(replicated_last_encoder_output, initial_state=last_encoders_states) decoder_dense = Dense(step.hyperparams['output_dim'])