Skip to content

Commit 9b90558

Browse files
committed
added SGD linear evaluator
1 parent 15b8848 commit 9b90558

File tree

4 files changed

+18
-17
lines changed

4 files changed

+18
-17
lines changed

simsiam/data.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ def __call__(self, x):
3636
class SimpleDataset(Dataset):
3737
def __init__(self, x, y):
3838
super().__init__()
39-
self.x = torch.tensor(x)
40-
self.y = torch.tensor(y)
39+
self.x = x
40+
self.y = y
4141

4242
def __getitem__(self, idx):
4343
return self.x[idx], self.y[idx]

simsiam/main.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -107,11 +107,11 @@
107107
feature_bank, targets = [], []
108108
# get current feature maps & fit LR
109109
for data, target in feature_loader:
110-
data = data.to(device)
110+
data, target = data.to(device), target.to(device)
111111
with torch.no_grad():
112-
feature = model(data, istrain=False)
112+
feature = model(data)
113113
feature = F.normalize(feature, dim=1)
114-
feature_bank.append(feature)
114+
feature_bank.append(feature.clone().detach())
115115
targets.append(target)
116116
feature_bank = torch.cat(feature_bank, dim=0)
117117
feature_labels = torch.cat(targets, dim=0)
@@ -121,20 +121,19 @@
121121

122122
y_preds, y_trues = [], []
123123
for data, target in test_loader:
124-
data = data.to(device)
124+
data, target = data.to(device), target.to(device)
125125
with torch.no_grad():
126-
feature = model(data, istrain=False)
126+
feature = model(data)
127127
feature = F.normalize(feature, dim=1)
128-
y_preds.append(linear_classifier.predict(feature))
128+
y_preds.append(linear_classifier.predict(feature.detach()))
129129
y_trues.append(target)
130130
y_trues = torch.cat(y_trues, dim=0)
131131
y_preds = torch.cat(y_preds, dim=0)
132-
top1acc = (y_trues == y_preds).sum() / y_preds.size(0)
133-
132+
top1acc = y_trues.eq(y_preds).sum().item() / y_preds.size(0)
134133
writer.add_scalar('Top Acc @ 1', top1acc, global_step=epoch)
135134
writer.add_scalar('Representation Standard Deviation', feature_bank.std(), global_step=epoch)
136135

137-
tqdm.write('#########################################################')
136+
tqdm.write('###################################################################')
138137
tqdm.write(
139138
f'Epoch {epoch + 1}/{args.epochs}, \
140139
Train Loss: {sum(train_losses) / len(train_losses):.3f}, \

simsiam/model.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
"""https://github.com/facebookresearch/moco"""
2-
31
import torch
42
from torch import nn
53
from torch.nn import functional as F
@@ -9,12 +7,15 @@
97

108

119
class Linear_Classifier(nn.Module):
12-
def __init__(self, args, num_classes, epochs=2000, lr=1e-3):
10+
def __init__(self, args, num_classes, epochs=500, lr=1e-3):
1311
super().__init__()
1412
self.fc = nn.Linear(args.hidden_dim, num_classes)
1513
self.epochs = epochs
1614
self.optimizer = torch.optim.Adam(self.parameters(), lr=lr)
1715
self.criterion = nn.CrossEntropyLoss()
16+
self.scheduler = torch.optim.lr_scheduler.MultiplicativeLR(
17+
self.optimizer,
18+
lr_lambda=lambda lr: 0.995)
1819

1920
def forward(self, x):
2021
return self.fc(x)
@@ -33,6 +34,7 @@ def fit(self, x, y):
3334
self.optimizer.zero_grad()
3435
loss.backward()
3536
self.optimizer.step()
37+
self.scheduler.step()
3638

3739
def predict(self, x):
3840
self.eval()
@@ -79,8 +81,8 @@ def __init__(self, args):
7981
nn.Linear(args.bottleneck_dim, args.hidden_dim),
8082
)
8183

82-
def forward(self, x1, x2=None, istrain=True):
83-
if istrain:
84+
def forward(self, x1, x2=None):
85+
if self.training:
8486
z1, z2 = self.encoder(x1), self.encoder(x2)
8587
p1, p2 = self.projector(z1), self.projector(z2)
8688
return z1, z2, p1, p2

simsiam/vis.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def get_features(loader, model, device):
1717
for img, target in loader:
1818
img = img.to(device)
1919
with torch.no_grad():
20-
feature = model(img, istrain=False)
20+
feature = model(img)
2121
targets.extend(target.cpu().numpy().tolist())
2222
features.append(feature.cpu())
2323
features = torch.cat(features).numpy()

0 commit comments

Comments
 (0)