Skip to content

Commit

Permalink
PyTorch models: warm_start option and serialization methods
Browse files Browse the repository at this point in the history
  • Loading branch information
cgpotts committed Jan 21, 2020
1 parent 9f32588 commit d70ccd9
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 5 deletions.
7 changes: 6 additions & 1 deletion torch_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ class TorchAutoencoder(TorchModelBase):
L2 regularization strength. Default 0 is no regularization.
device : 'cpu' or 'cuda'
The default is to use 'cuda' iff available
warm_start : bool
If True, calling `fit` will resume training with previously
defined trainable parameters. If False, calling `fit` will
reinitialize all trainable parameters. Default: False.
"""
def __init__(self, **kwargs):
Expand Down Expand Up @@ -66,7 +70,8 @@ def fit(self, X):
dataset, batch_size=self.batch_size, shuffle=True,
pin_memory=True)
# Graph
self.model = self.define_graph()
if not self.warm_start or not hasattr(self, "model"):
self.model = self.define_graph()
self.model.to(self.device)
self.model.train()
# Optimization:
Expand Down
53 changes: 53 additions & 0 deletions torch_model_base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pickle
import torch
import torch.nn as nn

Expand All @@ -14,6 +15,7 @@ def __init__(self,
eta=0.01,
optimizer=torch.optim.Adam,
l2_strength=0,
warm_start=False,
device=None):
self.hidden_dim = hidden_dim
self.hidden_activation = hidden_activation
Expand All @@ -22,6 +24,7 @@ def __init__(self,
self.eta = eta
self.optimizer = optimizer
self.l2_strength = l2_strength
self.warm_start = warm_start
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = torch.device(device)
Expand Down Expand Up @@ -49,6 +52,56 @@ def set_params(self, **params):
setattr(self, key, val)
return self

def to_pickle(self, output_filename):
"""Serialize the entire class instance. Importantly, this
is different from using the standard `torch.save` method:
torch.save(self.model.state_dict(), output_filename)
The above stores only the underlying model parameters. In
contrast, the current method ensures that all of the model
parameters are on the CPU and then stores the full instance.
This is necessary to ensure that we retain all the information
needed to read new examples and make predictions.
Parameters
----------
output_filename : str
Full path for the output file.
"""
self.model = self.model.cpu()
with open(output_filename, 'wb') as f:
pickle.dump(self, f)

@staticmethod
def from_pickle(src_filename):
"""Load an entire class instance onto the CPU. This also sets
`self.warm_start = True` so that the loaded parameters are used
if `fit` is called.
Importantly, this is different from recommended PyTorch method:
self.model.load_state_dict(torch.load(src_filename))
We cannot reliably do this with new instances, because we need
to see new examples in order to set some of the model
dimensionalities and obtain information about what the class
labels are. Thus, the current method loads an entire serialized
class as created by `to_pickle`.
The training and prediction code move the model parameters to
`self.device`.
Parameters
----------
src_filename : str
Full path to the serialized model file.
"""
with open(src_filename, 'rb') as f:
return pickle.load(f)

def __repr__(self):
param_str = ["{}={}".format(a, getattr(self, a)) for a in self.params]
param_str = ",\n\t".join(param_str)
Expand Down
9 changes: 8 additions & 1 deletion torch_rnn_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,10 @@ class TorchRNNClassifier(TorchModelBase):
L2 regularization strength. Default 0 is no regularization.
device : 'cpu' or 'cuda'
The default is to use 'cuda' iff available
warm_start : bool
If True, calling `fit` will resume training with previously
defined trainable parameters. If False, calling `fit` will
reinitialize all trainable parameters. Default: False.
"""
def __init__(self,
Expand Down Expand Up @@ -215,7 +219,8 @@ def fit(self, X, y, **kwargs):
# Infer `embed_dim` from `X` in this case:
self.embed_dim = X[0][0].shape[0]
# Graph:
self.model = self.build_graph()
if not self.warm_start or not hasattr(self, "model"):
self.model = self.build_graph()
self.model.to(self.device)
self.model.train()
# Make sure this value is up-to-date; self.`model` might change
Expand All @@ -242,6 +247,7 @@ def fit(self, X, y, **kwargs):
# Incremental predictions where possible:
if X_dev is not None and iteration > 0 and iteration % dev_iter == 0:
self.dev_predictions[iteration] = self.predict(X_dev)
self.model.train()
self.errors.append(epoch_error)
progress_bar("Finished epoch {} of {}; error is {}".format(
iteration, self.max_iter, epoch_error))
Expand All @@ -261,6 +267,7 @@ def predict_proba(self, X):
"""
self.model.eval()
with torch.no_grad():
self.model.to(self.device)
X, seq_lengths = self._prepare_dataset(X)
preds = self.model(X, seq_lengths)
preds = torch.softmax(preds, dim=1).cpu().numpy()
Expand Down
9 changes: 8 additions & 1 deletion torch_shallow_neural_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ class TorchShallowNeuralClassifier(TorchModelBase):
L2 regularization strength. Default 0 is no regularization.
device : 'cpu' or 'cuda'
The default is to use 'cuda' iff available
warm_start : bool
If True, calling `fit` will resume training with previously
defined trainable parameters. If False, calling `fit` will
reinitialize all trainable parameters. Default: False.
"""
def __init__(self, **kwargs):
Expand Down Expand Up @@ -81,7 +85,8 @@ def fit(self, X, y, **kwargs):
dataset, batch_size=self.batch_size, shuffle=True,
pin_memory=True)
# Graph:
self.model = self.define_graph()
if not self.warm_start or not hasattr(self, "model"):
self.model = self.define_graph()
self.model.to(self.device)
self.model.train()
# Optimization:
Expand All @@ -105,6 +110,7 @@ def fit(self, X, y, **kwargs):
# Incremental predictions where possible:
if X_dev is not None and iteration > 0 and iteration % dev_iter == 0:
self.dev_predictions[iteration] = self.predict(X_dev)
self.model.train()
self.errors.append(epoch_error)
progress_bar(
"Finished epoch {} of {}; error is {}".format(
Expand All @@ -125,6 +131,7 @@ def predict_proba(self, X):
"""
self.model.eval()
with torch.no_grad():
self.model.to(self.device)
X = torch.tensor(X, dtype=torch.float).to(self.device)
preds = self.model(X)
return torch.softmax(preds, dim=1).cpu().numpy()
Expand Down
6 changes: 4 additions & 2 deletions torch_tree_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def __init__(self, vocab, embedding=None, embed_dim=50, **kwargs):
self.device = 'cpu'

def build_graph(self):
self.model = TorchTreeNNModel(
return TorchTreeNNModel(
vocab=self.vocab,
embedding=self.embedding,
embed_dim=self.embed_dim,
Expand Down Expand Up @@ -111,7 +111,8 @@ def fit(self, X, **kwargs):
self.n_classes_ = len(self.classes_)
self.class2index = dict(zip(self.classes_, range(self.n_classes_)))
# Model:
self.build_graph()
if not self.warm_start or not hasattr(self, "model"):
self.model = self.build_graph()
self.model.to(self.device)
self.model.train()
# Optimization:
Expand All @@ -132,6 +133,7 @@ def fit(self, X, **kwargs):
# Incremental predictions where possible:
if X_dev is not None and iteration > 0 and iteration % dev_iter == 0:
self.dev_predictions[iteration] = self.predict(X_dev)
self.model.train()
self.errors.append(epoch_error)
progress_bar(
"Finished epoch {} of {}; error is {}".format(
Expand Down

0 comments on commit d70ccd9

Please sign in to comment.