-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathevaluate.py
More file actions
97 lines (78 loc) · 3.05 KB
/
evaluate.py
File metadata and controls
97 lines (78 loc) · 3.05 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import torch
def mean_average_precision(query_code,
database_code,
query_labels,
database_labels,
device,
topk=None,
):
"""
Calculate mean average precision(map).
Args:
query_code (torch.Tensor): Query data hash code.
database_code (torch.Tensor): Database data hash code.
query_labels (torch.Tensor): Query data targets, one-hot
database_labels (torch.Tensor): Database data targets, one-host
device (torch.device): Using CPU or GPU.
topk (int): Calculate top k data map.
Returns:
meanAP (float): Mean Average Precision.
"""
num_query = query_labels.shape[0]
mean_AP = 0.0
for i in range(num_query):
# Retrieve images from database
retrieval = (query_labels[i, :] @ database_labels.t() > 0).float()
# Calculate hamming distance
hamming_dist = 0.5 * (database_code.shape[1] - query_code[i, :] @ database_code.t())
# Arrange position according to hamming distance
retrieval = retrieval[torch.argsort(hamming_dist)][:topk]
# Retrieval count
retrieval_cnt = retrieval.sum().int().item()
# Can not retrieve images
if retrieval_cnt == 0:
continue
# Generate score for every position
score = torch.linspace(1, retrieval_cnt, retrieval_cnt).to(device)
# Acquire index
index = (torch.nonzero(retrieval == 1).squeeze() + 1.0).float()
mean_AP += (score / index).mean()
mean_AP = mean_AP / num_query
return mean_AP
def pr_curve(query_code, retrieval_code, query_targets, retrieval_targets, device):
"""
P-R curve.
Args
query_code(torch.Tensor): Query hash code.
retrieval_code(torch.Tensor): Retrieval hash code.
query_targets(torch.Tensor): Query targets.
retrieval_targets(torch.Tensor): Retrieval targets.
device (torch.device): Using CPU or GPU.
Returns
P(torch.Tensor): Precision.
R(torch.Tensor): Recall.
"""
num_query = query_code.shape[0]
num_bit = query_code.shape[1]
P = torch.zeros(num_query, num_bit + 1).to(device)
R = torch.zeros(num_query, num_bit + 1).to(device)
for i in range(num_query):
gnd = (query_targets[i].unsqueeze(0).mm(retrieval_targets.t()) > 0).float().squeeze()
tsum = torch.sum(gnd)
if tsum == 0:
continue
hamm = 0.5 * (retrieval_code.shape[1] - query_code[i, :] @ retrieval_code.t())
tmp = (hamm <= torch.arange(0, num_bit + 1).reshape(-1, 1).float().to(device)).float()
total = tmp.sum(dim=-1)
total = total + (total == 0).float() * 0.1
t = gnd * tmp
count = t.sum(dim=-1)
p = count / total
r = count / tsum
P[i] = p
R[i] = r
mask = (P > 0).float().sum(dim=0)
mask = mask + (mask == 0).float() * 0.1
P = P.sum(dim=0) / mask
R = R.sum(dim=0) / mask
return P, R