-
Notifications
You must be signed in to change notification settings - Fork 324
/
main_bayesian.py
executable file
·142 lines (110 loc) · 5.42 KB
/
main_bayesian.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
from __future__ import print_function
import os
import argparse
import torch
import numpy as np
from torch.optim import Adam, lr_scheduler
from torch.nn import functional as F
import data
import utils
import metrics
import config_bayesian as cfg
from models.BayesianModels.Bayesian3Conv3FC import BBB3Conv3FC
from models.BayesianModels.BayesianAlexNet import BBBAlexNet
from models.BayesianModels.BayesianLeNet import BBBLeNet
# CUDA settings
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def getModel(net_type, inputs, outputs, priors, layer_type, activation_type):
if (net_type == 'lenet'):
return BBBLeNet(outputs, inputs, priors, layer_type, activation_type)
elif (net_type == 'alexnet'):
return BBBAlexNet(outputs, inputs, priors, layer_type, activation_type)
elif (net_type == '3conv3fc'):
return BBB3Conv3FC(outputs, inputs, priors, layer_type, activation_type)
else:
raise ValueError('Network should be either [LeNet / AlexNet / 3Conv3FC')
def train_model(net, optimizer, criterion, trainloader, num_ens=1, beta_type=0.1, epoch=None, num_epochs=None):
net.train()
training_loss = 0.0
accs = []
kl_list = []
for i, (inputs, labels) in enumerate(trainloader, 1):
optimizer.zero_grad()
inputs, labels = inputs.to(device), labels.to(device)
outputs = torch.zeros(inputs.shape[0], net.num_classes, num_ens).to(device)
kl = 0.0
for j in range(num_ens):
net_out, _kl = net(inputs)
kl += _kl
outputs[:, :, j] = F.log_softmax(net_out, dim=1)
kl = kl / num_ens
kl_list.append(kl.item())
log_outputs = utils.logmeanexp(outputs, dim=2)
beta = metrics.get_beta(i-1, len(trainloader), beta_type, epoch, num_epochs)
loss = criterion(log_outputs, labels, kl, beta)
loss.backward()
optimizer.step()
accs.append(metrics.acc(log_outputs.data, labels))
training_loss += loss.cpu().data.numpy()
return training_loss/len(trainloader), np.mean(accs), np.mean(kl_list)
def validate_model(net, criterion, validloader, num_ens=1, beta_type=0.1, epoch=None, num_epochs=None):
"""Calculate ensemble accuracy and NLL Loss"""
net.train()
valid_loss = 0.0
accs = []
for i, (inputs, labels) in enumerate(validloader):
inputs, labels = inputs.to(device), labels.to(device)
outputs = torch.zeros(inputs.shape[0], net.num_classes, num_ens).to(device)
kl = 0.0
for j in range(num_ens):
net_out, _kl = net(inputs)
kl += _kl
outputs[:, :, j] = F.log_softmax(net_out, dim=1).data
log_outputs = utils.logmeanexp(outputs, dim=2)
beta = metrics.get_beta(i-1, len(validloader), beta_type, epoch, num_epochs)
valid_loss += criterion(log_outputs, labels, kl, beta).item()
accs.append(metrics.acc(log_outputs, labels))
return valid_loss/len(validloader), np.mean(accs)
def run(dataset, net_type):
# Hyper Parameter settings
layer_type = cfg.layer_type
activation_type = cfg.activation_type
priors = cfg.priors
train_ens = cfg.train_ens
valid_ens = cfg.valid_ens
n_epochs = cfg.n_epochs
lr_start = cfg.lr_start
num_workers = cfg.num_workers
valid_size = cfg.valid_size
batch_size = cfg.batch_size
beta_type = cfg.beta_type
trainset, testset, inputs, outputs = data.getDataset(dataset)
train_loader, valid_loader, test_loader = data.getDataloader(
trainset, testset, valid_size, batch_size, num_workers)
net = getModel(net_type, inputs, outputs, priors, layer_type, activation_type).to(device)
ckpt_dir = f'checkpoints/{dataset}/bayesian'
ckpt_name = f'checkpoints/{dataset}/bayesian/model_{net_type}_{layer_type}_{activation_type}.pt'
if not os.path.exists(ckpt_dir):
os.makedirs(ckpt_dir, exist_ok=True)
criterion = metrics.ELBO(len(trainset)).to(device)
optimizer = Adam(net.parameters(), lr=lr_start)
lr_sched = lr_scheduler.ReduceLROnPlateau(optimizer, patience=6, verbose=True)
valid_loss_max = np.Inf
for epoch in range(n_epochs): # loop over the dataset multiple times
train_loss, train_acc, train_kl = train_model(net, optimizer, criterion, train_loader, num_ens=train_ens, beta_type=beta_type, epoch=epoch, num_epochs=n_epochs)
valid_loss, valid_acc = validate_model(net, criterion, valid_loader, num_ens=valid_ens, beta_type=beta_type, epoch=epoch, num_epochs=n_epochs)
lr_sched.step(valid_loss)
print('Epoch: {} \tTraining Loss: {:.4f} \tTraining Accuracy: {:.4f} \tValidation Loss: {:.4f} \tValidation Accuracy: {:.4f} \ttrain_kl_div: {:.4f}'.format(
epoch, train_loss, train_acc, valid_loss, valid_acc, train_kl))
# save model if validation accuracy has increased
if valid_loss <= valid_loss_max:
print('Validation loss decreased ({:.6f} --> {:.6f}). Saving model ...'.format(
valid_loss_max, valid_loss))
torch.save(net.state_dict(), ckpt_name)
valid_loss_max = valid_loss
if __name__ == '__main__':
parser = argparse.ArgumentParser(description = "PyTorch Bayesian Model Training")
parser.add_argument('--net_type', default='lenet', type=str, help='model')
parser.add_argument('--dataset', default='MNIST', type=str, help='dataset = [MNIST/CIFAR10/CIFAR100]')
args = parser.parse_args()
run(args.dataset, args.net_type)