@@ -158,8 +158,9 @@ def train(dataloader, model, loss_fn, optimizer):
158158##############################################################################
159159# We also check the model's performance against the test dataset to ensure it is learning.
160160
161- def test (dataloader , model ):
161+ def test (dataloader , model , loss_fn ):
162162 size = len (dataloader .dataset )
163+ num_batches = len (dataloader )
163164 model .eval ()
164165 test_loss , correct = 0 , 0
165166 with torch .no_grad ():
@@ -168,7 +169,7 @@ def test(dataloader, model):
168169 pred = model (X )
169170 test_loss += loss_fn (pred , y ).item ()
170171 correct += (pred .argmax (1 ) == y ).type (torch .float ).sum ().item ()
171- test_loss /= size
172+ test_loss /= num_batches
172173 correct /= size
173174 print (f"Test Error: \n Accuracy: { (100 * correct ):>0.1f} %, Avg loss: { test_loss :>8f} \n " )
174175
@@ -181,7 +182,7 @@ def test(dataloader, model):
181182for t in range (epochs ):
182183 print (f"Epoch { t + 1 } \n -------------------------------" )
183184 train (train_dataloader , model , loss_fn , optimizer )
184- test (test_dataloader , model )
185+ test (test_dataloader , model , loss_fn )
185186print ("Done!" )
186187
187188######################################################################
0 commit comments