diff --git a/moco/main.py b/moco/main.py index 0e1d7d9..8777f9d 100644 --- a/moco/main.py +++ b/moco/main.py @@ -24,7 +24,7 @@ # training configs parser.add_argument('--lr', default=0.03, type=float, help='initial learning rate') parser.add_argument('--epochs', default=200, type=int, metavar='N', help='number of total epochs to run') -parser.add_argument('--batch_size', default=128, type=int, metavar='N', help='mini-batch size') +parser.add_argument('--batch_size', default=256, type=int, metavar='N', help='mini-batch size') parser.add_argument('--wd', default=0.0001, type=float, metavar='W', help='weight decay') parser.add_argument('--momentum', default=0.9, type=float, help='momentum for optimizer') parser.add_argument('--temperature', default=0.07, type=float, help='temperature for loss fn') @@ -76,7 +76,7 @@ Path(args.check_point.split('/')[0]).mkdir(parents=True, exist_ok=True) Path(args.logs_root).mkdir(parents=True, exist_ok=True) - f_q = MoCo(args).to(device) + f_q = MoCo(args).cuda() f_k = get_momentum_encoder(f_q) criterion = MoCoLoss(args.temperature) @@ -87,11 +87,11 @@ memo_bank = MemoryBank(f_k, device, train_loader, args.K) writer = SummaryWriter(args.logs_root) - pbar = tqdm(args.epochs) + pbar = tqdm(range(args.epochs)) for epoch in pbar: train_losses = [] for x1, x2 in train_loader: - x1, x2 = x1.to(device), x2.to(device) + x1, x2 = x1.cuda(), x2.cuda() q1, q2 = f_q(x1), f_q(x2) with torch.no_grad(): k1, k2 = f_k(x1), f_k(x2) @@ -112,7 +112,7 @@ feature_bank, feature_labels = [], [] for data, target in momentum_loader: - data = data.to(device) + data = data.cuda() with torch.no_grad(): features = f_q(data) feature_bank.append(features) @@ -125,7 +125,7 @@ y_preds, y_trues = [], [] for data, target in test_loader: - data = data.to(device) + data = data.cuda() with torch.no_grad(): feature = f_q(data).cpu().numpy() y_preds.extend(linear_classifier.predict(feature).tolist()) diff --git a/moco/utils.py b/moco/utils.py index 0548ad3..ab3420a 100644 --- a/moco/utils.py +++ b/moco/utils.py @@ -54,8 +54,7 @@ def __init__(self, model_k, device, loader, K=4096): self.queue = self.queue.to(device) for data, _ in loader: - x_k = data[1] - x_k = x_k.to(device) + x_k = data.to(device) k = model_k(x_k) k = k.detach() self.queue = self.queue_data(k) @@ -77,7 +76,7 @@ def dequeue_data(self, K=None): def dequeue_and_enqueue(self, k): self.queue_data(k) - return self.dequeue_data() + self.queue = self.dequeue_data() def momentum_update(f_k, f_q, m):