Skip to content

Commit

Permalink
added SGD linear evaluator
Browse files Browse the repository at this point in the history
  • Loading branch information
zsdonghao committed Dec 7, 2020
1 parent 15b8848 commit 9b90558
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 17 deletions.
4 changes: 2 additions & 2 deletions simsiam/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ def __call__(self, x):
class SimpleDataset(Dataset):
def __init__(self, x, y):
super().__init__()
self.x = torch.tensor(x)
self.y = torch.tensor(y)
self.x = x
self.y = y

def __getitem__(self, idx):
return self.x[idx], self.y[idx]
Expand Down
17 changes: 8 additions & 9 deletions simsiam/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,11 @@
feature_bank, targets = [], []
# get current feature maps & fit LR
for data, target in feature_loader:
data = data.to(device)
data, target = data.to(device), target.to(device)
with torch.no_grad():
feature = model(data, istrain=False)
feature = model(data)
feature = F.normalize(feature, dim=1)
feature_bank.append(feature)
feature_bank.append(feature.clone().detach())
targets.append(target)
feature_bank = torch.cat(feature_bank, dim=0)
feature_labels = torch.cat(targets, dim=0)
Expand All @@ -121,20 +121,19 @@

y_preds, y_trues = [], []
for data, target in test_loader:
data = data.to(device)
data, target = data.to(device), target.to(device)
with torch.no_grad():
feature = model(data, istrain=False)
feature = model(data)
feature = F.normalize(feature, dim=1)
y_preds.append(linear_classifier.predict(feature))
y_preds.append(linear_classifier.predict(feature.detach()))
y_trues.append(target)
y_trues = torch.cat(y_trues, dim=0)
y_preds = torch.cat(y_preds, dim=0)
top1acc = (y_trues == y_preds).sum() / y_preds.size(0)

top1acc = y_trues.eq(y_preds).sum().item() / y_preds.size(0)
writer.add_scalar('Top Acc @ 1', top1acc, global_step=epoch)
writer.add_scalar('Representation Standard Deviation', feature_bank.std(), global_step=epoch)

tqdm.write('#########################################################')
tqdm.write('###################################################################')
tqdm.write(
f'Epoch {epoch + 1}/{args.epochs}, \
Train Loss: {sum(train_losses) / len(train_losses):.3f}, \
Expand Down
12 changes: 7 additions & 5 deletions simsiam/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
"""https://github.com/facebookresearch/moco"""

import torch
from torch import nn
from torch.nn import functional as F
Expand All @@ -9,12 +7,15 @@


class Linear_Classifier(nn.Module):
def __init__(self, args, num_classes, epochs=2000, lr=1e-3):
def __init__(self, args, num_classes, epochs=500, lr=1e-3):
super().__init__()
self.fc = nn.Linear(args.hidden_dim, num_classes)
self.epochs = epochs
self.optimizer = torch.optim.Adam(self.parameters(), lr=lr)
self.criterion = nn.CrossEntropyLoss()
self.scheduler = torch.optim.lr_scheduler.MultiplicativeLR(
self.optimizer,
lr_lambda=lambda lr: 0.995)

def forward(self, x):
return self.fc(x)
Expand All @@ -33,6 +34,7 @@ def fit(self, x, y):
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
self.scheduler.step()

def predict(self, x):
self.eval()
Expand Down Expand Up @@ -79,8 +81,8 @@ def __init__(self, args):
nn.Linear(args.bottleneck_dim, args.hidden_dim),
)

def forward(self, x1, x2=None, istrain=True):
if istrain:
def forward(self, x1, x2=None):
if self.training:
z1, z2 = self.encoder(x1), self.encoder(x2)
p1, p2 = self.projector(z1), self.projector(z2)
return z1, z2, p1, p2
Expand Down
2 changes: 1 addition & 1 deletion simsiam/vis.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def get_features(loader, model, device):
for img, target in loader:
img = img.to(device)
with torch.no_grad():
feature = model(img, istrain=False)
feature = model(img)
targets.extend(target.cpu().numpy().tolist())
features.append(feature.cpu())
features = torch.cat(features).numpy()
Expand Down

0 comments on commit 9b90558

Please sign in to comment.