@@ -30,7 +30,7 @@ def getModel(net_type, inputs, outputs, layer_type, activation_type):
30
30
raise ValueError ('Network should be either [LeNet / AlexNet / 3Conv3FC' )
31
31
32
32
33
- def train_model (net , optimizer , criterion , trainloader , num_ens = 1 , beta_type = 0.1 ):
33
+ def train_model (net , optimizer , criterion , trainloader , num_ens = 1 , beta_type = 0.1 , epoch = None , num_epochs = None ):
34
34
net .train ()
35
35
training_loss = 0.0
36
36
accs = []
@@ -59,7 +59,7 @@ def train_model(net, optimizer, criterion, trainloader, num_ens=1, beta_type=0.1
59
59
kl_list .append (kl .item ())
60
60
log_outputs = utils .logmeanexp (outputs , dim = 2 )
61
61
62
- beta = metrics .get_beta (i - 1 , len (trainloader ), beta_type )
62
+ beta = metrics .get_beta (i - 1 , len (trainloader ), beta_type , epoch , num_epochs )
63
63
loss = criterion (log_outputs , labels , kl , beta )
64
64
loss .backward ()
65
65
optimizer .step ()
@@ -69,7 +69,7 @@ def train_model(net, optimizer, criterion, trainloader, num_ens=1, beta_type=0.1
69
69
return training_loss / len (trainloader ), np .mean (accs ), np .mean (kl_list )
70
70
71
71
72
- def validate_model (net , criterion , validloader , num_ens = 1 ):
72
+ def validate_model (net , criterion , validloader , num_ens = 1 , beta_type = 0.1 , epoch = None , num_epochs = None ):
73
73
"""Calculate ensemble accuracy and NLL Loss"""
74
74
net .train ()
75
75
valid_loss = 0.0
@@ -86,7 +86,7 @@ def validate_model(net, criterion, validloader, num_ens=1):
86
86
87
87
log_outputs = utils .logmeanexp (outputs , dim = 2 )
88
88
89
- beta = metrics .get_beta (i - 1 , len (validloader ), 0.1 )
89
+ beta = metrics .get_beta (i - 1 , len (validloader ), beta_type , epoch , num_epochs )
90
90
valid_loss += criterion (log_outputs , labels , kl , beta ).item ()
91
91
accs .append (metrics .acc (log_outputs , labels ))
92
92
@@ -126,8 +126,8 @@ def run(dataset, net_type):
126
126
for epoch in range (n_epochs ): # loop over the dataset multiple times
127
127
cfg .curr_epoch_no = epoch
128
128
129
- train_loss , train_acc , train_kl = train_model (net , optimizer , criterion , train_loader , num_ens = train_ens , beta_type = beta_type )
130
- valid_loss , valid_acc = validate_model (net , criterion , valid_loader , num_ens = valid_ens )
129
+ train_loss , train_acc , train_kl = train_model (net , optimizer , criterion , train_loader , num_ens = train_ens , beta_type = beta_type , epoch = epoch , num_epochs = n_epochs )
130
+ valid_loss , valid_acc = validate_model (net , criterion , valid_loader , num_ens = valid_ens , beta_type = beta_type , epoch = epoch , num_epochs = n_epochs )
131
131
lr_sched .step (valid_loss )
132
132
133
133
print ('Epoch: {} \t Training Loss: {:.4f} \t Training Accuracy: {:.4f} \t Validation Loss: {:.4f} \t Validation Accuracy: {:.4f} \t train_kl_div: {:.4f}' .format (
0 commit comments