-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
115 lines (97 loc) · 3.27 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from model_bn import Texture_model
from dataset import Texture_dataset_train, Texture_dataset_val
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
from sys import stdout
import time
import os
class WCE_loss(nn.Module):
def __init__(self):
super(WCE_loss, self).__init__()
def sum_ij(self, x):
return torch.sum(torch.sum(x, dim=3), dim=2)
def forward(self, pred, gt):
N_fg = self.sum_ij(gt)
N_bg = self.sum_ij(1 - gt)
L_fg = -1 * self.sum_ij(torch.log(pred + 1e-16) * gt) / N_fg
L_bg = -1 * self.sum_ij(torch.log(1 - pred + 1e-16) * (1 - gt)) / N_bg
L = L_fg + L_bg
return torch.mean(L)
def train():
num_epoch = 20000
cur_lr = initial_lr = 1e-5
steps_to_decay_lr = 500
steps_to_save = 5
num_max_model = 5
saved_best_model = []
saved_model = []
best_loss = np.inf
save_model_path = '/fast_data/one_shot_texture_models/'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = Texture_model()
model.to(device)
train_dataset = Texture_dataset_train(200, 'train_texture.npy')
val_dataset = Texture_dataset_train(240, 'val_texture.npy')
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=4)
val_dataloader = DataLoader(val_dataset, batch_size=8, shuffle=True, num_workers=4)
loss_func = WCE_loss()
opt = torch.optim.Adam(model.parameters(), lr=initial_lr, weight_decay=1e-3)
for epoch in range(num_epoch):
tic = time.time()
training_loss = 0
testing_loss = 0
for i, batch in enumerate(train_dataloader):
x, y, x_ref = batch
x, y, x_ref = x.to(device), y.to(device), x_ref.to(device)
opt.zero_grad()
loss = loss_func(model(x, x_ref), y)
loss.backward()
opt.step()
training_loss += loss.item()
stdout.write('\r%d' % i)
stdout.flush()
# print(i)
training_loss /= i
# print(training_loss)
model.eval()
with torch.no_grad():
for i, batch in enumerate(val_dataloader):
x, y, x_ref = batch
x, y, x_ref = x.to(device), y.to(device), x_ref.to(device)
loss = loss_func(model(x, x_ref), y)
testing_loss += loss.item()
testing_loss /= i
toc = time.time()
print('\r%5d/%5d training loss: %.5f, validation loss: %.5f (%d sec)' % (epoch + 1, num_epoch, training_loss, testing_loss, toc - tic))
model.train()
if testing_loss < best_loss:
best_loss = testing_loss
model_name = "model_%.5f.pt" % best_loss
saved_best_model.append(model_name)
torch.save(model.state_dict(), save_model_path + model_name)
if len(saved_best_model) > 5:
rm = saved_best_model.pop(0)
if os.path.exists(save_model_path + rm):
os.remove(save_model_path + rm)
else:
print('Can not find file', save_model_path + rm)
if (epoch + 1) % steps_to_save == 0:
model_name = 'model_%03d.pt' % (epoch + 1)
saved_model.append(model_name)
torch.save(model.state_dict(), save_model_path + model_name)
if len(saved_model) > 5:
rm = saved_model.pop(0)
if os.path.exists(save_model_path + rm):
os.remove(save_model_path + rm)
else:
print('Can not find file', save_model_path + rm)
if (epoch + 1) % steps_to_decay_lr == 0:
cur_lr /= 2
for g in opt.param_groups:
g['lr'] = cur_lr
print("Reducing learning rate to", cur_lr)
if __name__ == '__main__':
train()