@@ -57,6 +57,7 @@ def train(params: ModelParams):
57
57
optimizer = torch .optim .SGD (model .parameters (), lr = params .lr , momentum = params .momentum )
58
58
loss_fn = torch .nn .BCEWithLogitsLoss ()
59
59
train_loader = DataLoader (XORDataset (), batch_size = params .batch_size , shuffle = True )
60
+ test_loader = DataLoader (XORDataset (train = False ), batch_size = params .batch_size )
60
61
61
62
step = 0
62
63
@@ -75,31 +76,27 @@ def train(params: ModelParams):
75
76
optimizer .step ()
76
77
step += 1
77
78
78
- loss_val = loss .item ()
79
- accuracy_val = ((predictions > 0.5 ) == (targets > 0.5 )).type (torch .FloatTensor ).mean ()
79
+ accuracy = ((predictions > 0.5 ) == (targets > 0.5 )).type (torch .FloatTensor ).mean ()
80
80
81
81
if step % 500 == 0 :
82
- print (f'epoch { epoch } , step { step } , loss { loss_val :.{4 }f} , accuracy { accuracy_val :.{3 }f} ' )
82
+ print (f'epoch { epoch } , step { step } , loss { loss . item () :.{4 }f} , accuracy { accuracy :.{3 }f} ' )
83
83
84
84
# evaluate per epoch
85
- evaluate (model )
86
-
85
+ evaluate (model , test_loader )
87
86
88
- def evaluate (model ):
89
- test_loader = DataLoader (XORDataset (train = False ), batch_size = params .batch_size )
90
87
91
- prediction_is_correct = np .array ([])
88
+ def evaluate (model , loader ):
89
+ is_correct = np .array ([])
92
90
93
- for inputs , targets in test_loader :
91
+ for inputs , targets in loader :
94
92
inputs = inputs .to (params .device )
95
93
targets = targets .to (params .device )
96
94
with torch .no_grad ():
97
95
logits , predictions = model (inputs )
98
- prediction_is_correct = np .append (prediction_is_correct ,
99
- ((predictions > 0.5 ) == (targets > 0.5 )))
96
+ is_correct = np .append (is_correct , ((predictions > 0.5 ) == (targets > 0.5 )))
100
97
101
- accuracy_val = prediction_is_correct .mean ()
102
- print (f'test accuracy { accuracy_val :.{3 }f} ' )
98
+ accuracy = is_correct .mean ()
99
+ print (f'test accuracy { accuracy :.{3 }f} ' )
103
100
104
101
105
102
def get_arguments ():
0 commit comments