-
Notifications
You must be signed in to change notification settings - Fork 14
/
unmix_c100.py
454 lines (366 loc) · 18.1 KB
/
unmix_c100.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
# Code adapted from moco_demo (https://colab.research.google.com/github/facebookresearch/moco/blob/colab-notebook/colab/moco_cifar10_demo.ipynb)
from datetime import datetime
from functools import partial
from PIL import Image
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import CIFAR100
from torchvision.models import resnet
from tqdm import tqdm
import argparse
import json
import math
import os
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import utils
parser = argparse.ArgumentParser(description='Train MoCo on CIFAR-10')
parser.add_argument('-a', '--arch', default='resnet18')
# lr: 0.06 for batch 512 (or 0.03 for batch 256)
parser.add_argument('--lr', '--learning-rate', default=0.06, type=float, metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--epochs', default=1000, type=int, metavar='N', help='number of total epochs to run')
parser.add_argument('--schedule', default=[600, 900], nargs='*', type=int, help='learning rate schedule (when to drop lr by 10x); does not take effect if --cos is on')
parser.add_argument('--cos', action='store_true', help='use cosine lr schedule')
parser.add_argument('--batch-size', default=512, type=int, metavar='N', help='mini-batch size')
parser.add_argument('--wd', default=5e-4, type=float, metavar='W', help='weight decay')
# moco specific configs:
parser.add_argument('--moco-dim', default=128, type=int, help='feature dimension')
parser.add_argument('--moco-k', default=4096, type=int, help='queue size; number of negative keys')
parser.add_argument('--moco-m', default=0.99, type=float, help='moco momentum of updating key encoder')
parser.add_argument('--moco-t', default=0.1, type=float, help='softmax temperature')
parser.add_argument('--prob', default=0.5, type=float, help='prob for choosing region or global mixture')
parser.add_argument('--bn-splits', default=8, type=int, help='simulate multi-gpu behavior of BatchNorm in one gpu; 1 is SyncBatchNorm in multi-gpu')
parser.add_argument('--symmetric', action='store_true', help='use a symmetric loss function that backprops to both crops')
# knn monitor
parser.add_argument('--knn-k', default=200, type=int, help='k in kNN monitor')
parser.add_argument('--knn-t', default=0.1, type=float, help='softmax temperature in kNN monitor; could be different with moco-t')
# utils
parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)')
parser.add_argument('--results-dir', default='', type=str, metavar='PATH', help='path to cache (default: none)')
'''
args = parser.parse_args() # running in command line
'''
args = parser.parse_args('') # running in ipynb
# set command line arguments here when running in ipynb
args.epochs = 1000
args.cos = True
args.schedule = [] # cos in use
args.symmetric = True
args.results_dir = './UnMix_symmetric_C100'
if args.results_dir == '':
args.results_dir = './cache-' + datetime.now().strftime("%Y-%m-%d-%H-%M-%S-moco")
print(args)
class CIFAR100Pair(CIFAR100):
"""CIFAR100 Dataset.
"""
def __getitem__(self, index):
img = self.data[index]
img = Image.fromarray(img)
if self.transform is not None:
im_1 = self.transform(img)
im_2 = self.transform(img)
return im_1, im_2
train_transform = transforms.Compose([
transforms.RandomResizedCrop(32),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
transforms.RandomGrayscale(p=0.2),
transforms.ToTensor(),
transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])])
test_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])])
# data prepare
train_data = CIFAR100Pair(root='data', train=True, transform=train_transform, download=True)
train_loader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=16, pin_memory=True, drop_last=True)
memory_data = CIFAR100(root='data', train=True, transform=test_transform, download=True)
memory_loader = DataLoader(memory_data, batch_size=args.batch_size, shuffle=False, num_workers=16, pin_memory=True)
test_data = CIFAR100(root='data', train=False, transform=test_transform, download=True)
test_loader = DataLoader(test_data, batch_size=args.batch_size, shuffle=False, num_workers=16, pin_memory=True)
# SplitBatchNorm: simulate multi-gpu behavior of BatchNorm in one gpu by splitting alone the batch dimension
# implementation adapted from https://github.com/davidcpage/cifar10-fast/blob/master/torch_backend.py
class SplitBatchNorm(nn.BatchNorm2d):
def __init__(self, num_features, num_splits, **kw):
super().__init__(num_features, **kw)
self.num_splits = num_splits
def forward(self, input):
N, C, H, W = input.shape
if self.training or not self.track_running_stats:
running_mean_split = self.running_mean.repeat(self.num_splits)
running_var_split = self.running_var.repeat(self.num_splits)
outcome = nn.functional.batch_norm(
input.view(-1, C * self.num_splits, H, W), running_mean_split, running_var_split,
self.weight.repeat(self.num_splits), self.bias.repeat(self.num_splits),
True, self.momentum, self.eps).view(N, C, H, W)
self.running_mean.data.copy_(running_mean_split.view(self.num_splits, C).mean(dim=0))
self.running_var.data.copy_(running_var_split.view(self.num_splits, C).mean(dim=0))
return outcome
else:
return nn.functional.batch_norm(
input, self.running_mean, self.running_var,
self.weight, self.bias, False, self.momentum, self.eps)
class ModelBase(nn.Module):
"""
Common CIFAR ResNet recipe.
Comparing with ImageNet ResNet recipe, it:
(i) replaces conv1 with kernel=3, str=1
(ii) removes pool1
"""
def __init__(self, feature_dim=128, arch=None, bn_splits=16):
super(ModelBase, self).__init__()
# use split batchnorm
norm_layer = partial(SplitBatchNorm, num_splits=bn_splits) if bn_splits > 1 else nn.BatchNorm2d
resnet_arch = getattr(resnet, arch)
net = resnet_arch(num_classes=feature_dim, norm_layer=norm_layer)
self.net = []
for name, module in net.named_children():
if name == 'conv1':
module = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
if isinstance(module, nn.MaxPool2d):
continue
if isinstance(module, nn.Linear):
self.net.append(nn.Flatten(1))
self.net.append(module)
self.net = nn.Sequential(*self.net)
def forward(self, x):
x = self.net(x)
# note: not normalized here
return x
class ModelMoCo(nn.Module):
def __init__(self, dim=128, K=4096, m=0.99, T=0.1, arch='resnet18', bn_splits=8, symmetric=True):
super(ModelMoCo, self).__init__()
self.K = K
self.m = m
self.T = T
self.symmetric = symmetric
# create the encoders
self.encoder_q = ModelBase(feature_dim=dim, arch=arch, bn_splits=bn_splits)
self.encoder_k = ModelBase(feature_dim=dim, arch=arch, bn_splits=bn_splits)
for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
param_k.data.copy_(param_q.data) # initialize
param_k.requires_grad = False # not update by gradient
# create the queue
self.register_buffer("queue", torch.randn(dim, K))
self.queue = nn.functional.normalize(self.queue, dim=0)
self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
@torch.no_grad()
def _momentum_update_key_encoder(self):
"""
Momentum update of the key encoder
"""
for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)
@torch.no_grad()
def _dequeue_and_enqueue(self, keys):
batch_size = keys.shape[0]
ptr = int(self.queue_ptr)
assert self.K % batch_size == 0 # for simplicity
# replace the keys at ptr (dequeue and enqueue)
self.queue[:, ptr:ptr + batch_size] = keys.t() # transpose
ptr = (ptr + batch_size) % self.K # move pointer
self.queue_ptr[0] = ptr
@torch.no_grad()
def _batch_shuffle_single_gpu(self, x):
"""
Batch shuffle, for making use of BatchNorm.
"""
# random shuffle index
idx_shuffle = torch.randperm(x.shape[0]).cuda()
# index for restoring
idx_unshuffle = torch.argsort(idx_shuffle)
return x[idx_shuffle], idx_unshuffle
@torch.no_grad()
def _batch_unshuffle_single_gpu(self, x, idx_unshuffle):
"""
Undo batch shuffle.
"""
return x[idx_unshuffle]
def contrastive_loss(self, im_q, im_k, im1_mixed, im1_mixed_re):
# compute query features
q = self.encoder_q(im_q) # queries: NxC
q = nn.functional.normalize(q, dim=1) # already normalized
q_mixed = self.encoder_q(im1_mixed)
# alternative implementation: q_mixed_flip = self.encoder_q(im1_mixed_re)
q_mixed_flip = torch.flip(q_mixed, (0,))
q_mixed = nn.functional.normalize(q_mixed, dim=1)
q_mixed_flip = nn.functional.normalize(q_mixed_flip, dim=1)
# compute key features
with torch.no_grad(): # no gradient to keys
# shuffle for making use of BN
im_k_, idx_unshuffle = self._batch_shuffle_single_gpu(im_k)
k = self.encoder_k(im_k_) # keys: NxC
k = nn.functional.normalize(k, dim=1) # already normalized
# undo shuffle
k = self._batch_unshuffle_single_gpu(k, idx_unshuffle)
# compute logits
# Einstein sum is more intuitive
# positive logits: Nx1
l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
l_pos_m = torch.einsum('nc,nc->n', [q_mixed, k]).unsqueeze(-1)
l_pos_m_flip = torch.einsum('nc,nc->n', [q_mixed_flip, k]).unsqueeze(-1)
# negative logits: NxK
l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])
l_neg_m = torch.einsum('nc,ck->nk', [q_mixed, self.queue.clone().detach()])
l_neg_m_flip = torch.einsum('nc,ck->nk', [q_mixed_flip, self.queue.clone().detach()])
# logits: Nx(1+K)
logits = torch.cat([l_pos, l_neg], dim=1)
logits_m = torch.cat([l_pos_m, l_neg_m], dim=1)
logits_m_flip = torch.cat([l_pos_m_flip, l_neg_m_flip], dim=1)
# apply temperature
logits /= self.T
logits_m /= self.T
logits_m_flip /= self.T
# labels: positive key indicators
labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()
loss = nn.CrossEntropyLoss().cuda()(logits, labels)
loss_m1 = nn.CrossEntropyLoss().cuda()(logits_m, labels)
loss_m2 = nn.CrossEntropyLoss().cuda()(logits_m_flip, labels)
return loss, q, k, loss_m1, loss_m2
def forward(self, im1, im2, im1_mixed, im1_mixed_re, lam):
"""
Input:
im_q: a batch of query images
im_k: a batch of key images
Output:
loss
"""
# update the key encoder
with torch.no_grad(): # no gradient to keys
self._momentum_update_key_encoder()
# compute loss
if self.symmetric: # symmetric loss
loss_12, q1, k2, loss_m11, loss_m12 = self.contrastive_loss(im1, im2, im1_mixed, im1_mixed_re)
loss_21, q2, k1, loss_m21, loss_m22 = self.contrastive_loss(im2, im1, im1_mixed, im1_mixed_re)
loss = loss_12 + loss_21 + lam*loss_m11 + (1-lam)*loss_m12 + lam*loss_m21 + (1-lam)*loss_m22
k = torch.cat([k1, k2], dim=0)
else: # asymmetric loss
loss_0, q, k, loss_m11, loss_m12 = self.contrastive_loss(im1, im2, im1_mixed, im1_mixed_re)
loss = loss_0 + lam*loss_m11 + (1-lam)*loss_m12
self._dequeue_and_enqueue(k)
return loss
# create model
model = ModelMoCo(
dim=args.moco_dim,
K=args.moco_k,
m=args.moco_m,
T=args.moco_t,
arch=args.arch,
bn_splits=args.bn_splits,
symmetric=args.symmetric,
).cuda()
print(model.encoder_q)
# train for one epoch
def train(net, data_loader, train_optimizer, epoch, args):
net.train()
adjust_learning_rate(optimizer, epoch, args)
total_loss, total_num, train_bar = 0.0, 0, tqdm(data_loader)
for im_1, im_2 in train_bar:
im_1, im_2 = im_1.cuda(non_blocking=True), im_2.cuda(non_blocking=True)
r = np.random.rand(1)
args.beta = 1.0
lam = np.random.beta(args.beta, args.beta)
images_reverse = torch.flip(im_1, (0,))
if r < args.prob:
mixed_images = lam * im_1 + (1 - lam) * images_reverse
mixed_images_flip = torch.flip(mixed_images, (0,))
else:
mixed_images = im_1.clone()
bbx1, bby1, bbx2, bby2 = utils.rand_bbox(im_1.size(), lam)
mixed_images[:, :, bbx1:bbx2, bby1:bby2] = images_reverse[:, :, bbx1:bbx2, bby1:bby2]
mixed_images_flip = torch.flip(mixed_images, (0,))
# # adjust lambda to exactly match pixel ratio
lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (im_1.size()[-1] * im_1.size()[-2]))
loss = net(im_1, im_2, mixed_images, mixed_images_flip, lam)
train_optimizer.zero_grad()
loss.backward()
train_optimizer.step()
total_num += data_loader.batch_size
total_loss += loss.item() * data_loader.batch_size
train_bar.set_description('Train Epoch: [{}/{}], lr: {:.6f}, Loss: {:.4f}'.format(epoch, args.epochs, optimizer.param_groups[0]['lr'], total_loss / total_num))
return total_loss / total_num
# lr scheduler for training
def adjust_learning_rate(optimizer, epoch, args):
"""Decay the learning rate based on schedule"""
lr = args.lr
if args.cos: # cosine lr schedule
lr *= 0.5 * (1. + math.cos(math.pi * epoch / args.epochs))
else: # stepwise lr schedule
for milestone in args.schedule:
lr *= 0.1 if epoch >= milestone else 1.
for param_group in optimizer.param_groups:
param_group['lr'] = lr
# test using a knn monitor
def test(net, memory_data_loader, test_data_loader, epoch, args):
net.eval()
classes = len(memory_data_loader.dataset.classes)
total_top1, total_top5, total_num, feature_bank = 0.0, 0.0, 0, []
with torch.no_grad():
# generate feature bank
for data, target in tqdm(memory_data_loader, desc='Feature extracting'):
feature = net(data.cuda(non_blocking=True))
feature = F.normalize(feature, dim=1)
feature_bank.append(feature)
# [D, N]
feature_bank = torch.cat(feature_bank, dim=0).t().contiguous()
# [N]
feature_labels = torch.tensor(memory_data_loader.dataset.targets, device=feature_bank.device)
# loop test data to predict the label by weighted knn search
test_bar = tqdm(test_data_loader)
for data, target in test_bar:
data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True)
feature = net(data)
feature = F.normalize(feature, dim=1)
pred_labels = knn_predict(feature, feature_bank, feature_labels, classes, args.knn_k, args.knn_t)
total_num += data.size(0)
total_top1 += (pred_labels[:, 0] == target).float().sum().item()
test_bar.set_description('Test Epoch: [{}/{}] Acc@1:{:.2f}%'.format(epoch, args.epochs, total_top1 / total_num * 100))
return total_top1 / total_num * 100
# knn monitor as in InstDisc https://arxiv.org/abs/1805.01978
# implementation follows http://github.com/zhirongw/lemniscate.pytorch and https://github.com/leftthomas/SimCLR
def knn_predict(feature, feature_bank, feature_labels, classes, knn_k, knn_t):
# compute cos similarity between each feature vector and feature bank ---> [B, N]
sim_matrix = torch.mm(feature, feature_bank)
# [B, K]
sim_weight, sim_indices = sim_matrix.topk(k=knn_k, dim=-1)
# [B, K]
sim_labels = torch.gather(feature_labels.expand(feature.size(0), -1), dim=-1, index=sim_indices)
sim_weight = (sim_weight / knn_t).exp()
# counts for each class
one_hot_label = torch.zeros(feature.size(0) * knn_k, classes, device=sim_labels.device)
# [B*K, C]
one_hot_label = one_hot_label.scatter(dim=-1, index=sim_labels.view(-1, 1), value=1.0)
# weighted score ---> [B, C]
pred_scores = torch.sum(one_hot_label.view(feature.size(0), -1, classes) * sim_weight.unsqueeze(dim=-1), dim=1)
pred_labels = pred_scores.argsort(dim=-1, descending=True)
return pred_labels
# define optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.wd, momentum=0.9)
# load model if resume
epoch_start = 1
if args.resume is not '':
checkpoint = torch.load(args.resume)
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch_start = checkpoint['epoch'] + 1
print('Loaded from: {}'.format(args.resume))
# logging
results = {'train_loss': [], 'test_acc@1': []}
if not os.path.exists(args.results_dir):
os.mkdir(args.results_dir)
# dump args
with open(args.results_dir + '/args.json', 'w') as fid:
json.dump(args.__dict__, fid, indent=2)
# training loop
for epoch in range(epoch_start, args.epochs + 1):
train_loss = train(model, train_loader, optimizer, epoch, args)
results['train_loss'].append(train_loss)
test_acc_1 = test(model.encoder_q, memory_loader, test_loader, epoch, args)
results['test_acc@1'].append(test_acc_1)
# save statistics
data_frame = pd.DataFrame(data=results, index=range(epoch_start, epoch + 1))
data_frame.to_csv(args.results_dir + '/log.csv', index_label='epoch')
# save model
torch.save({'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer' : optimizer.state_dict(),}, args.results_dir + '/model_last.pth')