-
Notifications
You must be signed in to change notification settings - Fork 17
/
models.py
44 lines (34 loc) · 1.23 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import torch
import math
import torchvision
import torch.nn as nn
import torch.nn.functional as F
class ImgNet(nn.Module):
def __init__(self, code_len):
super(ImgNet, self).__init__()
self.alexnet = torchvision.models.alexnet(pretrained=True)
self.alexnet.classifier = nn.Sequential(*list(self.alexnet.classifier.children())[:6])
self.fc_encode = nn.Linear(4096, code_len)
self.alpha = 1.0
def forward(self, x):
x = self.alexnet.features(x)
x = x.view(x.size(0), -1)
feat = self.alexnet.classifier(x)
hid = self.fc_encode(feat)
code = F.tanh(self.alpha * hid)
return feat, hid, code
def set_alpha(self, epoch):
self.alpha = math.pow((1.0 * epoch + 1.0), 0.5)
class TxtNet(nn.Module):
def __init__(self, code_len, txt_feat_len):
super(TxtNet, self).__init__()
self.fc1 = nn.Linear(txt_feat_len, 4096)
self.fc2 = nn.Linear(4096, code_len)
self.alpha = 1.0
def forward(self, x):
feat = F.relu(self.fc1(x))
hid = self.fc2(feat)
code = F.tanh(self.alpha * hid)
return feat, hid, code
def set_alpha(self, epoch):
self.alpha = math.pow((1.0 * epoch + 1.0), 0.5)