Skip to content

Commit

Permalink
Proper handling of the optimizer for warm starts
Browse files Browse the repository at this point in the history
  • Loading branch information
cgpotts committed Apr 1, 2020
1 parent f09d481 commit 1f89324
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 25 deletions.
9 changes: 6 additions & 3 deletions torch_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,14 @@ def fit(self, X):
# Graph
if not self.warm_start or not hasattr(self, "model"):
self.model = self.define_graph()
self.opt = self.optimizer(
self.model.parameters(),
lr=self.eta,
weight_decay=self.l2_strength)
self.model.to(self.device)
self.model.train()
# Optimization:
loss = nn.MSELoss()
optimizer = self.optimizer(self.model.parameters(), lr=self.eta)
# Train:
for iteration in range(1, self.max_iter+1):
epoch_error = 0.0
Expand All @@ -86,9 +89,9 @@ def fit(self, X):
batch_preds = self.model(X_batch)
err = loss(batch_preds, y_batch)
epoch_error += err.item()
optimizer.zero_grad()
self.opt.zero_grad()
err.backward()
optimizer.step()
self.opt.step()
self.errors.append(epoch_error)
progress_bar(
"Finished epoch {} of {}; error is {}".format(
Expand Down
13 changes: 6 additions & 7 deletions torch_color_describer.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,10 @@ def fit(self, color_seqs, word_seqs):

if not self.warm_start or not hasattr(self, "model"):
self.model = self.build_graph()
self.opt = self.optimizer(
self.model.parameters(),
lr=self.eta,
weight_decay=self.l2_strength)

# Make sure that these attributes are aligned -- important
# where a supplied pretrained embedding has determined
Expand All @@ -370,11 +374,6 @@ def fit(self, color_seqs, word_seqs):

loss = nn.CrossEntropyLoss()

optimizer = self.optimizer(
self.model.parameters(),
lr=self.eta,
weight_decay=self.l2_strength)

for iteration in range(1, self.max_iter+1):
epoch_error = 0.0
for batch_colors, batch_words, batch_lens, targets in dataloader:
Expand All @@ -392,9 +391,9 @@ def fit(self, color_seqs, word_seqs):

err = loss(output, targets)
epoch_error += err.item()
optimizer.zero_grad()
self.opt.zero_grad()
err.backward()
optimizer.step()
self.opt.step()

utils.progress_bar("Epoch {}; err = {}".format(iteration, epoch_error))

Expand Down
12 changes: 6 additions & 6 deletions torch_rnn_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,17 +222,17 @@ def fit(self, X, y, **kwargs):
# Graph:
if not self.warm_start or not hasattr(self, "model"):
self.model = self.build_graph()
self.opt = self.optimizer(
self.model.parameters(),
lr=self.eta,
weight_decay=self.l2_strength)
self.model.to(self.device)
self.model.train()
# Make sure this value is up-to-date; self.`model` might change
# it if it creates an embedding:
self.embed_dim = self.model.embed_dim
# Optimization:
loss = nn.CrossEntropyLoss()
optimizer = self.optimizer(
self.model.parameters(),
lr=self.eta,
weight_decay=self.l2_strength)
# Train:
for iteration in range(1, self.max_iter+1):
epoch_error = 0.0
Expand All @@ -242,9 +242,9 @@ def fit(self, X, y, **kwargs):
err = loss(batch_preds, y_batch)
epoch_error += err.item()
# Backprop:
optimizer.zero_grad()
self.opt.zero_grad()
err.backward()
optimizer.step()
self.opt.step()
# Incremental predictions where possible:
if X_dev is not None and iteration > 0 and iteration % dev_iter == 0:
self.dev_predictions[iteration] = self.predict(X_dev)
Expand Down
12 changes: 6 additions & 6 deletions torch_shallow_neural_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,14 @@ def fit(self, X, y, **kwargs):
# Graph:
if not self.warm_start or not hasattr(self, "model"):
self.model = self.define_graph()
self.opt = self.optimizer(
self.model.parameters(),
lr=self.eta,
weight_decay=self.l2_strength)
self.model.to(self.device)
self.model.train()
# Optimization:
loss = nn.CrossEntropyLoss()
optimizer = self.optimizer(
self.model.parameters(),
lr=self.eta,
weight_decay=self.l2_strength)
# Train:
for iteration in range(1, self.max_iter+1):
epoch_error = 0.0
Expand All @@ -104,9 +104,9 @@ def fit(self, X, y, **kwargs):
batch_preds = self.model(X_batch)
err = loss(batch_preds, y_batch)
epoch_error += err.item()
optimizer.zero_grad()
self.opt.zero_grad()
err.backward()
optimizer.step()
self.opt.step()
# Incremental predictions where possible:
if X_dev is not None and iteration > 0 and iteration % dev_iter == 0:
self.dev_predictions[iteration] = self.predict(X_dev)
Expand Down
9 changes: 6 additions & 3 deletions torch_tree_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,15 @@ def fit(self, X, y=None, **kwargs):
# Model:
if not self.warm_start or not hasattr(self, "model"):
self.model = self.build_graph()
self.opt = self.optimizer(
self.model.parameters(),
lr=self.eta,
weight_decay=self.l2_strength)
self.model.to(self.device)
self.model.train()

# Optimization:
loss = nn.CrossEntropyLoss()
optimizer = self.optimizer(self.model.parameters(), lr=self.eta)

# Train:
dataset = list(zip(X, y))
Expand All @@ -138,9 +141,9 @@ def fit(self, X, y=None, **kwargs):
label = self.convert_label(label)
err = loss(pred, label)
epoch_error += err.item()
optimizer.zero_grad()
self.opt.zero_grad()
err.backward()
optimizer.step()
self.opt.step()
# Incremental predictions where possible:
if X_dev is not None and iteration > 0 and iteration % dev_iter == 0:
self.dev_predictions[iteration] = self.predict(X_dev)
Expand Down

0 comments on commit 1f89324

Please sign in to comment.