Skip to content

Commit 235be08

Browse files
[DETA] fix backbone freeze/unfreeze function (#27843)
* [DETA] fix freeze/unfreeze function * Update src/transformers/models/deta/modeling_deta.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/models/deta/modeling_deta.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * add freeze/unfreeze test case in DETA * fix type * fix typo 2 --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
1 parent df5c5c6 commit 235be08

File tree

2 files changed

+30
-4
lines changed

2 files changed

+30
-4
lines changed

src/transformers/models/deta/modeling_deta.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -1414,14 +1414,12 @@ def get_encoder(self):
14141414
def get_decoder(self):
14151415
return self.decoder
14161416

1417-
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrModel.freeze_backbone
14181417
def freeze_backbone(self):
1419-
for name, param in self.backbone.conv_encoder.model.named_parameters():
1418+
for name, param in self.backbone.model.named_parameters():
14201419
param.requires_grad_(False)
14211420

1422-
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrModel.unfreeze_backbone
14231421
def unfreeze_backbone(self):
1424-
for name, param in self.backbone.conv_encoder.model.named_parameters():
1422+
for name, param in self.backbone.model.named_parameters():
14251423
param.requires_grad_(True)
14261424

14271425
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrModel.get_valid_ratio

tests/models/deta/test_modeling_deta.py

+28
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,26 @@ def create_and_check_deta_model(self, config, pixel_values, pixel_mask, labels):
162162

163163
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.num_queries, self.hidden_size))
164164

165+
def create_and_check_deta_freeze_backbone(self, config, pixel_values, pixel_mask, labels):
166+
model = DetaModel(config=config)
167+
model.to(torch_device)
168+
model.eval()
169+
170+
model.freeze_backbone()
171+
172+
for _, param in model.backbone.model.named_parameters():
173+
self.parent.assertEqual(False, param.requires_grad)
174+
175+
def create_and_check_deta_unfreeze_backbone(self, config, pixel_values, pixel_mask, labels):
176+
model = DetaModel(config=config)
177+
model.to(torch_device)
178+
model.eval()
179+
180+
model.unfreeze_backbone()
181+
182+
for _, param in model.backbone.model.named_parameters():
183+
self.parent.assertEqual(True, param.requires_grad)
184+
165185
def create_and_check_deta_object_detection_head_model(self, config, pixel_values, pixel_mask, labels):
166186
model = DetaForObjectDetection(config=config)
167187
model.to(torch_device)
@@ -250,6 +270,14 @@ def test_deta_model(self):
250270
config_and_inputs = self.model_tester.prepare_config_and_inputs()
251271
self.model_tester.create_and_check_deta_model(*config_and_inputs)
252272

273+
def test_deta_freeze_backbone(self):
274+
config_and_inputs = self.model_tester.prepare_config_and_inputs()
275+
self.model_tester.create_and_check_deta_freeze_backbone(*config_and_inputs)
276+
277+
def test_deta_unfreeze_backbone(self):
278+
config_and_inputs = self.model_tester.prepare_config_and_inputs()
279+
self.model_tester.create_and_check_deta_unfreeze_backbone(*config_and_inputs)
280+
253281
def test_deta_object_detection_head_model(self):
254282
config_and_inputs = self.model_tester.prepare_config_and_inputs()
255283
self.model_tester.create_and_check_deta_object_detection_head_model(*config_and_inputs)

0 commit comments

Comments
 (0)