-
Notifications
You must be signed in to change notification settings - Fork 3
/
train_clean_withTrans_imagenet.py
126 lines (99 loc) · 4.17 KB
/
train_clean_withTrans_imagenet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import sys
import os
import torch
import torchvision
from tqdm import tqdm
import numpy as np
from torch import nn
import torchvision.transforms as transforms
import csv
from utils import args
from utils.utils import save_checkpoint_optimizer, progress_bar
from utils.dataloader import get_dataloader
from utils.network import get_network
def adjust_learning_rate(lr, optimizer, epoch, args):
if epoch in args.schedule:
lr *= args.gamma
for param_group in optimizer.param_groups:
param_group['lr'] = lr
return lr
def train_epoch(arg, trainloader, model, optimizer, criterion, epoch):
model.train()
total_clean, total_clean_correct = 0, 0
train_loss = 0
for i, (inputs, labels) in enumerate(trainloader):
inputs, labels = inputs.to(arg.device), labels.to(arg.device)
outputs = model(inputs)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss += loss.item()
total_clean_correct += torch.sum(torch.argmax(outputs[:], dim=1) == labels[:])
total_clean += inputs.shape[0]
avg_acc_clean = total_clean_correct * 100.0 / total_clean
progress_bar(i, len(trainloader), 'Epoch: %d | Loss: %.3f | Train ACC: %.3f%% (%d/%d)' % (
epoch, train_loss / (i + 1), avg_acc_clean, total_clean_correct, total_clean))
return train_loss / (i + 1), avg_acc_clean
def test_epoch(arg, testloader, model, criterion, epoch):
model.eval()
total_clean = 0
total_clean_correct = 0
test_loss = 0
for i, (inputs, labels) in enumerate(testloader):
inputs, labels = inputs.to(arg.device), labels.to(arg.device)
outputs = model(inputs)
loss = criterion(outputs, labels)
test_loss += loss.item()
total_clean_correct += torch.sum(torch.argmax(outputs[:], dim=1) == labels[:])
total_clean += inputs.shape[0]
avg_acc_clean = total_clean_correct * 100.0 / total_clean
progress_bar(i, len(testloader), 'Epoch: %d | Loss: %.3f | Test ACC: %.3f%% (%d/%d)' % (
epoch, test_loss / (i + 1), avg_acc_clean, total_clean_correct, total_clean))
return test_loss / (i + 1), avg_acc_clean
def main():
global arg
arg = args.get_args()
# Dataset
trainloader = get_dataloader(arg, True)
testloader = get_dataloader(arg, False)
# Prepare model, optimizer
model = get_network(arg)
model = torch.nn.DataParallel(model).cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=arg.lr, momentum=0.9, weight_decay=1e-4)
if arg.checkpoint_load is not None:
checkpoint = torch.load(arg.checkpoint_load)
print("Continue training...")
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
start_epoch = checkpoint['epoch'] + 1
else:
print("Training from scratch...")
start_epoch = 0
# Training and Testing
best_acc = 0
criterion = nn.CrossEntropyLoss()
lr = arg.lr
# Write
save_folder_path = os.path.join('./saved/benign_model/', arg.dataset, arg.model)
if not os.path.exists(save_folder_path):
os.makedirs(save_folder_path)
arg.checkpoint_save = os.path.join(save_folder_path, 'best.tar')
arg.log = os.path.join(save_folder_path, 'benign.csv')
f_name = arg.log
csvFile = open(f_name, 'a', newline='')
writer = csv.writer(csvFile)
writer.writerow(['Epoch', 'Train_Loss', 'Train_ACC', 'Test_Loss', 'Test_ACC'])
for epoch in tqdm(range(start_epoch, arg.epochs)):
# Set learning rate
lr = adjust_learning_rate(lr, optimizer, epoch, arg)
print('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, arg.epochs, lr))
train_loss, train_acc = train_epoch(arg, trainloader, model, optimizer, criterion, epoch)
test_loss, test_acc = test_epoch(arg, testloader, model, criterion, epoch)
if test_acc > best_acc:
best_acc = test_acc
save_checkpoint_optimizer(arg.checkpoint_save, epoch, model, optimizer)
writer.writerow([epoch, train_loss, train_acc.item(), test_loss, test_acc.item()])
csvFile.close()
if __name__ == '__main__':
main()