diff --git a/torch_autoencoder.py b/torch_autoencoder.py index acef314..8c8e455 100644 --- a/torch_autoencoder.py +++ b/torch_autoencoder.py @@ -72,11 +72,14 @@ def fit(self, X): # Graph if not self.warm_start or not hasattr(self, "model"): self.model = self.define_graph() + self.opt = self.optimizer( + self.model.parameters(), + lr=self.eta, + weight_decay=self.l2_strength) self.model.to(self.device) self.model.train() # Optimization: loss = nn.MSELoss() - optimizer = self.optimizer(self.model.parameters(), lr=self.eta) # Train: for iteration in range(1, self.max_iter+1): epoch_error = 0.0 @@ -86,9 +89,9 @@ def fit(self, X): batch_preds = self.model(X_batch) err = loss(batch_preds, y_batch) epoch_error += err.item() - optimizer.zero_grad() + self.opt.zero_grad() err.backward() - optimizer.step() + self.opt.step() self.errors.append(epoch_error) progress_bar( "Finished epoch {} of {}; error is {}".format( diff --git a/torch_color_describer.py b/torch_color_describer.py index e4a2633..fa6045d 100644 --- a/torch_color_describer.py +++ b/torch_color_describer.py @@ -347,6 +347,10 @@ def fit(self, color_seqs, word_seqs): if not self.warm_start or not hasattr(self, "model"): self.model = self.build_graph() + self.opt = self.optimizer( + self.model.parameters(), + lr=self.eta, + weight_decay=self.l2_strength) # Make sure that these attributes are aligned -- important # where a supplied pretrained embedding has determined @@ -370,11 +374,6 @@ def fit(self, color_seqs, word_seqs): loss = nn.CrossEntropyLoss() - optimizer = self.optimizer( - self.model.parameters(), - lr=self.eta, - weight_decay=self.l2_strength) - for iteration in range(1, self.max_iter+1): epoch_error = 0.0 for batch_colors, batch_words, batch_lens, targets in dataloader: @@ -392,9 +391,9 @@ def fit(self, color_seqs, word_seqs): err = loss(output, targets) epoch_error += err.item() - optimizer.zero_grad() + self.opt.zero_grad() err.backward() - optimizer.step() + self.opt.step() utils.progress_bar("Epoch {}; err = {}".format(iteration, epoch_error)) diff --git a/torch_rnn_classifier.py b/torch_rnn_classifier.py index 253c3e0..b23db37 100644 --- a/torch_rnn_classifier.py +++ b/torch_rnn_classifier.py @@ -222,6 +222,10 @@ def fit(self, X, y, **kwargs): # Graph: if not self.warm_start or not hasattr(self, "model"): self.model = self.build_graph() + self.opt = self.optimizer( + self.model.parameters(), + lr=self.eta, + weight_decay=self.l2_strength) self.model.to(self.device) self.model.train() # Make sure this value is up-to-date; self.`model` might change @@ -229,10 +233,6 @@ def fit(self, X, y, **kwargs): self.embed_dim = self.model.embed_dim # Optimization: loss = nn.CrossEntropyLoss() - optimizer = self.optimizer( - self.model.parameters(), - lr=self.eta, - weight_decay=self.l2_strength) # Train: for iteration in range(1, self.max_iter+1): epoch_error = 0.0 @@ -242,9 +242,9 @@ def fit(self, X, y, **kwargs): err = loss(batch_preds, y_batch) epoch_error += err.item() # Backprop: - optimizer.zero_grad() + self.opt.zero_grad() err.backward() - optimizer.step() + self.opt.step() # Incremental predictions where possible: if X_dev is not None and iteration > 0 and iteration % dev_iter == 0: self.dev_predictions[iteration] = self.predict(X_dev) diff --git a/torch_shallow_neural_classifier.py b/torch_shallow_neural_classifier.py index ab6f373..4ef58b2 100644 --- a/torch_shallow_neural_classifier.py +++ b/torch_shallow_neural_classifier.py @@ -87,14 +87,14 @@ def fit(self, X, y, **kwargs): # Graph: if not self.warm_start or not hasattr(self, "model"): self.model = self.define_graph() + self.opt = self.optimizer( + self.model.parameters(), + lr=self.eta, + weight_decay=self.l2_strength) self.model.to(self.device) self.model.train() # Optimization: loss = nn.CrossEntropyLoss() - optimizer = self.optimizer( - self.model.parameters(), - lr=self.eta, - weight_decay=self.l2_strength) # Train: for iteration in range(1, self.max_iter+1): epoch_error = 0.0 @@ -104,9 +104,9 @@ def fit(self, X, y, **kwargs): batch_preds = self.model(X_batch) err = loss(batch_preds, y_batch) epoch_error += err.item() - optimizer.zero_grad() + self.opt.zero_grad() err.backward() - optimizer.step() + self.opt.step() # Incremental predictions where possible: if X_dev is not None and iteration > 0 and iteration % dev_iter == 0: self.dev_predictions[iteration] = self.predict(X_dev) diff --git a/torch_tree_nn.py b/torch_tree_nn.py index d211294..480f154 100644 --- a/torch_tree_nn.py +++ b/torch_tree_nn.py @@ -121,12 +121,15 @@ def fit(self, X, y=None, **kwargs): # Model: if not self.warm_start or not hasattr(self, "model"): self.model = self.build_graph() + self.opt = self.optimizer( + self.model.parameters(), + lr=self.eta, + weight_decay=self.l2_strength) self.model.to(self.device) self.model.train() # Optimization: loss = nn.CrossEntropyLoss() - optimizer = self.optimizer(self.model.parameters(), lr=self.eta) # Train: dataset = list(zip(X, y)) @@ -138,9 +141,9 @@ def fit(self, X, y=None, **kwargs): label = self.convert_label(label) err = loss(pred, label) epoch_error += err.item() - optimizer.zero_grad() + self.opt.zero_grad() err.backward() - optimizer.step() + self.opt.step() # Incremental predictions where possible: if X_dev is not None and iteration > 0 and iteration % dev_iter == 0: self.dev_predictions[iteration] = self.predict(X_dev)