@@ -54,7 +54,7 @@ def __init__(self, sentence_length, input_len, hidden_len, output_len):
54
54
55
55
def build (self , layer = 'LSTM' , mapping = 'm2m' , nb_layers = 2 , dropout = 0.2 ):
56
56
"""
57
- Stacked LSTM with specified dropout rate (default 0.2), built with
57
+ Stacked RNN with specified dropout rate (default 0.2), built with
58
58
softmax activation, cross entropy loss and rmsprop optimizer.
59
59
60
60
Arguments:
@@ -122,7 +122,9 @@ class LAYER(GRU):
122
122
123
123
self .model .add (Activation ('softmax' ))
124
124
125
- self .model .compile (loss = 'categorical_crossentropy' , optimizer = 'rmsprop' )
125
+ self .model .compile (loss = 'categorical_crossentropy' ,
126
+ optimizer = 'rmsprop' ,
127
+ metrics = ['accuracy' ])
126
128
127
129
def save_model (self , filename ):
128
130
"""
@@ -184,7 +186,7 @@ def on_epoch_end(self, epoch, logs={}): # pylint: disable=W0102
184
186
A method starting at the begining of the training.
185
187
186
188
Arguments:
187
- epoch: {integer}, the current epoch
189
+ epoch: {integer}, the current epoch.
188
190
logs: {dictionary}, recording the training and validation
189
191
losses and accuracy of every epoch.
190
192
"""
@@ -383,6 +385,45 @@ def predict(sequence, input_len, analyzer, nb_predictions=80,
383
385
print "\n "
384
386
385
387
388
+ def train (analyzer , train_data , nb_training_samples ,
389
+ val_data , nb_validation_samples ,
390
+ nb_epoch = 50 , nb_iterations = 4 ):
391
+ """
392
+ Trains the network.
393
+
394
+ Arguments:
395
+ analyzer: {SequenceAnalyzer}.
396
+ train_data: {tuple}, training data (X_train, y_train).
397
+ val_data: {tuple}, validation data (X_val, y_val).
398
+ nb_training_samples: {integer}, the number training samples.
399
+ nb_validation_samples: {integer}, the number validation samples.
400
+ nb_iterations: {integer}, number of iterations.
401
+ sentence_length: {integer}, the length of each training sentence.
402
+ """
403
+ for iteration in range (1 , nb_iterations + 1 ):
404
+ print ""
405
+ print "------------------------ Start Training ------------------------"
406
+ print "Iteration: " , iteration
407
+ print "Number of epoch per iteration: " , nb_epoch
408
+
409
+ # history of losses and accuracy
410
+ history = History ()
411
+
412
+ # saves the model weights after each epoch
413
+ # if the validation loss decreased
414
+ checkpointer = ModelCheckpoint (filepath = "weights.hdf5" ,
415
+ verbose = 1 , save_best_only = True )
416
+
417
+ # train the model with data generator
418
+ analyzer .model .fit_generator (train_data ,
419
+ samples_per_epoch = nb_training_samples ,
420
+ nb_epoch = nb_epoch , verbose = 1 ,
421
+ callbacks = [history , checkpointer ],
422
+ validation_data = val_data ,
423
+ nb_val_samples = nb_validation_samples )
424
+
425
+ analyzer .save_model ("weights-after-iteration.hdf5" )
426
+
386
427
387
428
def detect (sequence , input_len , analyzer , mapping = 'm2m' , sentence_length = 40 ):
388
429
"""
@@ -441,47 +482,6 @@ def detect(sequence, input_len, analyzer, mapping='m2m', sentence_length=40):
441
482
return prob
442
483
443
484
444
-
445
- def train (analyzer , train_data , nb_training_samples ,
446
- val_data , nb_validation_samples ,
447
- nb_epoch = 50 , nb_iterations = 4 ):
448
- """
449
- Trains the network.
450
-
451
- Arguments:
452
- analyzer: {SequenceAnalyzer}.
453
- train_data: {tuple}, training data (X_train, y_train).
454
- val_data: {tuple}, validation data (X_val, y_val).
455
- nb_training_samples: {integer}, the number training samples.
456
- nb_validation_samples: {integer}, the number validation samples.
457
- nb_iterations: {integer}, number of iterations.
458
- sentence_length: {integer}, the length of each training sentence.
459
- """
460
- for iteration in range (1 , nb_iterations + 1 ):
461
- print ""
462
- print "------------------------ Start Training ------------------------"
463
- print "Iteration: " , iteration
464
- print "Number of epoch per iteration: " , nb_epoch
465
-
466
- # history of losses and accuracy
467
- history = History ()
468
-
469
- # saves the model weights after each epoch
470
- # if the validation loss decreased
471
- checkpointer = ModelCheckpoint (filepath = "weights.hdf5" ,
472
- verbose = 1 , save_best_only = True )
473
-
474
- # train the model with data generator
475
- analyzer .model .fit_generator (train_data ,
476
- samples_per_epoch = nb_training_samples ,
477
- nb_epoch = nb_epoch , verbose = 1 ,
478
- callbacks = [history , checkpointer ],
479
- validation_data = val_data ,
480
- nb_val_samples = nb_validation_samples )
481
-
482
- analyzer .save_model ("weights-after-iteration.hdf5" )
483
-
484
-
485
485
def run (hidden_len = 512 , batch_size = 128 , nb_batch = 200 , nb_epoch = 50 ,
486
486
nb_iterations = 4 , lr = 0.001 , validation_split = 0.05 , nb_predictions = 20 ,
487
487
mapping = 'm2m' , sentence_length = 80 , step = 80 , mode = 'train' ):
0 commit comments