Skip to content

Commit

Permalink
Merge branch 'neuraxle-refactor-tf2-wip' of github.com:Neuraxio/seq2s…
Browse files Browse the repository at this point in the history
…eq-signal-prediction into neuraxle-refactor-tf2-wip
  • Loading branch information
alexbrillant committed Jan 16, 2020
2 parents 100b95d + 3952257 commit c089c8c
Showing 1 changed file with 21 additions and 30 deletions.
51 changes: 21 additions & 30 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from tensorflow_core.python.keras import Input, Model
from tensorflow_core.python.keras.layers import GRUCell, RNN, Dense
from tensorflow_core.python.training.rmsprop import RMSPropOptimizer
from tensorflow_core.python.training.adam import AdamOptimizer

from data_loading import generate_data
from neuraxle_tensorflow.tensorflow_v1 import TensorflowV1ModelStep
Expand Down Expand Up @@ -89,24 +90,7 @@ def create_loss(step: Tensorflow2ModelStep, expected_outputs, predicted_outputs)


def create_optimizer(step: TensorflowV1ModelStep):
return RMSPropOptimizer(
learning_rate=step.hyperparams['learning_rate'],
decay=step.hyperparams['lr_decay'],
momentum=step.hyperparams['momentum']
)


seq2seq_pipeline_hyperparams = HyperparameterSamples({
'hidden_dim': 35,
'layers_stacked_count': 2,
'lambda_loss_amount': 0.003,
'learning_rate': 0.006,
'lr_decay': 0.92,
'momentum': 0.5,
'window_size_future': 40,
'output_dim': 2,
'input_dim': 2
})
return AdamOptimizer(learning_rate=step.hyperparams['learning_rate'])


def metric_2d_to_3d_wrapper(metric_fun: Callable):
Expand Down Expand Up @@ -137,23 +121,30 @@ def main():
validation_size = 0.15
max_plotted_predictions = 10

seq2seq_pipeline_hyperparams = HyperparameterSamples({
'hidden_dim': 100,
'layers_stacked_count': 2,
'lambda_loss_amount': 0.0003,
'learning_rate': 0.009,
'window_size_future': sequence_length,
'output_dim': output_dim,
'input_dim': input_dim
})
metrics = {'mse': metric_2d_to_3d_wrapper(mean_squared_error)}

signal_prediction_pipeline = Pipeline([
ForEachDataInput(MeanStdNormalizer()),
ToNumpy(),
PlotPredictionsWrapper(Tensorflow2ModelStep(
create_model=create_model,
create_loss=create_loss,
create_optimizer=create_optimizer,
expected_outputs_dtype=tf.dtypes.float32,
data_inputs_dtype=tf.dtypes.float32,
print_loss=True
).set_hyperparams(seq2seq_pipeline_hyperparams).update_hyperparams(HyperparameterSamples({
'window_size_future': sequence_length,
'input_dim': input_dim,
'output_dim': output_dim
})))
PlotPredictionsWrapper(
Tensorflow2ModelStep(
create_model=create_model,
create_loss=create_loss,
create_optimizer=create_optimizer,
expected_outputs_dtype=tf.dtypes.float32,
data_inputs_dtype=tf.dtypes.float32,
print_loss=True
).set_hyperparams(seq2seq_pipeline_hyperparams)
)
]).set_name('SignalPrediction')

pipeline = Pipeline([EpochRepeater(
Expand Down

0 comments on commit c089c8c

Please sign in to comment.