diff --git a/moco/main.py b/moco/main.py index 69512a0..c8ed3f4 100644 --- a/moco/main.py +++ b/moco/main.py @@ -119,7 +119,7 @@ feature_bank = torch.cat(feature_bank).cpu().numpy() feature_labels = torch.cat(feature_labels).numpy() - linear_classifier = LogisticRegression() + linear_classifier = LogisticRegression(multi_class='multinomial', solver='lbfgs') linear_classifier.fit(feature_bank, feature_labels) y_preds, y_trues = [], [] diff --git a/moco/utils.py b/moco/utils.py index 29c3a5b..766e28b 100644 --- a/moco/utils.py +++ b/moco/utils.py @@ -19,9 +19,9 @@ def forward(self, q, k, memo_bank): pos_logits = torch.einsum('ij,ij->i', [q, k]).unsqueeze(-1) neg_logits = torch.einsum('ij,kj->ik', [q, memo_bank.queue.clone()]) logits = torch.cat([pos_logits, neg_logits], dim=1) - + # zero is the positive "class" - labels = torch.new_zeros(N) + labels = logits.new_zeros(N, dtype=torch.long) return self.criterion(logits / self.T, labels) diff --git a/simsiam/main.py b/simsiam/main.py index cc15ee3..9d33523 100644 --- a/simsiam/main.py +++ b/simsiam/main.py @@ -113,7 +113,7 @@ feature_bank = torch.cat(feature_bank, dim=0).cpu().numpy() feature_labels = torch.cat(targets, dim=0).numpy() - linear_classifier = LogisticRegression() + linear_classifier = LogisticRegression(multi_class='multinomial', solver='lbfgs') linear_classifier.fit(feature_bank, feature_labels) y_preds, y_trues = [], []