Skip to content

Commit

Permalink
modified training
Browse files Browse the repository at this point in the history
  • Loading branch information
zsdonghao committed Dec 2, 2020
1 parent 842051a commit 4e28828
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 18 deletions.
12 changes: 6 additions & 6 deletions moco/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@
parser.add_argument('--mlp', default=True, type=bool, help='feature dimension')

# misc.
parser.add_argument('--data_root', default='../data', type=str, help='path to data')
parser.add_argument('--logs_root', default='logs', type=str, help='path to logs')
parser.add_argument('--check_point', default='check_point/moco.pth', type=str, help='path to model weights')
parser.add_argument('--data_root', default='data', type=str, help='path to data')
parser.add_argument('--logs_root', default='moco/logs', type=str, help='path to logs')
parser.add_argument('--check_point', default='moco/check_point/moco.pth', type=str, help='path to model weights')

args = parser.parse_args()

Expand Down Expand Up @@ -73,8 +73,8 @@
test_data = CIFAR10(root=args.data_root, train=False, transform=test_transform, download=True)
test_loader = DataLoader(test_data, batch_size=args.batch_size, shuffle=False, num_workers=28)

Path(args.check_point.split('/')[0]).mkdir(parents=True, exist_ok=True)
Path(args.logs_root).mkdir(parents=True, exist_ok=True)
Path(args.check_point.split('/')[1]).mkdir(parents=True, exist_ok=True)
Path(args.logs_root.split('/')[1]).mkdir(parents=True, exist_ok=True)

f_q = MoCo(args).cuda()
f_k = get_momentum_encoder(f_q)
Expand Down Expand Up @@ -104,7 +104,7 @@
with torch.no_grad():
momentum_update(f_k, f_q, args.m)
train_losses.append(loss.item())
pbar.set_postfix({'Loss': loss.item(), 'Learning Rate': scheduler.get_last_lr()})
pbar.set_postfix({'Loss': loss.item(), 'Learning Rate': scheduler.get_last_lr()[0]})

writer.add_scalar('Train Loss', sum(train_losses) / len(train_losses), global_step=epoch)
torch.save(f_q.state_dict(), args.check_point)
Expand Down
7 changes: 4 additions & 3 deletions moco/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@ def __init__(self, T=0.07):
self.criterion = nn.CrossEntropyLoss()

def forward(self, q, k, memo_bank):
N = q.size(0)
k = k.detach()
pos_logits = torch.einsum('ij,ij->i', [q, k]).unsqueeze(-1)
neg_logits = torch.einsum('ij,jk->ik', [q, memo_bank.queue.clone()])
neg_logits = torch.einsum('ij,kj->ik', [q, memo_bank.queue.clone()])
logits = torch.cat([pos_logits, neg_logits], dim=1)
labels = torch.zeros_like(logits, dtype=torch.long, device=logits.device)
labels = torch.zeros(N, dtype=torch.long, device=logits.device)
return self.criterion(logits / self.T, labels)


Expand Down Expand Up @@ -75,7 +76,7 @@ def dequeue_data(self, K=None):
return self.queue

def dequeue_and_enqueue(self, k):
self.queue_data(k)
self.queue = self.queue_data(k)
self.queue = self.dequeue_data()


Expand Down
14 changes: 5 additions & 9 deletions simsiam/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,10 @@
parser.add_argument('--bottleneck_dim', default=512, type=int, help='bottleneck dimension')
parser.add_argument('--num_encoder_fcs', default=2, type=int, help='number of layers of fcs for encoder')

# knn monitor
parser.add_argument('--knn-k', default=200, type=int, help='k in kNN monitor')
parser.add_argument('--knn-t', default=0.1, type=float, help='softmax temperature in kNN monitor')

# misc.
parser.add_argument('--data_root', default='../data', type=str, help='path to data')
parser.add_argument('--logs_root', default='logs', type=str, help='path to logs')
parser.add_argument('--check_point', default='check_point/simsiam.pth', type=str, help='path to model weights')
parser.add_argument('--data_root', default='data', type=str, help='path to data')
parser.add_argument('--logs_root', default='simsiam/logs', type=str, help='path to logs')
parser.add_argument('--check_point', default='simsiam/check_point/simsiam.pth', type=str, help='path to model weights')

args = parser.parse_args()

Expand Down Expand Up @@ -82,8 +78,8 @@ def cosine_loss(p, z):

writer = SummaryWriter(args.logs_root)
model = SimSiam(args).to(device)
Path(args.check_point.split('/')[0]).mkdir(parents=True, exist_ok=True)
Path(args.logs_root).mkdir(parents=True, exist_ok=True)
Path(args.check_point.split('/')[1]).mkdir(parents=True, exist_ok=True)
Path(args.logs_root.split('/')[1]).mkdir(parents=True, exist_ok=True)

optimizer = torch.optim.SGD(model.parameters(), lr=args.lr,
momentum=args.momentum, weight_decay=args.wd)
Expand Down

0 comments on commit 4e28828

Please sign in to comment.