Skip to content

Commit

Permalink
Fixed a bug with batch_forward
Browse files Browse the repository at this point in the history
Issue: #3
  • Loading branch information
rahulvigneswaran authored Jan 11, 2022
1 parent 85c09ac commit 7b65b32
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions libs/core/core_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ def train(self, retrain=False):

with torch.set_grad_enabled(True):
# If training, forward with loss, and no top 5 accuracy calculation
self.batch_forward(inputs, labels, phase="train", retrain=retrain)
self.batch_forward(inputs)
self.batch_loss(labels)
self.batch_backward()

Expand Down Expand Up @@ -654,7 +654,7 @@ def eval(self, phase='val'):
with torch.set_grad_enabled(False):

# In validation or testing
self.batch_forward(inputs, labels, phase=phase)
self.batch_forward(inputs)
self.batch_loss(labels)
minibatch_loss_total.append(self.loss.item())

Expand Down Expand Up @@ -856,4 +856,4 @@ def resume_run(self, saved_dict):

# This is there so that we can use source_import from the utils to import model
def get_core(*args):
return model(*args)
return model(*args)

0 comments on commit 7b65b32

Please sign in to comment.