Skip to content

Commit

Permalink
Update Evaluate.py
Browse files Browse the repository at this point in the history
  • Loading branch information
hyunjp authored Jul 17, 2020
1 parent 69a37dc commit f02dd0a
Showing 1 changed file with 6 additions and 8 deletions.
14 changes: 6 additions & 8 deletions Evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,15 @@
parser.add_argument('--gpus', nargs='+', type=str, help='gpus')
parser.add_argument('--batch_size', type=int, default=4, help='batch size for training')
parser.add_argument('--test_batch_size', type=int, default=1, help='batch size for test')
parser.add_argument('--epochs', type=int, default=60, help='number of epochs for training')
parser.add_argument('--loss_compact', type=float, default=0.01, help='weight of the feature compactness loss')
parser.add_argument('--loss_separate', type=float, default=0.01, help='weight of the feature separateness loss')
parser.add_argument('--h', type=int, default=256, help='height of input images')
parser.add_argument('--w', type=int, default=256, help='width of input images')
parser.add_argument('--c', type=int, default=3, help='channel of input images')
parser.add_argument('--lr', type=float, default=2e-4, help='initial learning rate')
parser.add_argument('--t_length', type=int, default=5, help='length of the frame sequences')
parser.add_argument('--fdim', type=int, default=512, help='channel dimension of the features')
parser.add_argument('--mdim', type=int, default=512, help='channel dimension of the memory items')
parser.add_argument('--msize', type=int, default=10, help='number of the memory items')
parser.add_argument('--alpha', type=float, default=0.7, help='weight for the anomality score')
parser.add_argument('--alpha', type=float, default=0.6, help='weight for the anomality score')
parser.add_argument('--th', type=float, default=0.01, help='threshold for test updating')
parser.add_argument('--num_workers', type=int, default=2, help='number of workers for the train loader')
parser.add_argument('--num_workers_test', type=int, default=1, help='number of workers for the test loader')
parser.add_argument('--dataset_type', type=str, default='ped2', help='type of dataset: ped2, avenue, shanghai')
Expand Down Expand Up @@ -77,6 +74,7 @@
test_batch = data.DataLoader(test_dataset, batch_size = args.test_batch_size,
shuffle=False, num_workers=args.num_workers_test, drop_last=False)

loss_func_mse = nn.MSELoss(reduction='none')

# Loading the trained model
model = torch.load(args.model_dir)
Expand All @@ -102,7 +100,7 @@
print('Evaluation of', args.dataset_type)

# Setting for video anomaly detection
for video in sorted(videos_new):
for video in sorted(videos_list):
video_name = video.split('/')[-1]
labels_list = np.append(labels_list, labels[0][4+label_length:videos[video_name]['length']+label_length])
label_length += videos[video_name]['length']
Expand Down Expand Up @@ -131,7 +129,7 @@
# Calculating the threshold for updating at the test time
point_sc = point_score(outputs, imgs[:,3*4:])

if point_sc < threshold:
if point_sc < args.th:
query = F.normalize(feas, dim=1)
query = query.permute(0,2,3,1) # b X h X w X d
m_items_test = model.memory.update(query, m_items_test, False)
Expand All @@ -142,7 +140,7 @@

# Measuring the abnormality score and the AUC
anomaly_score_total_list = []
for video in sorted(videos_new):
for video in sorted(videos_list):
video_name = video.split('/')[-1]
anomaly_score_total_list += score_sum(anomaly_score_list(psnr_list[video_name]),
anomaly_score_list_inv(feature_distance_list[video_name]), args.alpha)
Expand Down

0 comments on commit f02dd0a

Please sign in to comment.