Skip to content

Commit

Permalink
modified moco to match official implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrewzh112 committed Dec 8, 2020
1 parent e7dbefb commit eafd574
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 53 deletions.
13 changes: 8 additions & 5 deletions moco/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@

f_q = torch.nn.DataParallel(MoCo(args), device_ids=[0, 1]).to(device)
f_k = get_momentum_encoder(f_q)
for name, _ in f_q.named_modules():
print(name)

criterion = MoCoLoss(args.temperature)
optimizer = torch.optim.SGD(f_q.parameters(), lr=args.lr,
Expand All @@ -98,18 +100,19 @@
pbar = tqdm(range(start_epoch, args.epochs))
for epoch in pbar:
train_losses = []
f_q.train()
f_k.train()
for x1, x2 in train_loader:
q1, q2 = f_q(x1), f_q(x2)
with torch.no_grad():
momentum_update(f_k, f_q, args.m)
k1, k2 = f_k(x1), f_k(x2)
loss = criterion(q1, k2, memo_bank) + criterion(q2, k1, memo_bank)
optimizer.zero_grad()
loss.backward()
optimizer.step()
k = torch.cat([k1, k2], dim=0)
memo_bank.dequeue_and_enqueue(k)
with torch.no_grad():
momentum_update(f_k, f_q, args.m)
train_losses.append(loss.item())
pbar.set_postfix({'Loss': loss.item(), 'Learning Rate': scheduler.get_last_lr()[0]})

Expand All @@ -118,13 +121,13 @@

f_q.eval()
# extract features as training data
feature_bank, feature_labels = get_feature_label(f_q, momentum_loader, device, normalize=True)
feature_bank, feature_labels = get_feature_label(f_q, momentum_loader, device, normalize=False)

linear_classifier = Linear_Probe(len(momentum_data.classes), hidden_dim=args.feature_dim).to(device)
linear_classifier = Linear_Probe(num_classes=len(momentum_data.classes), in_features=f_q.out_features).to(device)
linear_classifier.fit(feature_bank, feature_labels)

# using linear classifier to predict test data
y_preds, y_trues = get_feature_label(f_q, test_loader, device, normalize=True, predictor=linear_classifier)
y_preds, y_trues = get_feature_label(f_q, test_loader, device, normalize=False, predictor=linear_classifier)
top1acc = y_trues.eq(y_preds).sum().item() / y_preds.size(0)

writer.add_scalar('Top Acc @ 1', top1acc, global_step=epoch)
Expand Down
30 changes: 17 additions & 13 deletions moco/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,6 @@
import torchvision
from networks.layers import ConvNormAct

class Print(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
print(x.size())
return x

class MoCo(nn.Module):
def __init__(self, args):
Expand All @@ -30,17 +23,28 @@ def __init__(self, args):
ConvNormAct(3, 32, mode='down'),
ConvNormAct(32, 64, mode='down'),
ConvNormAct(64, 128, mode='down'),
nn.AdaptiveAvgPool2d(1),
nn.Flatten()
nn.AdaptiveAvgPool2d(1)
)
self.encoder.fc = nn.Linear(128, 128)
else:
raise NotImplementedError
fc = [nn.Linear(self.encoder.fc.in_features, args.feature_dim)]

# disabling last layers, also saving the feature dimension for inference
if args.backbone != 'basic':
self.encoder.fc = nn.Identity()
self.out_features = self.encoder.fc.in_features
else:
self.out_features = 128

# projector after feature extractor
fc = [nn.Linear(self.out_features, args.feature_dim)]
if args.mlp:
fc.extend([nn.ReLU(), nn.Linear(args.feature_dim, args.feature_dim)])
self.encoder.fc = nn.Sequential(*fc)
self.projector = nn.Sequential(*fc)

def forward(self, x):
feature = self.encoder(x)
return F.normalize(feature, dim=-1)

# only project when training, output features otherwise
if self.training:
feature = self.projector(feature)
return F.normalize(feature, dim=1)
32 changes: 14 additions & 18 deletions moco/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def forward(self, q, k, memo_bank):
N = q.size(0)
k = k.detach()
pos_logits = torch.einsum('ij,ij->i', [q, k]).unsqueeze(-1)
neg_logits = torch.einsum('ij,kj->ik', [q, memo_bank.queue.clone()])
neg_logits = torch.einsum('ij,kj->ik', [q, memo_bank.queue.clone().detach()])
logits = torch.cat([pos_logits, neg_logits], dim=1)

# zero is the positive "class"
Expand Down Expand Up @@ -51,40 +51,36 @@ def __call__(self, x):

class MemoryBank:
"""https://github.com/peisuke/MomentumContrast.pytorch"""
def __init__(self, model_k, device, loader, K=4096):
def __init__(self, f_k, device, loader, K=4096):
self.K = K
self.queue = torch.zeros((0, 128), dtype=torch.float)
self.queue = torch.empty(dtype=torch.float)
self.queue = self.queue.to(device)

# initialize Q with 10 datapoints
for x, _ in loader:
x = x.to(device)
k = model_k(x)
self.queue = self.queue_data(k)
self.queue = self.dequeue_data(10)
break
# initialize queue with 32 features
x, _ = next(iter(loader))
x = x.to(device)
k = f_k(x)
self.dequeue_and_enqueue(k=k, K=32)

def queue_data(self, k):
k = k.detach()
return torch.cat([self.queue, k], dim=0)
self.queue = torch.cat([k, self.queue], dim=0)

def dequeue_data(self, K=None):
if K is None:
K = self.K
assert isinstance(K, int)
if len(self.queue) > K:
return self.queue[-K:]
else:
return self.queue
self.queue = self.queue[:K]

def dequeue_and_enqueue(self, k):
self.queue = self.queue_data(k)
self.queue = self.dequeue_data()
def dequeue_and_enqueue(self, k, K=None):
self.queue_data(k)
self.dequeue_data(K=K)


def momentum_update(f_k, f_q, m):
for param_q, param_k in zip(f_q.parameters(), f_k.parameters()):
param_k.data = param_k.data * m + param_q.data * (1. - m)
param_k.data = param_k.data * m + param_q.data * (1. - m)


def get_momentum_encoder(f_q):
Expand Down
18 changes: 4 additions & 14 deletions networks/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,21 +86,11 @@ def forward(self, x):
return self.resblock(x)


class Reshape(nn.Module):
def __init__(self, *args):
super().__init__()
self.shape = args

def forward(self, x):
return x.view(self.shape)


class Linear_Probe(nn.Module):
def __init__(self, num_classes, hidden_dim=256, lr=1e-3):
def __init__(self, num_classes, in_features=256, lr=30):
super().__init__()
self.fc = nn.Linear(hidden_dim, num_classes)
self.optimizer = torch.optim.SGD(self.parameters(), lr=lr,
momentum=0.9, weight_decay=0.0001)
self.fc = nn.Linear(in_features, num_classes)
self.optimizer = torch.optim.SGD(self.parameters(), lr=lr, weight_decay=0)
self.criterion = nn.CrossEntropyLoss()
self.scheduler = torch.optim.lr_scheduler.MultiplicativeLR(
self.optimizer,
Expand All @@ -112,7 +102,7 @@ def forward(self, x):
def loss(self, y_hat, y):
return self.criterion(y_hat, y)

def fit(self, x, y, epochs=500):
def fit(self, x, y, epochs=100):
dataset = SimpleDataset(x, y)
loader = DataLoader(dataset, batch_size=2056, shuffle=True)
self.train()
Expand Down
2 changes: 1 addition & 1 deletion simsiam/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@
# extract features as training data
feature_bank, feature_labels = get_feature_label(model, feature_loader, device, normalize=True)

linear_classifier = Linear_Probe(len(feature_data.classes), hidden_dim=args.hidden_dim).to(device)
linear_classifier = Linear_Probe(len(feature_data.classes), in_features=args.hidden_dim).to(device)
linear_classifier.fit(feature_bank, feature_labels)

# using linear classifier to predict test data
Expand Down
4 changes: 2 additions & 2 deletions utils/contrastive.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ def get_feature_label(feature_extractor, feature_loader, device, normalize=True,
if normalize:
feature = F.normalize(feature, dim=1)
if predictor is None:
transformed_features.append(feature.clone())
transformed_features.append(feature)
else:
transformed_features.append(predictor.predict(feature.clone()))
transformed_features.append(predictor.predict(feature))
targets.append(target)
transformed_features = torch.cat(transformed_features, dim=0)
targets = torch.cat(targets, dim=0)
Expand Down

0 comments on commit eafd574

Please sign in to comment.