-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtrain.py
79 lines (54 loc) · 2.19 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import torch
import Model
import torch.nn as nn
import numpy as np
import torch.optim as optim
from Model import MyDataset
from torch.utils.data import DataLoader
def trainModel(model, file_name):
USE_CUDA = torch.cuda.is_available()
save_name = ""
dir_list = file_name.split("/")
for i in range(len(dir_list) - 1):
save_name += dir_list[i] + "/"
boundary = len(np.load(save_name + "ClusterList.npy", allow_pickle=True))
if USE_CUDA:
net = model(boundary).cuda()
else:
net = model(boundary)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
trainset = MyDataset(save_name + "train_with_label.npy")
trainloader = DataLoader(trainset, batch_size=8, shuffle=True, num_workers=0)
for epoch in range(10): # loop over the dataset multiple times
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
# get the inputs; data is a list of [inputs, labels]
inputs, labels = data
if USE_CUDA:
inputs, labels = inputs.cuda(), labels.cuda()
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
if i % 100 == 99: # print every 2000 mini-batches
print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 100))
running_loss = 0.0
net.eval()
torch.save(net.state_dict(), save_name + "net_latest.pth")
if __name__ == '__main__':
ModleList = [Model.AudioQuery, Model.SunQuery, Model.EnronQuery,
Model.NuswideQuery, Model.NotreQuery]
fileList = ["Zipf/audio/datasetKnn.hdf5", "Zipf/sun/datasetKnn.hdf5",
"Zipf/enron/datasetKnn.hdf5", "Zipf/nuswide/datasetKnn.hdf5",
"Zipf/notre/datasetKnn.hdf5"]
for i in range(5):
model = ModleList[i]
file_name = fileList[i]
print("train", file_name)
trainModel(model, file_name)