Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions magpie/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(self, keras_model=None, word2vec_model=None, scaler=None,
else:
self.scaler = scaler

def train(self, train_dir, vocabulary, test_dir=None, callbacks=None,
def train(self, train_dir, vocabulary, test_dir=None, callbacks=None, k=1,
nn_model=NN_ARCHITECTURE, batch_size=BATCH_SIZE, test_ratio=0.0,
epochs=EPOCHS, verbose=1):
"""
Expand Down Expand Up @@ -82,7 +82,8 @@ def train(self, train_dir, vocabulary, test_dir=None, callbacks=None,
self.keras_model = get_nn_model(
nn_model,
embedding=self.word2vec_model.vector_size,
output_length=len(vocabulary)
output_length=len(vocabulary,
k=k)
)

(x_train, y_train), test_data = get_data_for_model(
Expand All @@ -108,7 +109,7 @@ def train(self, train_dir, vocabulary, test_dir=None, callbacks=None,
)

def batch_train(self, train_dir, vocabulary, test_dir=None, callbacks=None,
nn_model=NN_ARCHITECTURE, batch_size=BATCH_SIZE,
k=1, nn_model=NN_ARCHITECTURE, batch_size=BATCH_SIZE,
epochs=EPOCHS, verbose=1):
"""
Train the model on given data
Expand Down Expand Up @@ -150,7 +151,8 @@ def batch_train(self, train_dir, vocabulary, test_dir=None, callbacks=None,
self.keras_model = get_nn_model(
nn_model,
embedding=self.word2vec_model.vector_size,
output_length=len(vocabulary)
output_length=len(vocabulary),
k=k
)

train_generator, test_data = get_data_for_model(
Expand Down
19 changes: 12 additions & 7 deletions magpie/nn/models.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,25 @@
from keras.layers import Input, Dense, GRU, Dropout, BatchNormalization, \
MaxPooling1D, Conv1D, Flatten, Concatenate
from keras.models import Model
import keras.backend as K

from magpie.config import SAMPLE_LENGTH

def top_custom_categorical_accuracy(k=1):
def top_k_categorical_accuracy(y_true, y_pred):
return K.mean(K.in_top_k(y_pred, K.argmax(y_true, axis=-1), k), axis=-1)
return top_k_categorical_accuracy

def get_nn_model(nn_model, embedding, output_length):
def get_nn_model(nn_model, embedding, output_length, k=1):
if nn_model == 'cnn':
return cnn(embedding_size=embedding, output_length=output_length)
return cnn(embedding_size=embedding, output_length=output_length, k=k)
elif nn_model == 'rnn':
return rnn(embedding_size=embedding, output_length=output_length)
return rnn(embedding_size=embedding, output_length=output_length, k=k)
else:
raise ValueError("Unknown NN type: {}".format(nn_model))


def cnn(embedding_size, output_length):
def cnn(embedding_size, output_length, k):
""" Create and return a keras model of a CNN """

NB_FILTER = 256
Expand Down Expand Up @@ -47,13 +52,13 @@ def cnn(embedding_size, output_length):
model.compile(
loss='binary_crossentropy',
optimizer='adam',
metrics=['top_k_categorical_accuracy'],
metrics=[top_custom_categorical_accuracy(k)],
)

return model


def rnn(embedding_size, output_length):
def rnn(embedding_size, output_length, k):
""" Create and return a keras model of a RNN """
HIDDEN_LAYER_SIZE = 256

Expand All @@ -76,7 +81,7 @@ def rnn(embedding_size, output_length):
model.compile(
loss='binary_crossentropy',
optimizer='adam',
metrics=['top_k_categorical_accuracy'],
metrics=[top_custom_categorical_accuracy(k)],
)

return model