Skip to content

Commit 35203d6

Browse files
authored
Update train.py
1 parent 86b4990 commit 35203d6

File tree

1 file changed

+246
-0
lines changed

1 file changed

+246
-0
lines changed

train.py

+246
Original file line numberDiff line numberDiff line change
@@ -1 +1,247 @@
1+
import os
2+
import warnings
3+
from datetime import datetime
4+
from pprint import pprint
5+
import numpy as np
6+
import torch
7+
import torch.backends.cudnn as torchcudnn
8+
from tensorboardX import SummaryWriter
9+
from torch.nn import CrossEntropyLoss
10+
from torch.optim import SGD, Adam
11+
with warnings.catch_warnings():
12+
warnings.filterwarnings("ignore", category=FutureWarning)
13+
import argparse
14+
import random
15+
import network
16+
from config import arg_config, proj_root
17+
from data.OBdataset import create_loader
18+
from utils.misc import (AvgMeter, construct_path_dict,
19+
make_log, pre_mkdir)
120

21+
22+
parser = argparse.ArgumentParser(description='Model2_multiscale_fix_fm_alpha_test')
23+
parser.add_argument('--kernel_size', type=int, default=3, help='kernel size',
24+
choices=[1, 3, 5, 7])
25+
parser.add_argument('--multi_scale', type=int, default=2, help='kernel size',
26+
choices=[1, 2, 3, 4, 5])
27+
parser.add_argument('--ex_name', type=str, default="train_topnet3")
28+
parser.add_argument('--resume', action='store_true', help='resume from checkpoint')
29+
30+
args_2 = parser.parse_args()
31+
32+
def setup_seed(seed):
33+
torch.manual_seed(seed)
34+
torch.cuda.manual_seed_all(seed)
35+
np.random.seed(seed)
36+
random.seed(seed)
37+
torch.backends.cudnn.deterministic = True
38+
39+
setup_seed(0)
40+
torchcudnn.benchmark = True
41+
torchcudnn.enabled = True
42+
torchcudnn.deterministic = True
43+
44+
45+
class Trainer:
46+
def __init__(self, args,writer):
47+
super(Trainer, self).__init__()
48+
self.args = args
49+
pprint(self.args)
50+
51+
if self.args["suffix"]:
52+
self.model_name = self.args["model"] + "_" + self.args["suffix"]
53+
else:
54+
self.model_name = self.args["model"]
55+
self.path = construct_path_dict(proj_root=proj_root, exp_name=args_2.ex_name)
56+
57+
pre_mkdir(path_config=self.path)
58+
59+
self.save_path = self.path["save"]
60+
self.save_pre = self.args["save_pre"]
61+
self.bestF1 = 0.
62+
63+
self.tr_loader = create_loader(
64+
self.args["tr_data_path"], self.args["bg_dir"], self.args["fg_dir"], self.args["mask_dir"],
65+
self.args["input_size"], 'train', self.args["batch_size"], self.args["num_workers"], True,
66+
)
67+
68+
self.dev = torch.device(f'cuda:{arg_config["gpu_id"]}') if torch.cuda.is_available() else "cpu"
69+
self.net = getattr(network, self.args["model"])(pretrained=True).to(self.dev)
70+
self.loss = CrossEntropyLoss(ignore_index=255, reduction=self.args["reduction"]).to(self.dev)
71+
self.opti = self.make_optim()
72+
self.end_epoch = self.args["epoch_num"]
73+
if self.args["resume"]:
74+
try:
75+
self.resume_checkpoint(load_path=self.path["final_full_net"], mode="all")
76+
except:
77+
print(f"{self.path['final_full_net']} does not exist and we will load {self.path['final_state_net']}")
78+
self.resume_checkpoint(load_path=self.path["final_state_net"], mode="onlynet")
79+
self.start_epoch = self.end_epoch
80+
else:
81+
self.start_epoch = 0
82+
self.iter_num = self.end_epoch * len(self.tr_loader)
83+
84+
def total_loss(self, train_preds, train_alphas):
85+
loss_list = []
86+
loss_item_list = []
87+
88+
assert len(self.loss_funcs) != 0, "please determine loss function`self.loss_funcs`"
89+
for loss in self.loss_funcs:
90+
loss_out = loss(train_preds, train_alphas)
91+
loss_list.append(loss_out)
92+
loss_item_list.append(f"{loss_out.item():.5f}")
93+
94+
train_loss = sum(loss_list)
95+
return train_loss, loss_item_list
96+
97+
def train(self):
98+
for curr_epoch in range(self.start_epoch, self.end_epoch):
99+
self.net.train()
100+
train_loss_record = AvgMeter()
101+
out_loss_record = AvgMeter()
102+
if self.args["lr_type"] == "poly":
103+
self.change_lr(curr_epoch)
104+
elif self.args["lr_type"] == "decay":
105+
self.change_lr(curr_epoch)
106+
elif self.args["lr_type"] == "all_decay":
107+
self.change_lr(curr_epoch)
108+
else:
109+
raise NotImplementedError
110+
for train_batch_id, train_data in enumerate(self.tr_loader):
111+
curr_iter = curr_epoch * len(self.tr_loader) + train_batch_id
112+
113+
self.opti.zero_grad()
114+
index,train_bgs, train_masks, train_fgs, train_targets, num, composite_list, feature_pos, w, h, savename = train_data
115+
train_bgs = train_bgs.to(self.dev, non_blocking=True)
116+
train_masks = train_masks.to(self.dev, non_blocking=True)
117+
train_fgs = train_fgs.to(self.dev, non_blocking=True)
118+
train_targets = train_targets.to(self.dev, non_blocking=True)
119+
num = num.to(self.dev, non_blocking=True)
120+
composite_list = composite_list.to(self.dev, non_blocking=True)
121+
feature_pos = feature_pos.to(self.dev, non_blocking=True)
122+
123+
train_outs, feature_map = self.net(train_bgs, train_fgs, train_masks, 'train')
124+
out_loss = self.loss(train_outs, train_targets.long())
125+
train_loss = out_loss
126+
127+
train_loss.backward()
128+
self.opti.step()
129+
train_iter_loss = train_loss.item()
130+
train_batch_size = train_bgs.size(0)
131+
train_loss_record.update(train_iter_loss, train_batch_size)
132+
if self.args["print_freq"] > 0 and (curr_iter + 1) % self.args["print_freq"] == 0:
133+
log = (
134+
f"[I:{curr_iter}/{self.iter_num}][E:{curr_epoch}:{self.end_epoch}]>"
135+
f"[Lr:{self.opti.param_groups[0]['lr']:.7f}]"
136+
f"(L2)[Avg:{train_loss_record.avg:.3f}|Cur:{train_iter_loss:.3f}]"
137+
)
138+
writer.add_scalar('Train/train_loss', train_loss_record.avg, curr_iter)
139+
writer.add_scalar('Train/out_loss', out_loss_record.avg, curr_iter)
140+
print(log)
141+
make_log(self.path["tr_log"], log)
142+
checkpoint_path = os.path.join(self.args["checkpoint_dir"], '{}_state.pth'.format(curr_epoch))
143+
torch.save(self.net.state_dict(), checkpoint_path)
144+
145+
146+
147+
148+
149+
def change_lr(self, curr):
150+
total_num = self.end_epoch
151+
if self.args["lr_type"] == "poly":
152+
ratio = pow((1 - float(curr) / total_num), self.args["lr_decay"])
153+
self.opti.param_groups[0]["lr"] = self.opti.param_groups[0]["lr"] * ratio
154+
self.opti.param_groups[1]["lr"] = self.opti.param_groups[0]["lr"]
155+
elif self.args["lr_type"] == "decay":
156+
ratio = 0.1
157+
if (curr % 9 == 0):
158+
self.opti.param_groups[0]["lr"] = self.opti.param_groups[0]["lr"] * ratio
159+
self.opti.param_groups[1]["lr"] = self.opti.param_groups[0]["lr"]
160+
elif self.args["lr_type"] == "all_decay":
161+
lr = self.args["lr"] * (0.5 ** (curr // 2))
162+
for param_group in self.opti.param_groups:
163+
param_group['lr'] = lr
164+
else:
165+
raise NotImplementedError
166+
167+
def make_optim(self):
168+
if self.args["optim"] == "sgd_trick":
169+
params = [
170+
{
171+
"params": [p for name, p in self.net.named_parameters() if ("bias" in name or "bn" in name)],
172+
"weight_decay": 0,
173+
},
174+
{
175+
"params": [
176+
p for name, p in self.net.named_parameters() if ("bias" not in name and "bn" not in name)
177+
]
178+
},
179+
]
180+
optimizer = SGD(
181+
params,
182+
lr=self.args["lr"],
183+
momentum=self.args["momentum"],
184+
weight_decay=self.args["weight_decay"],
185+
nesterov=self.args["nesterov"],
186+
)
187+
elif self.args["optim"] == "f3_trick":
188+
backbone, head = [], []
189+
for name, params_tensor in self.net.named_parameters():
190+
if "encoder" in name:
191+
backbone.append(params_tensor)
192+
else:
193+
head.append(params_tensor)
194+
params = [
195+
{"params": backbone, "lr": 0.1 * self.args["lr"]},
196+
{"params": head, "lr": self.args["lr"]},
197+
]
198+
optimizer = SGD(
199+
params=params,
200+
momentum=self.args["momentum"],
201+
weight_decay=self.args["weight_decay"],
202+
nesterov=self.args["nesterov"],
203+
)
204+
elif self.args["optim"] == "Adam_trick":
205+
optimizer = Adam(filter(lambda p: p.requires_grad, self.net.parameters()), lr=self.args["lr"])
206+
else:
207+
raise NotImplementedError
208+
print("optimizer = ", optimizer)
209+
return optimizer
210+
211+
def save_checkpoint(self, current_epoch, full_net_path, state_net_path):
212+
state_dict = {
213+
"epoch": current_epoch,
214+
"net_state": self.net.state_dict(),
215+
"opti_state": self.opti.state_dict(),
216+
}
217+
torch.save(state_dict, full_net_path)
218+
torch.save(self.net.state_dict(), state_net_path)
219+
220+
def resume_checkpoint(self, load_path, mode="all"):
221+
if os.path.exists(load_path) and os.path.isfile(load_path):
222+
print(f" =>> loading checkpoint '{load_path}' <<== ")
223+
checkpoint = torch.load(load_path, map_location=self.dev)
224+
if mode == "all":
225+
self.start_epoch = 0
226+
self.net.load_state_dict(checkpoint["net_state"])
227+
self.opti.load_state_dict(checkpoint["opti_state"])
228+
print(f" ==> loaded checkpoint '{load_path}' (epoch {checkpoint['epoch']})")
229+
elif mode == "onlynet":
230+
self.net.load_state_dict(checkpoint)
231+
print(f" ==> loaded checkpoint '{load_path}' " f"(only has the net's weight params) <<== ")
232+
else:
233+
raise NotImplementedError
234+
else:
235+
raise Exception(f"{load_path}please check the load path")
236+
237+
238+
239+
if __name__ == "__main__":
240+
print(torch.device(f'cuda:{arg_config["gpu_id"]}') if torch.cuda.is_available() else "cpu")
241+
writer = SummaryWriter(logdir="writer/train")
242+
trainer = Trainer(arg_config,writer)
243+
print(f" ===========>> {datetime.now()}: Begin training <<=========== ")
244+
245+
trainer.train()
246+
print(f" ===========>> {datetime.now()}: End training <<=========== ")
247+

0 commit comments

Comments
 (0)