Skip to content

Commit

Permalink
Merge branch 'neuraxle-refactor-tf2-wip' of github.com:Neuraxio/seq2s…
Browse files Browse the repository at this point in the history
…eq-signal-prediction into neuraxle-refactor-tf2-wip
  • Loading branch information
guillaume-chevalier committed Jan 16, 2020
2 parents 33deec7 + c089c8c commit 836de3b
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,15 @@ def _create_decoder(step: Tensorflow2ModelStep, last_encoder_outputs, last_encod
decoder_lstm = RNN(cell=_create_stacked_rnn_cells(step), return_sequences=True, return_state=False)

last_encoder_output = tf.expand_dims(last_encoder_outputs, axis=1)
# last encoder output shape: (batch_size, 1, hidden_dim)

replicated_last_encoder_output = tf.repeat(
input=last_encoder_output,
repeats=step.hyperparams['window_size_future'],
axis=1
)
# replicated last encoder output shape: (batch_size, window_size_future, hidden_dim)

decoder_outputs = decoder_lstm(replicated_last_encoder_output, initial_state=last_encoders_states)
decoder_dense = Dense(step.hyperparams['output_dim'])

Expand Down

0 comments on commit 836de3b

Please sign in to comment.