Skip to content

Commit

Permalink
Merge pull request cgpotts#23 from zijwang/fix_model_eval
Browse files Browse the repository at this point in the history
fix model.eval() and explicitly add model.train()
  • Loading branch information
cgpotts authored Apr 9, 2019
2 parents 57a01cc + 0ee1437 commit 2e37c94
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 0 deletions.
2 changes: 2 additions & 0 deletions torch_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def fit(self, X):
# Graph
self.model = self.define_graph()
self.model.to(self.device)
self.model.train()
# Optimization:
loss = nn.MSELoss()
optimizer = self.optimizer(self.model.parameters(), lr=self.eta)
Expand Down Expand Up @@ -105,6 +106,7 @@ def predict(self, X):
This will have the same shape as `X`.
"""
self.model.eval()
with torch.no_grad():
X_tensor = self.convert_input_to_tensor(X)
X_pred = self.model(X_tensor)
Expand Down
2 changes: 2 additions & 0 deletions torch_rnn_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ def fit(self, X, y, **kwargs):
# Graph:
self.model = self.build_graph()
self.model.to(self.device)
self.model.train()
# Make sure this value is up-to-date; self.`model` might change
# it if it creates an embedding:
self.embed_dim = self.model.embed_dim
Expand Down Expand Up @@ -257,6 +258,7 @@ def predict_proba(self, X):
np.array with shape (len(X), self.n_classes_)
"""
self.model.eval()
with torch.no_grad():
X, seq_lengths = self._prepare_dataset(X)
preds = self.model(X, seq_lengths)
Expand Down
2 changes: 2 additions & 0 deletions torch_shallow_neural_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def fit(self, X, y, **kwargs):
# Graph:
self.model = self.define_graph()
self.model.to(self.device)
self.model.train()
# Optimization:
loss = nn.CrossEntropyLoss()
optimizer = self.optimizer(
Expand Down Expand Up @@ -122,6 +123,7 @@ def predict_proba(self, X):
np.array with shape (len(X), self.n_classes_)
"""
self.model.eval()
with torch.no_grad():
X = torch.tensor(X, dtype=torch.float).to(self.device)
preds = self.model(X)
Expand Down
1 change: 1 addition & 0 deletions torch_subtree_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def predict_proba(self, X):
"""Returns a list of lists of prediction vectors, one list of
vectors per tree in `X`.
"""
self.model.eval()
with torch.no_grad():
preds = []
for tree in X:
Expand Down
2 changes: 2 additions & 0 deletions torch_tree_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def fit(self, X, **kwargs):
# Model:
self.build_graph()
self.model.to(self.device)
self.model.train()
# Optimization:
loss = nn.CrossEntropyLoss()
optimizer = self.optimizer(self.model.parameters(), lr=self.eta)
Expand Down Expand Up @@ -179,6 +180,7 @@ def predict_proba(self, X):
np.array with shape (len(X), self.n_classes_)
"""
self.model.eval()
with torch.no_grad():
preds = []
for tree in X:
Expand Down

0 comments on commit 2e37c94

Please sign in to comment.