-
Notifications
You must be signed in to change notification settings - Fork 17
/
model.py
131 lines (101 loc) · 3.3 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import os
from keras.layers import Input
from keras.models import Model
from keras.callbacks import ModelCheckpoint
from utils import train_generator
class BaseModel(object):
"""Base Model Interface
Methods
----------
fit(train_data, valid_data, epohcs, batch_size, **kwargs)
predict(X)
evaluate(X, y)
Examples
----------
>>> model = Model("example", inference, "model.h5")
>>> model.fit([X_train, y_train], [X_val, y_val])
"""
def __init__(self, name, fn, model_path):
"""Constructor for BaseModel
Parameters
----------
name : str
Name of this model
fn : function
Inference function, y = fn(X)
model_path : str
Path to a model.h5
"""
X = Input(shape=[28, 28, 1])
y = fn(X)
self.model = Model(X, y, name=name)
self.model.compile("adam", "categorical_crossentropy", ["accuracy"])
self.path = model_path
self.name = name
self.load()
def fit(self, train_data, valid_data, epochs=10, batchsize=32, **kwargs):
"""Training function
Evaluate at each epoch against validation data
Save the best model according to the validation loss
Parameters
----------
train_data : tuple, (X_train, y_train)
X_train.shape == (N, H, W, C)
y_train.shape == (N, N_classes)
valid_data : tuple
(X_val, y_val)
epochs : int
Number of epochs to train
batchsize : int
Minibatch size
**kwargs
Keywords arguments for `fit_generator`
"""
callback_best_only = ModelCheckpoint(self.path, save_best_only=True)
train_gen, val_gen = train_generator()
X_train, y_train = train_data
X_val, y_val = valid_data
N = X_train.shape[0]
N_val = X_val.shape[0]
self.model.fit_generator(train_gen.flow(X_train, y_train, batchsize),
steps_per_epoch=N / batchsize,
validation_data=val_gen.flow(X_val, y_val, batchsize),
validation_steps=N_val / batchsize,
epochs=epochs,
callbacks=[callback_best_only],
**kwargs)
def save(self):
"""Save weights
Should not be used manually
"""
self.model.save_weights(self.path)
def load(self):
"""Load weights from self.path """
if os.path.isfile(self.path):
self.model.load_weights(self.path)
print("Model loaded")
else:
print("No model is found")
def predict(self, X):
"""Return probabilities for each classes
Parameters
----------
X : array-like (N, H, W, C)
Returns
----------
y : array-like (N, N_classes)
Probability array
"""
return self.model.predict(X)
def evaluate(self, X, y):
"""Return an accuracy
Parameters
----------
X : array-like (N, H, W, C)
y : array-like (N, N_classes)
Returns
----------
acc : float
Accuracy
"""
return self.model.evaluate(X, y)