Skip to content

Commit ad831d4

Browse files
authored
Refine Evaluation and Resume (#7)
* refine cfg and scheduler * add auto resume * add vl_pretrain_load * rename test to eval * remove old train and eval engine * refine evaluation * delete useless print
1 parent f84cbe7 commit ad831d4

14 files changed

Lines changed: 118 additions & 151 deletions

File tree

config/common/train.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,35 +4,36 @@
44
from simrec.scheduler.lr_scheduler import WarmupCosineLR
55

66
train = dict(
7+
output_dir = "./test",
8+
warmup_epochs=3,
9+
epochs = 25,
10+
base_lr=1e-4,
11+
warmup_lr=1e-7,
12+
min_lr=1e-6,
713
batch_size=8,
8-
num_workers=8,
14+
log_period=1,
15+
data=dict(pin_memory=True, num_workers=8),
16+
scheduler=dict(
17+
name="cosine",
18+
decay_epochs=[30, 35, 37],
19+
lr_decay_rate=0.2,
20+
),
921
amp=dict(enabled=False),
1022
ddp=dict(
1123
backend="nccl",
1224
init_method="env://",
1325
),
1426
ema=dict(enabled=True, alpha=0.9997, buffer_ema=True),
15-
epochs = 25,
16-
output_dir = "./test",
17-
log_period = 1,
18-
resume=dict(enable=False, auto_resume=True, resume_path=""),
27+
clip_grad_norm=0.15,
28+
auto_resume=dict(enabled=True),
29+
resume_path="",
1930
vl_pretrain_weight="",
20-
21-
scheduler = LazyCall(WarmupCosineLR)(
22-
# optimizer and epochs and n_iter_per_epoch will be set in train.py
23-
warmup_epochs = 3,
24-
warmup_lr = 0.0000001,
25-
base_lr = 0.0001,
26-
min_lr = 0.000001,
27-
),
28-
2931
multi_scale_training=dict(
3032
enabled=True,
3133
img_scales=[[224,224],[256,256],[288,288],[320,320],[352,352],
3234
[384,384],[416,416],[448,448],[480,480],[512,512],
3335
[544,544],[576,576],[608,608]]
3436
),
35-
clip_grad_norm=0.15,
36-
log_image = False,
37+
log_image=False,
3738
seed = 123456,
3839
)

config/simrec_refcoco_scratch.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,6 @@
77

88
dataset.ann_path["refcoco"] = "/home/rentianhe/dataset/rec/anns/refcoco.json"
99
dataset.image_path["refcoco"] = "/home/rentianhe/dataset/rec/images/train2014"
10-
dataset.mask_path["refcoco"] = "/home/rentianhe/dataset/rec/masks/refcoco"
10+
dataset.mask_path["refcoco"] = "/home/rentianhe/dataset/rec/masks/refcoco"
11+
12+
train.resume_path = "/home/rentianhe/code/SimREC/output/ckpt_epoch_3.pth"

simrec/datasets/dataloader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ def build_loader(cfg, dataset: torch.utils.data.Dataset, rank: int, shuffle=True
3131
dataset,
3232
batch_size=cfg.train.batch_size,
3333
sampler=dist_sampler,
34-
num_workers=cfg.train.num_workers,
35-
pin_memory=True,
34+
num_workers=cfg.train.data.num_workers,
35+
pin_memory=cfg.train.data.pin_memory,
3636
drop_last=drop_last
3737
)
3838
return data_loader

simrec/models/heads/rec_heads.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ def __init__(
3333
super().__init__()
3434
# same padding
3535
pad = (ksize - 1) // 2
36-
print(in_channels,out_channels,ksize,stride)
3736
self.conv = nn.Conv2d(
3837
in_channels,
3938
out_channels,

simrec/models/mcn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from simrec.models.heads.mcn_heads import MCNhead
1919
from simrec.models.backbones.build import build_visual_encoder
2020
from simrec.models.language_encoders.build import build_language_encoder
21-
from simrec.layers.fusion_layer import SimpleFusion,MultiScaleFusion,GaranAttention
21+
from simrec.layers.fusion_layer import SimpleFusion, MultiScaleFusion, GaranAttention
2222

2323

2424
class MCN(nn.Module):

simrec/models/simrec.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -36,23 +36,11 @@ def __init__(
3636
super(SimREC, self).__init__()
3737
self.visual_encoder=visual_backbone
3838
self.lang_encoder=language_encoder
39-
# self.multi_scale_manner = MultiScaleFusion(v_planes=(512, 512, hidden_size), scaled=True)
4039
self.multi_scale_manner = multi_scale_manner
41-
# self.fusion_manner=nn.ModuleList(
42-
# [
43-
# SimpleFusion(v_planes=256, out_planes=512, q_planes=512),
44-
# SimpleFusion(v_planes=512, out_planes=512, q_planes=512),
45-
# SimpleFusion(v_planes=1024, out_planes=512, q_planes=512)
46-
# ]
47-
# )
4840
self.fusion_manner = fusion_manner
49-
# self.attention_manner=GaranAttention(512,512)
5041
self.attention_manner = attention_manner
5142
self.head=head
5243

53-
total = sum([param.nelement() for param in self.lang_encoder.parameters()])
54-
print(' + Number of lang enc params: %.2fM' % (total / 1e6))
55-
5644

5745
def frozen(self,module):
5846
if getattr(module,'module',False):

simrec/scheduler/build.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,31 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
from simrec.config import instantiate
16+
from .lr_scheduler import StepLR, WarmupCosineLR
1717

18-
def build_lr_scheduler(cfg, optimizer):
19-
"""Build learning rate scheduler, defined by ``cfg.train.scheduler``."""
20-
cfg.optimizer = optimizer
21-
scheduler = instantiate(cfg)
22-
return scheduler
18+
def build_lr_scheduler(cfg, optimizer, n_iter_per_epoch):
19+
"""Build learning rate scheduler."""
20+
scheduler_name = cfg.train.scheduler.name.lower()
21+
22+
scheduler = None
23+
if scheduler_name == "cosine":
24+
scheduler = WarmupCosineLR(
25+
optimizer=optimizer,
26+
warmup_epochs=cfg.train.warmup_epochs,
27+
epochs=cfg.train.epochs,
28+
warmup_lr=cfg.train.warmup_lr,
29+
base_lr=cfg.train.base_lr,
30+
min_lr=cfg.train.min_lr,
31+
n_iter_per_epoch=n_iter_per_epoch
32+
)
33+
elif scheduler_name == "step":
34+
scheduler = StepLR(
35+
optimizer=optimizer,
36+
warmup_epochs=cfg.train.warmup_epochs,
37+
epochs=cfg.train.epochs,
38+
decay_epochs=cfg.train.scheduler.decay_epochs,
39+
lr_decay_rate=cfg.train.lr_decay_rate,
40+
n_iter_per_epoch=n_iter_per_epoch,
41+
)
42+
43+
return scheduler

simrec/utils/checkpoint.py

Lines changed: 25 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -14,64 +14,23 @@
1414
# limitations under the License.
1515

1616
import os
17-
import warnings
1817

1918
import torch
20-
from torch.nn import DataParallel as DP
21-
from torch.nn.parallel import DistributedDataParallel as DDP
2219

20+
from simrec.utils.distributed import is_main_process
2321

24-
def save_ckpt(net, optimizer,scheduler, misc, __C):
25-
path = __C.CKPTs_PATH
26-
if not os.path.exists(path):
27-
os.mkdir(path)
28-
path += '/' + __C.VERSION
29-
if not os.path.exists(path):
30-
os.mkdir(path)
31-
assert isinstance(misc, dict)
32-
if isinstance(net, DP) or isinstance(net, DDP):
33-
path += '/' + 'dist_'
34-
path += str(misc['epoch']) + '.pth.tar'
35-
ckpt = {
36-
'net_state_dict': net.state_dict(),
37-
'optimizer_state_dict': optimizer.state_dict(),
38-
'scheduler':scheduler.state_dict(),
39-
'epoch':misc['epoch'],
40-
'lr':optimizer.param_groups[0]["lr"],
41-
}
42-
torch.save(ckpt, path)
43-
44-
45-
def load_ckpt(net, optimizer,scheduler, path, rank=None):
46-
loc = f'cuda:{rank}' if rank is not None else None
47-
ckpt = torch.load(path, map_location=loc)
48-
49-
flag = isinstance(net, DP) or isinstance(net, DDP)
50-
if '_dist' in path:
51-
if not flag:
52-
for name in ckpt['net_state_dict']:
53-
assert name.startswith('module.')
54-
ckpt['net_state_dict'][name.lstrip('module.')] = ckpt['net_state_dict'].pop(name)
55-
else:
56-
if flag:
57-
for name in ckpt['net_state_dict']:
58-
ckpt['net_state_dict']['module.' + name] = ckpt['net_state_dict'].pop(name)
59-
60-
optimizer.load_state_dict(ckpt['optimizer_state_dict'])
6122

62-
scheduler.load_state_dict(ckpt['scheduler'])
63-
64-
missing, unexpected = net.load_state_dict(ckpt['net_state_dict'], strict=False)
65-
if unexpected.__len__ != 0:
66-
warnings.warn(f'Current model misses {unexpected.__len__} parameters from checkpointing model')
67-
for name in missing:
68-
print('\n' + name + '\n')
69-
if missing.__len__ != 0:
70-
warnings.warn(f'Current model contains {missing.__len__} parameters that checkpointing model doesn\'t contain')
71-
for name in unexpected:
72-
print('\n' + name + '\n')
73-
74-
return ckpt
23+
def load_checkpoint(cfg, model, optimizer, scheduler, logger):
24+
logger.info(f"==============> Resuming form {cfg.train.resume_path}....................")
25+
checkpoint = torch.load(cfg.train.resume_path, map_location=lambda storage, loc: storage.cuda())
26+
msg = model.load_state_dict(checkpoint['state_dict'], strict=False)
27+
logger.info(msg)
28+
optimizer.load_state_dict(checkpoint["optimizer"])
29+
scheduler.load_state_dict(checkpoint["scheduler"])
30+
start_epoch = checkpoint["epoch"]
31+
logger.info("==> loaded checkpoint from {}\n".format(cfg.train.resume_path) +
32+
"==> epoch: {} lr: {} ".format(checkpoint['epoch'],checkpoint['lr']))
33+
return start_epoch + 1
7534

7635

7736
def save_checkpoint(cfg, epoch, model, optimizer, scheduler, logger, det_best=False, seg_best=False):
@@ -99,4 +58,16 @@ def save_checkpoint(cfg, epoch, model, optimizer, scheduler, logger, det_best=Fa
9958
if seg_best:
10059
seg_best_model_path = os.path.join(cfg.train.output_dir, f'seg_best_model.pth')
10160
torch.save(save_state, seg_best_model_path)
102-
logger.info(f"checkpoints saved !!!")
61+
logger.info(f"checkpoints saved !!!")
62+
63+
64+
def auto_resume_helper(output_dir):
65+
checkpoints = os.listdir(output_dir)
66+
checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth')]
67+
print(f"All checkpoints founded in {output_dir}: {checkpoints}")
68+
if len(checkpoints) > 0:
69+
resume_file = os.path.join(output_dir, "last_checkpoint.pth")
70+
else:
71+
resume_file = None
72+
73+
return resume_file

simrec/utils/distributed.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,11 +75,17 @@ def synchronize():
7575
dist.barrier()
7676

7777

78-
7978
def cleanup_distributed():
8079
dist.destroy_process_group()
8180

8281

82+
def reduce_tensor(tensor):
83+
rt = tensor.clone()
84+
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
85+
rt /= dist.get_world_size()
86+
return rt
87+
88+
8389
def reduce_meters(meters, rank, cfg):
8490
"""Sync and flush meters."""
8591
assert isinstance(meters, dict), "collect AverageMeters into a dict"
@@ -94,6 +100,7 @@ def reduce_meters(meters, rank, cfg):
94100
value = torch.mean(torch.cat(avg_reduce)).item()
95101
meter.update_reduce(value)
96102

103+
97104
def find_free_port():
98105
import socket
99106
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
File renamed without changes.

0 commit comments

Comments
 (0)