From d70ccd95a6a0dfe2856fe1fd0e593a85ca6d563a Mon Sep 17 00:00:00 2001 From: Christopher Potts Date: Mon, 20 Jan 2020 17:56:10 -0800 Subject: [PATCH] PyTorch models: warm_start option and serialization methods --- torch_autoencoder.py | 7 +++- torch_model_base.py | 53 ++++++++++++++++++++++++++++++ torch_rnn_classifier.py | 9 ++++- torch_shallow_neural_classifier.py | 9 ++++- torch_tree_nn.py | 6 ++-- 5 files changed, 79 insertions(+), 5 deletions(-) diff --git a/torch_autoencoder.py b/torch_autoencoder.py index 9883f36..acef314 100644 --- a/torch_autoencoder.py +++ b/torch_autoencoder.py @@ -32,6 +32,10 @@ class TorchAutoencoder(TorchModelBase): L2 regularization strength. Default 0 is no regularization. device : 'cpu' or 'cuda' The default is to use 'cuda' iff available + warm_start : bool + If True, calling `fit` will resume training with previously + defined trainable parameters. If False, calling `fit` will + reinitialize all trainable parameters. Default: False. """ def __init__(self, **kwargs): @@ -66,7 +70,8 @@ def fit(self, X): dataset, batch_size=self.batch_size, shuffle=True, pin_memory=True) # Graph - self.model = self.define_graph() + if not self.warm_start or not hasattr(self, "model"): + self.model = self.define_graph() self.model.to(self.device) self.model.train() # Optimization: diff --git a/torch_model_base.py b/torch_model_base.py index 57302aa..dc138b4 100644 --- a/torch_model_base.py +++ b/torch_model_base.py @@ -1,3 +1,4 @@ +import pickle import torch import torch.nn as nn @@ -14,6 +15,7 @@ def __init__(self, eta=0.01, optimizer=torch.optim.Adam, l2_strength=0, + warm_start=False, device=None): self.hidden_dim = hidden_dim self.hidden_activation = hidden_activation @@ -22,6 +24,7 @@ def __init__(self, self.eta = eta self.optimizer = optimizer self.l2_strength = l2_strength + self.warm_start = warm_start if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" self.device = torch.device(device) @@ -49,6 +52,56 @@ def set_params(self, **params): setattr(self, key, val) return self + def to_pickle(self, output_filename): + """Serialize the entire class instance. Importantly, this + is different from using the standard `torch.save` method: + + torch.save(self.model.state_dict(), output_filename) + + The above stores only the underlying model parameters. In + contrast, the current method ensures that all of the model + parameters are on the CPU and then stores the full instance. + This is necessary to ensure that we retain all the information + needed to read new examples and make predictions. + + Parameters + ---------- + output_filename : str + Full path for the output file. + + """ + self.model = self.model.cpu() + with open(output_filename, 'wb') as f: + pickle.dump(self, f) + + @staticmethod + def from_pickle(src_filename): + """Load an entire class instance onto the CPU. This also sets + `self.warm_start = True` so that the loaded parameters are used + if `fit` is called. + + Importantly, this is different from recommended PyTorch method: + + self.model.load_state_dict(torch.load(src_filename)) + + We cannot reliably do this with new instances, because we need + to see new examples in order to set some of the model + dimensionalities and obtain information about what the class + labels are. Thus, the current method loads an entire serialized + class as created by `to_pickle`. + + The training and prediction code move the model parameters to + `self.device`. + + Parameters + ---------- + src_filename : str + Full path to the serialized model file. + + """ + with open(src_filename, 'rb') as f: + return pickle.load(f) + def __repr__(self): param_str = ["{}={}".format(a, getattr(self, a)) for a in self.params] param_str = ",\n\t".join(param_str) diff --git a/torch_rnn_classifier.py b/torch_rnn_classifier.py index 1cfa9c3..5f9926b 100644 --- a/torch_rnn_classifier.py +++ b/torch_rnn_classifier.py @@ -145,6 +145,10 @@ class TorchRNNClassifier(TorchModelBase): L2 regularization strength. Default 0 is no regularization. device : 'cpu' or 'cuda' The default is to use 'cuda' iff available + warm_start : bool + If True, calling `fit` will resume training with previously + defined trainable parameters. If False, calling `fit` will + reinitialize all trainable parameters. Default: False. """ def __init__(self, @@ -215,7 +219,8 @@ def fit(self, X, y, **kwargs): # Infer `embed_dim` from `X` in this case: self.embed_dim = X[0][0].shape[0] # Graph: - self.model = self.build_graph() + if not self.warm_start or not hasattr(self, "model"): + self.model = self.build_graph() self.model.to(self.device) self.model.train() # Make sure this value is up-to-date; self.`model` might change @@ -242,6 +247,7 @@ def fit(self, X, y, **kwargs): # Incremental predictions where possible: if X_dev is not None and iteration > 0 and iteration % dev_iter == 0: self.dev_predictions[iteration] = self.predict(X_dev) + self.model.train() self.errors.append(epoch_error) progress_bar("Finished epoch {} of {}; error is {}".format( iteration, self.max_iter, epoch_error)) @@ -261,6 +267,7 @@ def predict_proba(self, X): """ self.model.eval() with torch.no_grad(): + self.model.to(self.device) X, seq_lengths = self._prepare_dataset(X) preds = self.model(X, seq_lengths) preds = torch.softmax(preds, dim=1).cpu().numpy() diff --git a/torch_shallow_neural_classifier.py b/torch_shallow_neural_classifier.py index 92dd3cc..ab6f373 100644 --- a/torch_shallow_neural_classifier.py +++ b/torch_shallow_neural_classifier.py @@ -34,6 +34,10 @@ class TorchShallowNeuralClassifier(TorchModelBase): L2 regularization strength. Default 0 is no regularization. device : 'cpu' or 'cuda' The default is to use 'cuda' iff available + warm_start : bool + If True, calling `fit` will resume training with previously + defined trainable parameters. If False, calling `fit` will + reinitialize all trainable parameters. Default: False. """ def __init__(self, **kwargs): @@ -81,7 +85,8 @@ def fit(self, X, y, **kwargs): dataset, batch_size=self.batch_size, shuffle=True, pin_memory=True) # Graph: - self.model = self.define_graph() + if not self.warm_start or not hasattr(self, "model"): + self.model = self.define_graph() self.model.to(self.device) self.model.train() # Optimization: @@ -105,6 +110,7 @@ def fit(self, X, y, **kwargs): # Incremental predictions where possible: if X_dev is not None and iteration > 0 and iteration % dev_iter == 0: self.dev_predictions[iteration] = self.predict(X_dev) + self.model.train() self.errors.append(epoch_error) progress_bar( "Finished epoch {} of {}; error is {}".format( @@ -125,6 +131,7 @@ def predict_proba(self, X): """ self.model.eval() with torch.no_grad(): + self.model.to(self.device) X = torch.tensor(X, dtype=torch.float).to(self.device) preds = self.model(X) return torch.softmax(preds, dim=1).cpu().numpy() diff --git a/torch_tree_nn.py b/torch_tree_nn.py index 7d41430..5ee246a 100644 --- a/torch_tree_nn.py +++ b/torch_tree_nn.py @@ -78,7 +78,7 @@ def __init__(self, vocab, embedding=None, embed_dim=50, **kwargs): self.device = 'cpu' def build_graph(self): - self.model = TorchTreeNNModel( + return TorchTreeNNModel( vocab=self.vocab, embedding=self.embedding, embed_dim=self.embed_dim, @@ -111,7 +111,8 @@ def fit(self, X, **kwargs): self.n_classes_ = len(self.classes_) self.class2index = dict(zip(self.classes_, range(self.n_classes_))) # Model: - self.build_graph() + if not self.warm_start or not hasattr(self, "model"): + self.model = self.build_graph() self.model.to(self.device) self.model.train() # Optimization: @@ -132,6 +133,7 @@ def fit(self, X, **kwargs): # Incremental predictions where possible: if X_dev is not None and iteration > 0 and iteration % dev_iter == 0: self.dev_predictions[iteration] = self.predict(X_dev) + self.model.train() self.errors.append(epoch_error) progress_bar( "Finished epoch {} of {}; error is {}".format(