Skip to content

Commit 4e28828

Browse files
committed
modified training
1 parent 842051a commit 4e28828

File tree

3 files changed

+15
-18
lines changed

3 files changed

+15
-18
lines changed

moco/main.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,9 @@
3737
parser.add_argument('--mlp', default=True, type=bool, help='feature dimension')
3838

3939
# misc.
40-
parser.add_argument('--data_root', default='../data', type=str, help='path to data')
41-
parser.add_argument('--logs_root', default='logs', type=str, help='path to logs')
42-
parser.add_argument('--check_point', default='check_point/moco.pth', type=str, help='path to model weights')
40+
parser.add_argument('--data_root', default='data', type=str, help='path to data')
41+
parser.add_argument('--logs_root', default='moco/logs', type=str, help='path to logs')
42+
parser.add_argument('--check_point', default='moco/check_point/moco.pth', type=str, help='path to model weights')
4343

4444
args = parser.parse_args()
4545

@@ -73,8 +73,8 @@
7373
test_data = CIFAR10(root=args.data_root, train=False, transform=test_transform, download=True)
7474
test_loader = DataLoader(test_data, batch_size=args.batch_size, shuffle=False, num_workers=28)
7575

76-
Path(args.check_point.split('/')[0]).mkdir(parents=True, exist_ok=True)
77-
Path(args.logs_root).mkdir(parents=True, exist_ok=True)
76+
Path(args.check_point.split('/')[1]).mkdir(parents=True, exist_ok=True)
77+
Path(args.logs_root.split('/')[1]).mkdir(parents=True, exist_ok=True)
7878

7979
f_q = MoCo(args).cuda()
8080
f_k = get_momentum_encoder(f_q)
@@ -104,7 +104,7 @@
104104
with torch.no_grad():
105105
momentum_update(f_k, f_q, args.m)
106106
train_losses.append(loss.item())
107-
pbar.set_postfix({'Loss': loss.item(), 'Learning Rate': scheduler.get_last_lr()})
107+
pbar.set_postfix({'Loss': loss.item(), 'Learning Rate': scheduler.get_last_lr()[0]})
108108

109109
writer.add_scalar('Train Loss', sum(train_losses) / len(train_losses), global_step=epoch)
110110
torch.save(f_q.state_dict(), args.check_point)

moco/utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,12 @@ def __init__(self, T=0.07):
1414
self.criterion = nn.CrossEntropyLoss()
1515

1616
def forward(self, q, k, memo_bank):
17+
N = q.size(0)
1718
k = k.detach()
1819
pos_logits = torch.einsum('ij,ij->i', [q, k]).unsqueeze(-1)
19-
neg_logits = torch.einsum('ij,jk->ik', [q, memo_bank.queue.clone()])
20+
neg_logits = torch.einsum('ij,kj->ik', [q, memo_bank.queue.clone()])
2021
logits = torch.cat([pos_logits, neg_logits], dim=1)
21-
labels = torch.zeros_like(logits, dtype=torch.long, device=logits.device)
22+
labels = torch.zeros(N, dtype=torch.long, device=logits.device)
2223
return self.criterion(logits / self.T, labels)
2324

2425

@@ -75,7 +76,7 @@ def dequeue_data(self, K=None):
7576
return self.queue
7677

7778
def dequeue_and_enqueue(self, k):
78-
self.queue_data(k)
79+
self.queue = self.queue_data(k)
7980
self.queue = self.dequeue_data()
8081

8182

simsiam/main.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,10 @@
3333
parser.add_argument('--bottleneck_dim', default=512, type=int, help='bottleneck dimension')
3434
parser.add_argument('--num_encoder_fcs', default=2, type=int, help='number of layers of fcs for encoder')
3535

36-
# knn monitor
37-
parser.add_argument('--knn-k', default=200, type=int, help='k in kNN monitor')
38-
parser.add_argument('--knn-t', default=0.1, type=float, help='softmax temperature in kNN monitor')
39-
4036
# misc.
41-
parser.add_argument('--data_root', default='../data', type=str, help='path to data')
42-
parser.add_argument('--logs_root', default='logs', type=str, help='path to logs')
43-
parser.add_argument('--check_point', default='check_point/simsiam.pth', type=str, help='path to model weights')
37+
parser.add_argument('--data_root', default='data', type=str, help='path to data')
38+
parser.add_argument('--logs_root', default='simsiam/logs', type=str, help='path to logs')
39+
parser.add_argument('--check_point', default='simsiam/check_point/simsiam.pth', type=str, help='path to model weights')
4440

4541
args = parser.parse_args()
4642

@@ -82,8 +78,8 @@ def cosine_loss(p, z):
8278

8379
writer = SummaryWriter(args.logs_root)
8480
model = SimSiam(args).to(device)
85-
Path(args.check_point.split('/')[0]).mkdir(parents=True, exist_ok=True)
86-
Path(args.logs_root).mkdir(parents=True, exist_ok=True)
81+
Path(args.check_point.split('/')[1]).mkdir(parents=True, exist_ok=True)
82+
Path(args.logs_root.split('/')[1]).mkdir(parents=True, exist_ok=True)
8783

8884
optimizer = torch.optim.SGD(model.parameters(), lr=args.lr,
8985
momentum=args.momentum, weight_decay=args.wd)

0 commit comments

Comments
 (0)