14
14
from .rcnn .utils .comm import synchronize , get_rank
15
15
from .rcnn .modeling .relation_heads .relation_heads import build_roi_relation_head
16
16
17
- SCENE_PAESER_DICT = {"sg_baseline" , "sg_imp" } #, "msdn": MSDN}
17
+ SCENE_PAESER_DICT = {"sg_baseline" , "sg_imp" , "sg_msdn" } #, "msdn": MSDN}
18
18
19
19
class SceneParser (GeneralizedRCNN ):
20
20
"Scene Parser"
@@ -25,6 +25,41 @@ def __init__(self, cfg):
25
25
self .rel_heads = None
26
26
if cfg .MODEL .RELATION_ON and self .cfg .MODEL .ALGORITHM in SCENE_PAESER_DICT :
27
27
self .rel_heads = build_roi_relation_head (cfg , self .backbone .out_channels )
28
+ self ._freeze_components (self .cfg )
29
+
30
+ def _freeze_components (self , cfg ):
31
+ if cfg .MODEL .BACKBONE .FREEZE_PARAMETER :
32
+ for param in self .backbone .parameters ():
33
+ param .requires_grad = False
34
+
35
+ if cfg .MODEL .RPN .FREEZE_PARAMETER :
36
+ for param in self .rpn .parameters ():
37
+ param .requires_grad = False
38
+
39
+ if cfg .MODEL .ROI_BOX_HEAD .FREEZE_PARAMETER :
40
+ for param in self .roi_heads .box .parameters ():
41
+ param .requires_grad = False
42
+
43
+ def train (self ):
44
+ if self .cfg .MODEL .BACKBONE .FREEZE_PARAMETER :
45
+ self .backbone .eval ()
46
+ else :
47
+ self .backbone .train ()
48
+
49
+ if self .cfg .MODEL .RPN .FREEZE_PARAMETER :
50
+ self .rpn .eval ()
51
+ else :
52
+ self .rpn .train ()
53
+
54
+ if self .cfg .MODEL .ROI_BOX_HEAD .FREEZE_PARAMETER :
55
+ self .roi_heads .eval ()
56
+ else :
57
+ self .roi_heads .train ()
58
+
59
+ self .rel_heads .train ()
60
+
61
+ def eval (self ):
62
+ self .eval ()
28
63
29
64
def forward (self , images , targets = None ):
30
65
"""
@@ -44,10 +79,11 @@ def forward(self, images, targets=None):
44
79
images = to_image_list (images )
45
80
features = self .backbone (images .tensors )
46
81
proposals , proposal_losses = self .rpn (images , features , targets )
47
-
82
+ scene_parser_losses = {}
48
83
if self .roi_heads :
49
- x , detections , scene_parser_losses = self .roi_heads (features , proposals , targets )
84
+ x , detections , roi_heads_loss = self .roi_heads (features , proposals , targets )
50
85
result = detections
86
+ scene_parser_losses .update (roi_heads_loss )
51
87
52
88
if self .rel_heads :
53
89
relation_features = features
@@ -60,8 +96,8 @@ def forward(self, images, targets=None):
60
96
relation_features = x
61
97
# During training, self.box() will return the unaltered proposals as "detections"
62
98
# this makes the API consistent during training and testing
63
- x_pairs , detection_pairs , loss_relation = self .rel_heads (relation_features , detections , targets )
64
- losses .update (loss_relation )
99
+ x_pairs , detection_pairs , rel_heads_loss = self .rel_heads (relation_features , detections , targets )
100
+ scene_parser_losses .update (rel_heads_loss )
65
101
66
102
x = (x , x_pairs )
67
103
result = (detections , detection_pairs )
@@ -109,5 +145,6 @@ def build_scene_parser_optimizer(cfg, model, local_rank=0, distributed=False):
109
145
save_to_disk = get_rank () == 0
110
146
checkpointer = SceneParserCheckpointer (cfg , model , optimizer , scheduler , save_dir , save_to_disk ,
111
147
logger = logging .getLogger ("scene_graph_generation.checkpointer" ))
112
- extra_checkpoint_data = checkpointer .load (cfg .MODEL .WEIGHT , resume = cfg .resume )
148
+ model_weight = cfg .MODEL .WEIGHT_DET if cfg .MODEL .WEIGHT_DET != "" else cfg .MODEL .WEIGHT_IMG
149
+ extra_checkpoint_data = checkpointer .load (model_weight , resume = cfg .resume )
113
150
return optimizer , scheduler , checkpointer , extra_checkpoint_data
0 commit comments