Skip to content

Commit

Permalink
modified moco training
Browse files Browse the repository at this point in the history
  • Loading branch information
zsdonghao committed Dec 2, 2020
1 parent a32bc9c commit e75a67b
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 9 deletions.
12 changes: 6 additions & 6 deletions moco/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
# training configs
parser.add_argument('--lr', default=0.03, type=float, help='initial learning rate')
parser.add_argument('--epochs', default=200, type=int, metavar='N', help='number of total epochs to run')
parser.add_argument('--batch_size', default=128, type=int, metavar='N', help='mini-batch size')
parser.add_argument('--batch_size', default=256, type=int, metavar='N', help='mini-batch size')
parser.add_argument('--wd', default=0.0001, type=float, metavar='W', help='weight decay')
parser.add_argument('--momentum', default=0.9, type=float, help='momentum for optimizer')
parser.add_argument('--temperature', default=0.07, type=float, help='temperature for loss fn')
Expand Down Expand Up @@ -76,7 +76,7 @@
Path(args.check_point.split('/')[0]).mkdir(parents=True, exist_ok=True)
Path(args.logs_root).mkdir(parents=True, exist_ok=True)

f_q = MoCo(args).to(device)
f_q = MoCo(args).cuda()
f_k = get_momentum_encoder(f_q)

criterion = MoCoLoss(args.temperature)
Expand All @@ -87,11 +87,11 @@
memo_bank = MemoryBank(f_k, device, train_loader, args.K)
writer = SummaryWriter(args.logs_root)

pbar = tqdm(args.epochs)
pbar = tqdm(range(args.epochs))
for epoch in pbar:
train_losses = []
for x1, x2 in train_loader:
x1, x2 = x1.to(device), x2.to(device)
x1, x2 = x1.cuda(), x2.cuda()
q1, q2 = f_q(x1), f_q(x2)
with torch.no_grad():
k1, k2 = f_k(x1), f_k(x2)
Expand All @@ -112,7 +112,7 @@

feature_bank, feature_labels = [], []
for data, target in momentum_loader:
data = data.to(device)
data = data.cuda()
with torch.no_grad():
features = f_q(data)
feature_bank.append(features)
Expand All @@ -125,7 +125,7 @@

y_preds, y_trues = [], []
for data, target in test_loader:
data = data.to(device)
data = data.cuda()
with torch.no_grad():
feature = f_q(data).cpu().numpy()
y_preds.extend(linear_classifier.predict(feature).tolist())
Expand Down
5 changes: 2 additions & 3 deletions moco/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,7 @@ def __init__(self, model_k, device, loader, K=4096):
self.queue = self.queue.to(device)

for data, _ in loader:
x_k = data[1]
x_k = x_k.to(device)
x_k = data.to(device)
k = model_k(x_k)
k = k.detach()
self.queue = self.queue_data(k)
Expand All @@ -77,7 +76,7 @@ def dequeue_data(self, K=None):

def dequeue_and_enqueue(self, k):
self.queue_data(k)
return self.dequeue_data()
self.queue = self.dequeue_data()


def momentum_update(f_k, f_q, m):
Expand Down

0 comments on commit e75a67b

Please sign in to comment.