From eafd574468110d38be9e570132295ee4bf4c8a67 Mon Sep 17 00:00:00 2001 From: Andrew Zhao Date: Tue, 8 Dec 2020 21:11:40 +0800 Subject: [PATCH] modified moco to match official implementation --- moco/main.py | 13 ++++++++----- moco/model.py | 30 +++++++++++++++++------------- moco/utils.py | 32 ++++++++++++++------------------ networks/layers.py | 18 ++++-------------- simsiam/main.py | 2 +- utils/contrastive.py | 4 ++-- 6 files changed, 46 insertions(+), 53 deletions(-) diff --git a/moco/main.py b/moco/main.py index cd2b9b1..5ed84d0 100644 --- a/moco/main.py +++ b/moco/main.py @@ -77,6 +77,8 @@ f_q = torch.nn.DataParallel(MoCo(args), device_ids=[0, 1]).to(device) f_k = get_momentum_encoder(f_q) + for name, _ in f_q.named_modules(): + print(name) criterion = MoCoLoss(args.temperature) optimizer = torch.optim.SGD(f_q.parameters(), lr=args.lr, @@ -98,9 +100,12 @@ pbar = tqdm(range(start_epoch, args.epochs)) for epoch in pbar: train_losses = [] + f_q.train() + f_k.train() for x1, x2 in train_loader: q1, q2 = f_q(x1), f_q(x2) with torch.no_grad(): + momentum_update(f_k, f_q, args.m) k1, k2 = f_k(x1), f_k(x2) loss = criterion(q1, k2, memo_bank) + criterion(q2, k1, memo_bank) optimizer.zero_grad() @@ -108,8 +113,6 @@ optimizer.step() k = torch.cat([k1, k2], dim=0) memo_bank.dequeue_and_enqueue(k) - with torch.no_grad(): - momentum_update(f_k, f_q, args.m) train_losses.append(loss.item()) pbar.set_postfix({'Loss': loss.item(), 'Learning Rate': scheduler.get_last_lr()[0]}) @@ -118,13 +121,13 @@ f_q.eval() # extract features as training data - feature_bank, feature_labels = get_feature_label(f_q, momentum_loader, device, normalize=True) + feature_bank, feature_labels = get_feature_label(f_q, momentum_loader, device, normalize=False) - linear_classifier = Linear_Probe(len(momentum_data.classes), hidden_dim=args.feature_dim).to(device) + linear_classifier = Linear_Probe(num_classes=len(momentum_data.classes), in_features=f_q.out_features).to(device) linear_classifier.fit(feature_bank, feature_labels) # using linear classifier to predict test data - y_preds, y_trues = get_feature_label(f_q, test_loader, device, normalize=True, predictor=linear_classifier) + y_preds, y_trues = get_feature_label(f_q, test_loader, device, normalize=False, predictor=linear_classifier) top1acc = y_trues.eq(y_preds).sum().item() / y_preds.size(0) writer.add_scalar('Top Acc @ 1', top1acc, global_step=epoch) diff --git a/moco/model.py b/moco/model.py index 6f527f2..3d061fb 100644 --- a/moco/model.py +++ b/moco/model.py @@ -3,13 +3,6 @@ import torchvision from networks.layers import ConvNormAct -class Print(nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - print(x.size()) - return x class MoCo(nn.Module): def __init__(self, args): @@ -30,17 +23,28 @@ def __init__(self, args): ConvNormAct(3, 32, mode='down'), ConvNormAct(32, 64, mode='down'), ConvNormAct(64, 128, mode='down'), - nn.AdaptiveAvgPool2d(1), - nn.Flatten() + nn.AdaptiveAvgPool2d(1) ) - self.encoder.fc = nn.Linear(128, 128) else: raise NotImplementedError - fc = [nn.Linear(self.encoder.fc.in_features, args.feature_dim)] + + # disabling last layers, also saving the feature dimension for inference + if args.backbone != 'basic': + self.encoder.fc = nn.Identity() + self.out_features = self.encoder.fc.in_features + else: + self.out_features = 128 + + # projector after feature extractor + fc = [nn.Linear(self.out_features, args.feature_dim)] if args.mlp: fc.extend([nn.ReLU(), nn.Linear(args.feature_dim, args.feature_dim)]) - self.encoder.fc = nn.Sequential(*fc) + self.projector = nn.Sequential(*fc) def forward(self, x): feature = self.encoder(x) - return F.normalize(feature, dim=-1) + + # only project when training, output features otherwise + if self.training: + feature = self.projector(feature) + return F.normalize(feature, dim=1) diff --git a/moco/utils.py b/moco/utils.py index 766e28b..3fc9736 100644 --- a/moco/utils.py +++ b/moco/utils.py @@ -17,7 +17,7 @@ def forward(self, q, k, memo_bank): N = q.size(0) k = k.detach() pos_logits = torch.einsum('ij,ij->i', [q, k]).unsqueeze(-1) - neg_logits = torch.einsum('ij,kj->ik', [q, memo_bank.queue.clone()]) + neg_logits = torch.einsum('ij,kj->ik', [q, memo_bank.queue.clone().detach()]) logits = torch.cat([pos_logits, neg_logits], dim=1) # zero is the positive "class" @@ -51,40 +51,36 @@ def __call__(self, x): class MemoryBank: """https://github.com/peisuke/MomentumContrast.pytorch""" - def __init__(self, model_k, device, loader, K=4096): + def __init__(self, f_k, device, loader, K=4096): self.K = K - self.queue = torch.zeros((0, 128), dtype=torch.float) + self.queue = torch.empty(dtype=torch.float) self.queue = self.queue.to(device) - # initialize Q with 10 datapoints - for x, _ in loader: - x = x.to(device) - k = model_k(x) - self.queue = self.queue_data(k) - self.queue = self.dequeue_data(10) - break + # initialize queue with 32 features + x, _ = next(iter(loader)) + x = x.to(device) + k = f_k(x) + self.dequeue_and_enqueue(k=k, K=32) def queue_data(self, k): k = k.detach() - return torch.cat([self.queue, k], dim=0) + self.queue = torch.cat([k, self.queue], dim=0) def dequeue_data(self, K=None): if K is None: K = self.K assert isinstance(K, int) if len(self.queue) > K: - return self.queue[-K:] - else: - return self.queue + self.queue = self.queue[:K] - def dequeue_and_enqueue(self, k): - self.queue = self.queue_data(k) - self.queue = self.dequeue_data() + def dequeue_and_enqueue(self, k, K=None): + self.queue_data(k) + self.dequeue_data(K=K) def momentum_update(f_k, f_q, m): for param_q, param_k in zip(f_q.parameters(), f_k.parameters()): - param_k.data = param_k.data * m + param_q.data * (1. - m) + param_k.data = param_k.data * m + param_q.data * (1. - m) def get_momentum_encoder(f_q): diff --git a/networks/layers.py b/networks/layers.py index ae7c18d..c82456e 100644 --- a/networks/layers.py +++ b/networks/layers.py @@ -86,21 +86,11 @@ def forward(self, x): return self.resblock(x) -class Reshape(nn.Module): - def __init__(self, *args): - super().__init__() - self.shape = args - - def forward(self, x): - return x.view(self.shape) - - class Linear_Probe(nn.Module): - def __init__(self, num_classes, hidden_dim=256, lr=1e-3): + def __init__(self, num_classes, in_features=256, lr=30): super().__init__() - self.fc = nn.Linear(hidden_dim, num_classes) - self.optimizer = torch.optim.SGD(self.parameters(), lr=lr, - momentum=0.9, weight_decay=0.0001) + self.fc = nn.Linear(in_features, num_classes) + self.optimizer = torch.optim.SGD(self.parameters(), lr=lr, weight_decay=0) self.criterion = nn.CrossEntropyLoss() self.scheduler = torch.optim.lr_scheduler.MultiplicativeLR( self.optimizer, @@ -112,7 +102,7 @@ def forward(self, x): def loss(self, y_hat, y): return self.criterion(y_hat, y) - def fit(self, x, y, epochs=500): + def fit(self, x, y, epochs=100): dataset = SimpleDataset(x, y) loader = DataLoader(dataset, batch_size=2056, shuffle=True) self.train() diff --git a/simsiam/main.py b/simsiam/main.py index 52080ee..019e197 100644 --- a/simsiam/main.py +++ b/simsiam/main.py @@ -111,7 +111,7 @@ # extract features as training data feature_bank, feature_labels = get_feature_label(model, feature_loader, device, normalize=True) - linear_classifier = Linear_Probe(len(feature_data.classes), hidden_dim=args.hidden_dim).to(device) + linear_classifier = Linear_Probe(len(feature_data.classes), in_features=args.hidden_dim).to(device) linear_classifier.fit(feature_bank, feature_labels) # using linear classifier to predict test data diff --git a/utils/contrastive.py b/utils/contrastive.py index 91beaed..f771eb0 100644 --- a/utils/contrastive.py +++ b/utils/contrastive.py @@ -11,9 +11,9 @@ def get_feature_label(feature_extractor, feature_loader, device, normalize=True, if normalize: feature = F.normalize(feature, dim=1) if predictor is None: - transformed_features.append(feature.clone()) + transformed_features.append(feature) else: - transformed_features.append(predictor.predict(feature.clone())) + transformed_features.append(predictor.predict(feature)) targets.append(target) transformed_features = torch.cat(transformed_features, dim=0) targets = torch.cat(targets, dim=0)