Skip to content

Commit e23cc54

Browse files
committed
add freeze_parameter for model components
1 parent f75d76d commit e23cc54

File tree

10 files changed

+150
-21
lines changed

10 files changed

+150
-21
lines changed

configs/baseline_res101.yaml

+9-4
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,24 @@
11
DATASET:
22
NAME: "vg"
33
MODE: "benchmark"
4-
TRAIN_BATCH_SIZE: 6
4+
TRAIN_BATCH_SIZE: 16
55
TEST_BATCH_SIZE: 1
66
MODEL:
7-
WEIGHT: "catalog://ImageNetPretrained/MSRA/R-101"
7+
WEIGHT_IMG: "catalog://ImageNetPretrained/MSRA/R-101"
8+
WEIGHT_DET: "checkpoints/vg_benchmark_object/R-101-C4/faster_rcnn/BatchSize_6/Base_LR_0.005/checkpoint_0099999.pth"
9+
RELATION_ON: True
810
ALGORITHM: "sg_baseline"
911
USE_FREQ_PRIOR: False
1012
BACKBONE:
1113
CONV_BODY: "R-101-C4"
14+
FREEZE_PARAMETER: True
15+
RPN:
16+
FREEZE_PARAMETER: True
1217
ROI_HEADS:
13-
BATCH_SIZE_PER_IMAGE: 384
18+
BATCH_SIZE_PER_IMAGE: 512
1419
ROI_BOX_HEAD:
1520
NUM_CLASSES: 151
16-
RELATION_ON: True
21+
FREEZE_PARAMETER: True
1722
ROI_RELATION_HEAD:
1823
NUM_CLASSES: 51
1924
SOLVER:

configs/faster_rcnn_res101.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ DATASET:
55
TRAIN_BATCH_SIZE: 6
66
TEST_BATCH_SIZE: 1
77
MODEL:
8-
WEIGHT: "catalog://ImageNetPretrained/MSRA/R-101"
8+
WEIGHT_IMG: "catalog://ImageNetPretrained/MSRA/R-101"
99
ALGORITHM: "faster_rcnn"
1010
BACKBONE:
1111
CONV_BODY: "R-101-C4"

configs/grcnn_res101.yaml

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
DATASET:
2+
NAME: "vg"
3+
MODE: "benchmark"
4+
TRAIN_BATCH_SIZE: 6
5+
TEST_BATCH_SIZE: 1
6+
MODEL:
7+
WEIGHT_IMG: "catalog://ImageNetPretrained/MSRA/R-101"
8+
WEIGHT_DET: ""
9+
ALGORITHM: "sg_grcnn"
10+
BACKBONE:
11+
CONV_BODY: "R-101-C4"
12+
FREEZE_PARAMETER: False
13+
RPN:
14+
FREEZE_PARAMETER: False
15+
ROI_HEADS:
16+
BATCH_SIZE_PER_IMAGE: 384
17+
ROI_BOX_HEAD:
18+
NUM_CLASSES: 151
19+
FREEZE_PARAMETER: False
20+
ROI_RELATION_HEAD:
21+
NUM_CLASSES: 51
22+
SOLVER:
23+
BASE_LR: 5e-3
24+
MAX_ITER: 100000
25+
STEPS: (70000,90000)
26+
CHECKPOINT_PERIOD: 5000

configs/imp_res101.yaml

+6-2
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,19 @@ DATASET:
44
TRAIN_BATCH_SIZE: 6
55
TEST_BATCH_SIZE: 1
66
MODEL:
7-
WEIGHT: "catalog://ImageNetPretrained/MSRA/R-101"
7+
WEIGHT_IMG: "catalog://ImageNetPretrained/MSRA/R-101"
8+
WEIGHT_DET: ""
89
ALGORITHM: "sg_imp"
910
BACKBONE:
1011
CONV_BODY: "R-101-C4"
12+
FREEZE_PARAMETER: False
13+
RPN:
14+
FREEZE_PARAMETER: False
1115
ROI_HEADS:
1216
BATCH_SIZE_PER_IMAGE: 384
1317
ROI_BOX_HEAD:
1418
NUM_CLASSES: 151
15-
RELATION_ON: True
19+
FREEZE_PARAMETER: False
1620
ROI_RELATION_HEAD:
1721
NUM_CLASSES: 51
1822
SOLVER:

configs/motifnet_res101.yaml

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
DATASET:
2+
NAME: "vg"
3+
MODE: "benchmark"
4+
TRAIN_BATCH_SIZE: 6
5+
TEST_BATCH_SIZE: 1
6+
MODEL:
7+
WEIGHT_IMG: "catalog://ImageNetPretrained/MSRA/R-101"
8+
WEIGHT_DET: ""
9+
ALGORITHM: "sg_motifnet"
10+
BACKBONE:
11+
CONV_BODY: "R-101-C4"
12+
FREEZE_PARAMETER: False
13+
RPN:
14+
FREEZE_PARAMETER: False
15+
ROI_HEADS:
16+
BATCH_SIZE_PER_IMAGE: 384
17+
ROI_BOX_HEAD:
18+
NUM_CLASSES: 151
19+
FREEZE_PARAMETER: False
20+
ROI_RELATION_HEAD:
21+
NUM_CLASSES: 51
22+
SOLVER:
23+
BASE_LR: 5e-3
24+
MAX_ITER: 100000
25+
STEPS: (70000,90000)
26+
CHECKPOINT_PERIOD: 5000

configs/msdn_res101.yaml

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
DATASET:
2+
NAME: "vg"
3+
MODE: "benchmark"
4+
TRAIN_BATCH_SIZE: 6
5+
TEST_BATCH_SIZE: 1
6+
MODEL:
7+
WEIGHT_IMG: "catalog://ImageNetPretrained/MSRA/R-101"
8+
WEIGHT_DET: "checkpoints/vg_benchmark_object/R-101-C4/faster_rcnn/BatchSize_6/Base_LR_0.005/checkpoint_0099999.pth"
9+
ALGORITHM: "sg_msdn"
10+
BACKBONE:
11+
CONV_BODY: "R-101-C4"
12+
FREEZE_PARAMETER: False
13+
RPN:
14+
FREEZE_PARAMETER: False
15+
ROI_HEADS:
16+
BATCH_SIZE_PER_IMAGE: 384
17+
ROI_BOX_HEAD:
18+
NUM_CLASSES: 151
19+
FREEZE_PARAMETER: False
20+
ROI_RELATION_HEAD:
21+
NUM_CLASSES: 51
22+
SOLVER:
23+
BASE_LR: 5e-3
24+
MAX_ITER: 100000
25+
STEPS: (70000,90000)
26+
CHECKPOINT_PERIOD: 5000

lib/config/defaults.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,18 @@
4646
_C.MODEL.DEVICE = "cuda"
4747
_C.MODEL.META_ARCHITECTURE = "GeneralizedRCNN"
4848
_C.MODEL.CLS_AGNOSTIC_BBOX_REG = False
49-
_C.MODEL.WEIGHT = ""
49+
_C.MODEL.WEIGHT_IMG = "" # weight loading path for imagenet pretrained model
50+
_C.MODEL.WEIGHT_DET = "" # weight loading path for detector pre-trained model
51+
_C.MODEL.WEIGHT_SGG = "" # weight loading path for scene graph generator pre-trained model
52+
5053
# If the WEIGHT starts with a catalog://, like :R-50, the code will look for
5154
# the path in paths_catalog. Else, it will use it as the specified absolute
5255
# path
5356
_C.MODEL.BACKBONE = CN() # Backbone options
5457
_C.MODEL.BACKBONE.CONV_BODY = "R-50-C4" # The backbone conv body to use # (e.g., 'FPN.add_fpn_ResNet101_conv5_body' to specify a ResNet-101-FPN backbone)
5558
_C.MODEL.BACKBONE.FREEZE_CONV_BODY_AT = 2 # Add StopGrad at a specified stage so the bottom layers are frozen
5659
_C.MODEL.BACKBONE.OUT_CHANNELS = 256 * 4
60+
_C.MODEL.BACKBONE.FREEZE_PARAMETER = False
5761

5862
_C.MODEL.FPN = CN() # FPN options
5963
_C.MODEL.FPN.USE_GN = False
@@ -110,6 +114,8 @@
110114
# Custom rpn head, empty to use default conv or separable conv
111115
_C.MODEL.RPN.RPN_HEAD = "SingleConvRPNHead"
112116

117+
_C.MODEL.RPN.FREEZE_PARAMETER = False
118+
113119
# ---------------------------------------------------------------------------- #
114120
# ROI HEADS options
115121
# ---------------------------------------------------------------------------- #
@@ -160,6 +166,7 @@
160166
_C.MODEL.ROI_BOX_HEAD.DILATION = 1
161167
_C.MODEL.ROI_BOX_HEAD.CONV_HEAD_DIM = 256
162168
_C.MODEL.ROI_BOX_HEAD.NUM_STACKED_CONVS = 4
169+
_C.MODEL.ROI_BOX_HEAD.FREEZE_PARAMETER = False
163170

164171
''''''
165172
_C.MODEL.ROI_RELATION_HEAD = CN()

lib/scene_parser/parser.py

+43-6
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from .rcnn.utils.comm import synchronize, get_rank
1515
from .rcnn.modeling.relation_heads.relation_heads import build_roi_relation_head
1616

17-
SCENE_PAESER_DICT = {"sg_baseline", "sg_imp"} #, "msdn": MSDN}
17+
SCENE_PAESER_DICT = {"sg_baseline", "sg_imp", "sg_msdn"} #, "msdn": MSDN}
1818

1919
class SceneParser(GeneralizedRCNN):
2020
"Scene Parser"
@@ -25,6 +25,41 @@ def __init__(self, cfg):
2525
self.rel_heads = None
2626
if cfg.MODEL.RELATION_ON and self.cfg.MODEL.ALGORITHM in SCENE_PAESER_DICT:
2727
self.rel_heads = build_roi_relation_head(cfg, self.backbone.out_channels)
28+
self._freeze_components(self.cfg)
29+
30+
def _freeze_components(self, cfg):
31+
if cfg.MODEL.BACKBONE.FREEZE_PARAMETER:
32+
for param in self.backbone.parameters():
33+
param.requires_grad = False
34+
35+
if cfg.MODEL.RPN.FREEZE_PARAMETER:
36+
for param in self.rpn.parameters():
37+
param.requires_grad = False
38+
39+
if cfg.MODEL.ROI_BOX_HEAD.FREEZE_PARAMETER:
40+
for param in self.roi_heads.box.parameters():
41+
param.requires_grad = False
42+
43+
def train(self):
44+
if self.cfg.MODEL.BACKBONE.FREEZE_PARAMETER:
45+
self.backbone.eval()
46+
else:
47+
self.backbone.train()
48+
49+
if self.cfg.MODEL.RPN.FREEZE_PARAMETER:
50+
self.rpn.eval()
51+
else:
52+
self.rpn.train()
53+
54+
if self.cfg.MODEL.ROI_BOX_HEAD.FREEZE_PARAMETER:
55+
self.roi_heads.eval()
56+
else:
57+
self.roi_heads.train()
58+
59+
self.rel_heads.train()
60+
61+
def eval(self):
62+
self.eval()
2863

2964
def forward(self, images, targets=None):
3065
"""
@@ -44,10 +79,11 @@ def forward(self, images, targets=None):
4479
images = to_image_list(images)
4580
features = self.backbone(images.tensors)
4681
proposals, proposal_losses = self.rpn(images, features, targets)
47-
82+
scene_parser_losses = {}
4883
if self.roi_heads:
49-
x, detections, scene_parser_losses = self.roi_heads(features, proposals, targets)
84+
x, detections, roi_heads_loss = self.roi_heads(features, proposals, targets)
5085
result = detections
86+
scene_parser_losses.update(roi_heads_loss)
5187

5288
if self.rel_heads:
5389
relation_features = features
@@ -60,8 +96,8 @@ def forward(self, images, targets=None):
6096
relation_features = x
6197
# During training, self.box() will return the unaltered proposals as "detections"
6298
# this makes the API consistent during training and testing
63-
x_pairs, detection_pairs, loss_relation = self.rel_heads(relation_features, detections, targets)
64-
losses.update(loss_relation)
99+
x_pairs, detection_pairs, rel_heads_loss = self.rel_heads(relation_features, detections, targets)
100+
scene_parser_losses.update(rel_heads_loss)
65101

66102
x = (x, x_pairs)
67103
result = (detections, detection_pairs)
@@ -109,5 +145,6 @@ def build_scene_parser_optimizer(cfg, model, local_rank=0, distributed=False):
109145
save_to_disk = get_rank() == 0
110146
checkpointer = SceneParserCheckpointer(cfg, model, optimizer, scheduler, save_dir, save_to_disk,
111147
logger=logging.getLogger("scene_graph_generation.checkpointer"))
112-
extra_checkpoint_data = checkpointer.load(cfg.MODEL.WEIGHT, resume=cfg.resume)
148+
model_weight =cfg.MODEL.WEIGHT_DET if cfg.MODEL.WEIGHT_DET != "" else cfg.MODEL.WEIGHT_IMG
149+
extra_checkpoint_data = checkpointer.load(model_weight, resume=cfg.resume)
113150
return optimizer, scheduler, checkpointer, extra_checkpoint_data

lib/scene_parser/rcnn/modeling/roi_heads/roi_heads.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,9 @@ def __init__(self, cfg, heads):
1919
self.keypoint.feature_extractor = self.box.feature_extractor
2020

2121
def forward(self, features, proposals, targets=None):
22-
losses = {}
2322
# TODO rename x to roi_box_features, if it doesn't increase memory consumption
2423
x, detections, loss_box = self.box(features, proposals, targets)
25-
losses.update(loss_box)
26-
return x, detections, losses
24+
return x, detections, loss_box
2725

2826

2927
def build_roi_heads(cfg, in_channels):

lib/scene_parser/rcnn/utils/checkpoint.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def save(self, name, **kwargs):
5050
self.tag_last_checkpoint(save_file)
5151

5252
def load(self, f=None, resume=0, use_latest=True):
53-
if self.has_last_checkpoint() and use_latest and resume == 0:
53+
if self.has_last_checkpoint() and use_latest and resume > 0:
5454
# override argument with existing checkpoint
5555
f = self.get_last_checkpoint_file()
5656
elif self.has_checkpoint(resume) and resume > 0:
@@ -62,14 +62,14 @@ def load(self, f=None, resume=0, use_latest=True):
6262
self.logger.info("Loading checkpoint from {}".format(f))
6363
checkpoint = self._load_file(f)
6464
self._load_model(checkpoint)
65-
if "optimizer" in checkpoint and self.optimizer:
65+
if "optimizer" in checkpoint and self.optimizer and "sg" in f:
6666
self.logger.info("Loading optimizer from {}".format(f))
6767
self.optimizer.load_state_dict(checkpoint.pop("optimizer"))
68-
if "scheduler" in checkpoint and self.scheduler:
68+
if "scheduler" in checkpoint and self.scheduler and "sg" in f:
6969
self.logger.info("Loading scheduler from {}".format(f))
7070
self.scheduler.load_state_dict(checkpoint.pop("scheduler"))
71+
checkpoint['iteration'] = resume # if we load detector, the we should not use its start iteration
7172

72-
# return any further checkpoint data
7373
return checkpoint
7474

7575
def has_last_checkpoint(self):

0 commit comments

Comments
 (0)