diff --git a/simsiam/data.py b/simsiam/data.py index 2c7807f..9e6369a 100644 --- a/simsiam/data.py +++ b/simsiam/data.py @@ -36,8 +36,8 @@ def __call__(self, x): class SimpleDataset(Dataset): def __init__(self, x, y): super().__init__() - self.x = torch.tensor(x) - self.y = torch.tensor(y) + self.x = x + self.y = y def __getitem__(self, idx): return self.x[idx], self.y[idx] diff --git a/simsiam/main.py b/simsiam/main.py index 963bcc8..3514cf0 100644 --- a/simsiam/main.py +++ b/simsiam/main.py @@ -107,11 +107,11 @@ feature_bank, targets = [], [] # get current feature maps & fit LR for data, target in feature_loader: - data = data.to(device) + data, target = data.to(device), target.to(device) with torch.no_grad(): - feature = model(data, istrain=False) + feature = model(data) feature = F.normalize(feature, dim=1) - feature_bank.append(feature) + feature_bank.append(feature.clone().detach()) targets.append(target) feature_bank = torch.cat(feature_bank, dim=0) feature_labels = torch.cat(targets, dim=0) @@ -121,20 +121,19 @@ y_preds, y_trues = [], [] for data, target in test_loader: - data = data.to(device) + data, target = data.to(device), target.to(device) with torch.no_grad(): - feature = model(data, istrain=False) + feature = model(data) feature = F.normalize(feature, dim=1) - y_preds.append(linear_classifier.predict(feature)) + y_preds.append(linear_classifier.predict(feature.detach())) y_trues.append(target) y_trues = torch.cat(y_trues, dim=0) y_preds = torch.cat(y_preds, dim=0) - top1acc = (y_trues == y_preds).sum() / y_preds.size(0) - + top1acc = y_trues.eq(y_preds).sum().item() / y_preds.size(0) writer.add_scalar('Top Acc @ 1', top1acc, global_step=epoch) writer.add_scalar('Representation Standard Deviation', feature_bank.std(), global_step=epoch) - tqdm.write('#########################################################') + tqdm.write('###################################################################') tqdm.write( f'Epoch {epoch + 1}/{args.epochs}, \ Train Loss: {sum(train_losses) / len(train_losses):.3f}, \ diff --git a/simsiam/model.py b/simsiam/model.py index a987aeb..1ff69ab 100644 --- a/simsiam/model.py +++ b/simsiam/model.py @@ -1,5 +1,3 @@ -"""https://github.com/facebookresearch/moco""" - import torch from torch import nn from torch.nn import functional as F @@ -9,12 +7,15 @@ class Linear_Classifier(nn.Module): - def __init__(self, args, num_classes, epochs=2000, lr=1e-3): + def __init__(self, args, num_classes, epochs=500, lr=1e-3): super().__init__() self.fc = nn.Linear(args.hidden_dim, num_classes) self.epochs = epochs self.optimizer = torch.optim.Adam(self.parameters(), lr=lr) self.criterion = nn.CrossEntropyLoss() + self.scheduler = torch.optim.lr_scheduler.MultiplicativeLR( + self.optimizer, + lr_lambda=lambda lr: 0.995) def forward(self, x): return self.fc(x) @@ -33,6 +34,7 @@ def fit(self, x, y): self.optimizer.zero_grad() loss.backward() self.optimizer.step() + self.scheduler.step() def predict(self, x): self.eval() @@ -79,8 +81,8 @@ def __init__(self, args): nn.Linear(args.bottleneck_dim, args.hidden_dim), ) - def forward(self, x1, x2=None, istrain=True): - if istrain: + def forward(self, x1, x2=None): + if self.training: z1, z2 = self.encoder(x1), self.encoder(x2) p1, p2 = self.projector(z1), self.projector(z2) return z1, z2, p1, p2 diff --git a/simsiam/vis.py b/simsiam/vis.py index a291ffc..3803657 100644 --- a/simsiam/vis.py +++ b/simsiam/vis.py @@ -17,7 +17,7 @@ def get_features(loader, model, device): for img, target in loader: img = img.to(device) with torch.no_grad(): - feature = model(img, istrain=False) + feature = model(img) targets.extend(target.cpu().numpy().tolist()) features.append(feature.cpu()) features = torch.cat(features).numpy()