From 4e28828e946addf2742008749ca81f05499cccec Mon Sep 17 00:00:00 2001 From: zsdonghao Date: Wed, 2 Dec 2020 23:35:27 +0800 Subject: [PATCH] modified training --- moco/main.py | 12 ++++++------ moco/utils.py | 7 ++++--- simsiam/main.py | 14 +++++--------- 3 files changed, 15 insertions(+), 18 deletions(-) diff --git a/moco/main.py b/moco/main.py index 6702ee5..6c0cc56 100644 --- a/moco/main.py +++ b/moco/main.py @@ -37,9 +37,9 @@ parser.add_argument('--mlp', default=True, type=bool, help='feature dimension') # misc. -parser.add_argument('--data_root', default='../data', type=str, help='path to data') -parser.add_argument('--logs_root', default='logs', type=str, help='path to logs') -parser.add_argument('--check_point', default='check_point/moco.pth', type=str, help='path to model weights') +parser.add_argument('--data_root', default='data', type=str, help='path to data') +parser.add_argument('--logs_root', default='moco/logs', type=str, help='path to logs') +parser.add_argument('--check_point', default='moco/check_point/moco.pth', type=str, help='path to model weights') args = parser.parse_args() @@ -73,8 +73,8 @@ test_data = CIFAR10(root=args.data_root, train=False, transform=test_transform, download=True) test_loader = DataLoader(test_data, batch_size=args.batch_size, shuffle=False, num_workers=28) - Path(args.check_point.split('/')[0]).mkdir(parents=True, exist_ok=True) - Path(args.logs_root).mkdir(parents=True, exist_ok=True) + Path(args.check_point.split('/')[1]).mkdir(parents=True, exist_ok=True) + Path(args.logs_root.split('/')[1]).mkdir(parents=True, exist_ok=True) f_q = MoCo(args).cuda() f_k = get_momentum_encoder(f_q) @@ -104,7 +104,7 @@ with torch.no_grad(): momentum_update(f_k, f_q, args.m) train_losses.append(loss.item()) - pbar.set_postfix({'Loss': loss.item(), 'Learning Rate': scheduler.get_last_lr()}) + pbar.set_postfix({'Loss': loss.item(), 'Learning Rate': scheduler.get_last_lr()[0]}) writer.add_scalar('Train Loss', sum(train_losses) / len(train_losses), global_step=epoch) torch.save(f_q.state_dict(), args.check_point) diff --git a/moco/utils.py b/moco/utils.py index ab3420a..706c391 100644 --- a/moco/utils.py +++ b/moco/utils.py @@ -14,11 +14,12 @@ def __init__(self, T=0.07): self.criterion = nn.CrossEntropyLoss() def forward(self, q, k, memo_bank): + N = q.size(0) k = k.detach() pos_logits = torch.einsum('ij,ij->i', [q, k]).unsqueeze(-1) - neg_logits = torch.einsum('ij,jk->ik', [q, memo_bank.queue.clone()]) + neg_logits = torch.einsum('ij,kj->ik', [q, memo_bank.queue.clone()]) logits = torch.cat([pos_logits, neg_logits], dim=1) - labels = torch.zeros_like(logits, dtype=torch.long, device=logits.device) + labels = torch.zeros(N, dtype=torch.long, device=logits.device) return self.criterion(logits / self.T, labels) @@ -75,7 +76,7 @@ def dequeue_data(self, K=None): return self.queue def dequeue_and_enqueue(self, k): - self.queue_data(k) + self.queue = self.queue_data(k) self.queue = self.dequeue_data() diff --git a/simsiam/main.py b/simsiam/main.py index 6f54dab..07cbf25 100644 --- a/simsiam/main.py +++ b/simsiam/main.py @@ -33,14 +33,10 @@ parser.add_argument('--bottleneck_dim', default=512, type=int, help='bottleneck dimension') parser.add_argument('--num_encoder_fcs', default=2, type=int, help='number of layers of fcs for encoder') -# knn monitor -parser.add_argument('--knn-k', default=200, type=int, help='k in kNN monitor') -parser.add_argument('--knn-t', default=0.1, type=float, help='softmax temperature in kNN monitor') - # misc. -parser.add_argument('--data_root', default='../data', type=str, help='path to data') -parser.add_argument('--logs_root', default='logs', type=str, help='path to logs') -parser.add_argument('--check_point', default='check_point/simsiam.pth', type=str, help='path to model weights') +parser.add_argument('--data_root', default='data', type=str, help='path to data') +parser.add_argument('--logs_root', default='simsiam/logs', type=str, help='path to logs') +parser.add_argument('--check_point', default='simsiam/check_point/simsiam.pth', type=str, help='path to model weights') args = parser.parse_args() @@ -82,8 +78,8 @@ def cosine_loss(p, z): writer = SummaryWriter(args.logs_root) model = SimSiam(args).to(device) - Path(args.check_point.split('/')[0]).mkdir(parents=True, exist_ok=True) - Path(args.logs_root).mkdir(parents=True, exist_ok=True) + Path(args.check_point.split('/')[1]).mkdir(parents=True, exist_ok=True) + Path(args.logs_root.split('/')[1]).mkdir(parents=True, exist_ok=True) optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.wd)