Skip to content

Commit 1be6330

Browse files
authored
Merge pull request kumar-shridhar#47 from tuero/beta-update-parameter-fix
Beta update has required parameters, train/validate reflect changes
2 parents 286ee55 + 1a62b43 commit 1be6330

File tree

2 files changed

+9
-7
lines changed

2 files changed

+9
-7
lines changed

main_bayesian.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def getModel(net_type, inputs, outputs, layer_type, activation_type):
3030
raise ValueError('Network should be either [LeNet / AlexNet / 3Conv3FC')
3131

3232

33-
def train_model(net, optimizer, criterion, trainloader, num_ens=1, beta_type=0.1):
33+
def train_model(net, optimizer, criterion, trainloader, num_ens=1, beta_type=0.1, epoch=None, num_epochs=None):
3434
net.train()
3535
training_loss = 0.0
3636
accs = []
@@ -59,7 +59,7 @@ def train_model(net, optimizer, criterion, trainloader, num_ens=1, beta_type=0.1
5959
kl_list.append(kl.item())
6060
log_outputs = utils.logmeanexp(outputs, dim=2)
6161

62-
beta = metrics.get_beta(i-1, len(trainloader), beta_type)
62+
beta = metrics.get_beta(i-1, len(trainloader), beta_type, epoch, num_epochs)
6363
loss = criterion(log_outputs, labels, kl, beta)
6464
loss.backward()
6565
optimizer.step()
@@ -69,7 +69,7 @@ def train_model(net, optimizer, criterion, trainloader, num_ens=1, beta_type=0.1
6969
return training_loss/len(trainloader), np.mean(accs), np.mean(kl_list)
7070

7171

72-
def validate_model(net, criterion, validloader, num_ens=1):
72+
def validate_model(net, criterion, validloader, num_ens=1, beta_type=0.1, epoch=None, num_epochs=None):
7373
"""Calculate ensemble accuracy and NLL Loss"""
7474
net.train()
7575
valid_loss = 0.0
@@ -86,7 +86,7 @@ def validate_model(net, criterion, validloader, num_ens=1):
8686

8787
log_outputs = utils.logmeanexp(outputs, dim=2)
8888

89-
beta = metrics.get_beta(i-1, len(validloader), 0.1)
89+
beta = metrics.get_beta(i-1, len(validloader), beta_type, epoch, num_epochs)
9090
valid_loss += criterion(log_outputs, labels, kl, beta).item()
9191
accs.append(metrics.acc(log_outputs, labels))
9292

@@ -126,8 +126,8 @@ def run(dataset, net_type):
126126
for epoch in range(n_epochs): # loop over the dataset multiple times
127127
cfg.curr_epoch_no = epoch
128128

129-
train_loss, train_acc, train_kl = train_model(net, optimizer, criterion, train_loader, num_ens=train_ens, beta_type=beta_type)
130-
valid_loss, valid_acc = validate_model(net, criterion, valid_loader, num_ens=valid_ens)
129+
train_loss, train_acc, train_kl = train_model(net, optimizer, criterion, train_loader, num_ens=train_ens, beta_type=beta_type, epoch=epoch, num_epochs=n_epochs)
130+
valid_loss, valid_acc = validate_model(net, criterion, valid_loader, num_ens=valid_ens, beta_type=beta_type, epoch=epoch, num_epochs=n_epochs)
131131
lr_sched.step(valid_loss)
132132

133133
print('Epoch: {} \tTraining Loss: {:.4f} \tTraining Accuracy: {:.4f} \tValidation Loss: {:.4f} \tValidation Accuracy: {:.4f} \ttrain_kl_div: {:.4f}'.format(

metrics.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,15 @@ def calculate_kl(mu_p, sig_p, mu_q, sig_q):
2929
return kl
3030

3131

32-
def get_beta(batch_idx, m, beta_type):
32+
def get_beta(batch_idx, m, beta_type, epoch, num_epochs):
3333
if type(beta_type) is float:
3434
return beta_type
3535

3636
if beta_type == "Blundell":
3737
beta = 2 ** (m - (batch_idx + 1)) / (2 ** m - 1)
3838
elif beta_type == "Soenderby":
39+
if epoch is None or num_epochs is None:
40+
raise ValueError('Soenderby method requires both epoch and num_epochs to be passed.')
3941
beta = min(epoch / (num_epochs // 4), 1)
4042
elif beta_type == "Standard":
4143
beta = 1 / m

0 commit comments

Comments
 (0)