-
Notifications
You must be signed in to change notification settings - Fork 0
/
NIMA.py
62 lines (51 loc) · 1.81 KB
/
NIMA.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
"""
file - model.py
Implements the aesthemic model and emd loss used in paper.
Copyright (C) Yunxiao Shi 2017 - 2020
NIMA is released under the MIT license. See LICENSE for the fill license text.
"""
import torch
import torch.nn as nn
class NIMA(nn.Module):
"""Neural IMage Assessment model by Google"""
def __init__(self, base_model, num_classes=10):
super(NIMA, self).__init__()
self.features = base_model.features
self.classifier = nn.Sequential(
nn.Dropout(p=0.75),
nn.Linear(in_features=25088, out_features=num_classes),
nn.Softmax())
def forward(self, x):
out = self.features(x)
out = out.view(out.size(0), -1)
print(out.size())
out = self.classifier(out)
return out
def single_emd_loss(p, q, r=2):
"""
Earth Mover's Distance of one sample
Args:
p: true distribution of shape num_classes × 1
q: estimated distribution of shape num_classes × 1
r: norm parameter
"""
assert p.shape == q.shape, "Length of the two distribution must be the same"
length = p.shape[0]
emd_loss = 0.0
for i in range(1, length + 1):
emd_loss += torch.abs(sum(p[:i] - q[:i])) ** r
return (emd_loss / length) ** (1. / r)
def emd_loss(p, q, r=2):
"""
Earth Mover's Distance on a batch
Args:
p: true distribution of shape mini_batch_size × num_classes × 1
q: estimated distribution of shape mini_batch_size × num_classes × 1
r: norm parameters
"""
assert p.shape == q.shape, "Shape of the two distribution batches must be the same."
mini_batch_size = p.shape[0]
loss_vector = []
for i in range(mini_batch_size):
loss_vector.append(single_emd_loss(p[i], q[i], r=r))
return sum(loss_vector) / mini_batch_size