-
Notifications
You must be signed in to change notification settings - Fork 3
/
train.py
128 lines (94 loc) · 4.33 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import torch
import torch.nn.functional as F
import os
import argparse
from tqdm import tqdm
from data import get_dataloader
from models import get_model
from utils import save
def train(net, data_loader, optimizer, scheduler):
net.train()
sum_loss = 0
num_samples = 0
pbar = tqdm(data_loader)
for images, targets in pbar:
images, targets = images.to(args.device), targets.to(args.device)
optimizer.zero_grad()
logits = net(images)
loss = F.cross_entropy(logits, targets)
loss.backward()
optimizer.step()
sum_loss += loss.item() * images.shape[0]
num_samples += images.shape[0]
pbar.set_description('Average Loss: %.2f' % (sum_loss/num_samples))
pbar.close()
scheduler.step()
return sum_loss
def val(net, data_loader):
net.eval()
n_correct = 0
n_total = 0
for images, targets in data_loader:
images, targets = images.to(args.device), targets.to(args.device)
logits = net(images)
prediction = logits.argmax(-1)
n_correct += (prediction==targets).sum()
n_total += targets.shape[0]
acc = n_correct / n_total * 100
return acc
def main(args):
print(args)
num_classes, train_loader, val_loader, holdout_loader, test_clean_loader, test_poisoned_loader, trigger = get_dataloader(args)
net = get_model(args.model, num_classes).to(args.device)
optimizer = torch.optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs)
print('Start Trainig')
for epoch in range(args.epochs):
print('Epoch %d:' % epoch)
loss = train(net, train_loader, optimizer, scheduler)
if (epoch+1) % 10 == 0:
acc = val(net, val_loader)
print('Validation accuracy: %.2f' % acc)
acc, asr = val(net, test_clean_loader), val(net, test_poisoned_loader)
print('Test clean accuracy: %.2f' % acc)
print('Test attack success rate: %.2f' % asr)
save(net, trigger, args)
print('Training finished')
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='PyTorch Backdoor Training')
parser.add_argument('--model', default='resnet18', type=str,
help='network structure choice')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
# Optimization options
parser.add_argument('--epochs', default=150, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('--batch-size', default=512, type=int, metavar='N',
help='batch size')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
metavar='LR', help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float,
metavar='W', help='weight decay (default: 5e-4)')
# Checkpoints
parser.add_argument('-c', '--checkpoint', default='./ckpt', type=str, metavar='PATH',
help='path to save checkpoint (default: checkpoint)')
# Miscs
parser.add_argument('--manual-seed', default=0, type=int, help='manual seed')
# Device options
parser.add_argument('--device', default='cuda:0', type=str,
help='device used for training')
# data path
parser.add_argument('--dataset', type=str, default='cifar10')
parser.add_argument('--dataset-dir', type=str, default='../data/cifar10')
# backdoor setting
parser.add_argument('--attack-type', type=str, default='badnets')
parser.add_argument('--target_label', type=int, default=0, help='backdoor target label.')
parser.add_argument('--poisoning-rate', type=float, default=0.1, help='backdoor training sample ratio.')
parser.add_argument('--trigger-size', type=int, default=3, help='size of square backdoor trigger.')
args = parser.parse_args()
torch.manual_seed(args.manual_seed)
torch.cuda.manual_seed(args.manual_seed)
torch.backends.cudnn.deterministic=True
main(args)