|
107 | 107 | feature_bank, targets = [], []
|
108 | 108 | # get current feature maps & fit LR
|
109 | 109 | for data, target in feature_loader:
|
110 |
| - data = data.to(device) |
| 110 | + data, target = data.to(device), target.to(device) |
111 | 111 | with torch.no_grad():
|
112 |
| - feature = model(data, istrain=False) |
| 112 | + feature = model(data) |
113 | 113 | feature = F.normalize(feature, dim=1)
|
114 |
| - feature_bank.append(feature) |
| 114 | + feature_bank.append(feature.clone().detach()) |
115 | 115 | targets.append(target)
|
116 | 116 | feature_bank = torch.cat(feature_bank, dim=0)
|
117 | 117 | feature_labels = torch.cat(targets, dim=0)
|
|
121 | 121 |
|
122 | 122 | y_preds, y_trues = [], []
|
123 | 123 | for data, target in test_loader:
|
124 |
| - data = data.to(device) |
| 124 | + data, target = data.to(device), target.to(device) |
125 | 125 | with torch.no_grad():
|
126 |
| - feature = model(data, istrain=False) |
| 126 | + feature = model(data) |
127 | 127 | feature = F.normalize(feature, dim=1)
|
128 |
| - y_preds.append(linear_classifier.predict(feature)) |
| 128 | + y_preds.append(linear_classifier.predict(feature.detach())) |
129 | 129 | y_trues.append(target)
|
130 | 130 | y_trues = torch.cat(y_trues, dim=0)
|
131 | 131 | y_preds = torch.cat(y_preds, dim=0)
|
132 |
| - top1acc = (y_trues == y_preds).sum() / y_preds.size(0) |
133 |
| - |
| 132 | + top1acc = y_trues.eq(y_preds).sum().item() / y_preds.size(0) |
134 | 133 | writer.add_scalar('Top Acc @ 1', top1acc, global_step=epoch)
|
135 | 134 | writer.add_scalar('Representation Standard Deviation', feature_bank.std(), global_step=epoch)
|
136 | 135 |
|
137 |
| - tqdm.write('#########################################################') |
| 136 | + tqdm.write('###################################################################') |
138 | 137 | tqdm.write(
|
139 | 138 | f'Epoch {epoch + 1}/{args.epochs}, \
|
140 | 139 | Train Loss: {sum(train_losses) / len(train_losses):.3f}, \
|
|
0 commit comments