Skip to content

Commit

Permalink
using softmax classifier as evaluator
Browse files Browse the repository at this point in the history
  • Loading branch information
zsdonghao committed Dec 5, 2020
1 parent 5c088b9 commit 4656d16
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion moco/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@
feature_bank = torch.cat(feature_bank).cpu().numpy()
feature_labels = torch.cat(feature_labels).numpy()

linear_classifier = LogisticRegression()
linear_classifier = LogisticRegression(multi_class='multinomial', solver='lbfgs')
linear_classifier.fit(feature_bank, feature_labels)

y_preds, y_trues = [], []
Expand Down
4 changes: 2 additions & 2 deletions moco/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ def forward(self, q, k, memo_bank):
pos_logits = torch.einsum('ij,ij->i', [q, k]).unsqueeze(-1)
neg_logits = torch.einsum('ij,kj->ik', [q, memo_bank.queue.clone()])
logits = torch.cat([pos_logits, neg_logits], dim=1)

# zero is the positive "class"
labels = torch.new_zeros(N)
labels = logits.new_zeros(N, dtype=torch.long)
return self.criterion(logits / self.T, labels)


Expand Down
2 changes: 1 addition & 1 deletion simsiam/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@
feature_bank = torch.cat(feature_bank, dim=0).cpu().numpy()
feature_labels = torch.cat(targets, dim=0).numpy()

linear_classifier = LogisticRegression()
linear_classifier = LogisticRegression(multi_class='multinomial', solver='lbfgs')
linear_classifier.fit(feature_bank, feature_labels)

y_preds, y_trues = [], []
Expand Down

0 comments on commit 4656d16

Please sign in to comment.