Skip to content

Commit 0996a10

Browse files
authored
Revert low cpu mem tie weights (#29135)
* Revert "Add tie_weights() to LM heads and set bias in set_output_embeddings() (#28948)" This reverts commit 725f4ad. * Revert "Patch to skip failing `test_save_load_low_cpu_mem_usage` tests (#29043)" This reverts commit 4156f51.
1 parent 15cfe38 commit 0996a10

26 files changed

+0
-144
lines changed

src/transformers/models/bert/modeling_bert.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -692,9 +692,6 @@ def __init__(self, config):
692692
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
693693
self.decoder.bias = self.bias
694694

695-
def _tie_weights(self):
696-
self.decoder.bias = self.bias
697-
698695
def forward(self, hidden_states):
699696
hidden_states = self.transform(hidden_states)
700697
hidden_states = self.decoder(hidden_states)
@@ -1065,7 +1062,6 @@ def get_output_embeddings(self):
10651062

10661063
def set_output_embeddings(self, new_embeddings):
10671064
self.cls.predictions.decoder = new_embeddings
1068-
self.cls.predictions.bias = new_embeddings.bias
10691065

10701066
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
10711067
@replace_return_docstrings(output_type=BertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
@@ -1175,7 +1171,6 @@ def get_output_embeddings(self):
11751171

11761172
def set_output_embeddings(self, new_embeddings):
11771173
self.cls.predictions.decoder = new_embeddings
1178-
self.cls.predictions.bias = new_embeddings.bias
11791174

11801175
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
11811176
@add_code_sample_docstrings(
@@ -1329,7 +1324,6 @@ def get_output_embeddings(self):
13291324

13301325
def set_output_embeddings(self, new_embeddings):
13311326
self.cls.predictions.decoder = new_embeddings
1332-
self.cls.predictions.bias = new_embeddings.bias
13331327

13341328
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
13351329
@add_code_sample_docstrings(

src/transformers/models/big_bird/modeling_big_bird.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1707,9 +1707,6 @@ def __init__(self, config):
17071707
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
17081708
self.decoder.bias = self.bias
17091709

1710-
def _tie_weights(self):
1711-
self.decoder.bias = self.bias
1712-
17131710
def forward(self, hidden_states):
17141711
hidden_states = self.transform(hidden_states)
17151712
hidden_states = self.decoder(hidden_states)
@@ -2269,7 +2266,6 @@ def get_output_embeddings(self):
22692266

22702267
def set_output_embeddings(self, new_embeddings):
22712268
self.cls.predictions.decoder = new_embeddings
2272-
self.cls.predictions.bias = new_embeddings.bias
22732269

22742270
@add_start_docstrings_to_model_forward(BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
22752271
@replace_return_docstrings(output_type=BigBirdForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
@@ -2382,7 +2378,6 @@ def get_output_embeddings(self):
23822378

23832379
def set_output_embeddings(self, new_embeddings):
23842380
self.cls.predictions.decoder = new_embeddings
2385-
self.cls.predictions.bias = new_embeddings.bias
23862381

23872382
@add_start_docstrings_to_model_forward(BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
23882383
@replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC)
@@ -2524,7 +2519,6 @@ def get_output_embeddings(self):
25242519

25252520
def set_output_embeddings(self, new_embeddings):
25262521
self.cls.predictions.decoder = new_embeddings
2527-
self.cls.predictions.bias = new_embeddings.bias
25282522

25292523
@add_start_docstrings_to_model_forward(BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
25302524
@add_code_sample_docstrings(

src/transformers/models/blip/modeling_blip_text.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -523,9 +523,6 @@ def __init__(self, config):
523523
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
524524
self.decoder.bias = self.bias
525525

526-
def _tie_weights(self):
527-
self.decoder.bias = self.bias
528-
529526
def forward(self, hidden_states):
530527
hidden_states = self.transform(hidden_states)
531528
hidden_states = self.decoder(hidden_states)
@@ -820,7 +817,6 @@ def get_output_embeddings(self):
820817

821818
def set_output_embeddings(self, new_embeddings):
822819
self.cls.predictions.decoder = new_embeddings
823-
self.cls.predictions.bias = new_embeddings.bias
824820

825821
def forward(
826822
self,

src/transformers/models/ernie/modeling_ernie.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -608,9 +608,6 @@ def __init__(self, config):
608608
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
609609
self.decoder.bias = self.bias
610610

611-
def _tie_weights(self):
612-
self.decoder.bias = self.bias
613-
614611
def forward(self, hidden_states):
615612
hidden_states = self.transform(hidden_states)
616613
hidden_states = self.decoder(hidden_states)
@@ -998,7 +995,6 @@ def get_output_embeddings(self):
998995
# Copied from transformers.models.bert.modeling_bert.BertForPreTraining.set_output_embeddings
999996
def set_output_embeddings(self, new_embeddings):
1000997
self.cls.predictions.decoder = new_embeddings
1001-
self.cls.predictions.bias = new_embeddings.bias
1002998

1003999
@add_start_docstrings_to_model_forward(ERNIE_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
10041000
@replace_return_docstrings(output_type=ErnieForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
@@ -1113,7 +1109,6 @@ def get_output_embeddings(self):
11131109
# Copied from transformers.models.bert.modeling_bert.BertLMHeadModel.set_output_embeddings
11141110
def set_output_embeddings(self, new_embeddings):
11151111
self.cls.predictions.decoder = new_embeddings
1116-
self.cls.predictions.bias = new_embeddings.bias
11171112

11181113
@add_start_docstrings_to_model_forward(ERNIE_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
11191114
@add_code_sample_docstrings(
@@ -1274,7 +1269,6 @@ def get_output_embeddings(self):
12741269
# Copied from transformers.models.bert.modeling_bert.BertForMaskedLM.set_output_embeddings
12751270
def set_output_embeddings(self, new_embeddings):
12761271
self.cls.predictions.decoder = new_embeddings
1277-
self.cls.predictions.bias = new_embeddings.bias
12781272

12791273
@add_start_docstrings_to_model_forward(ERNIE_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
12801274
@add_code_sample_docstrings(

src/transformers/models/layoutlm/modeling_layoutlm.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -589,9 +589,6 @@ def __init__(self, config):
589589
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
590590
self.decoder.bias = self.bias
591591

592-
def _tie_weights(self):
593-
self.decoder.bias = self.bias
594-
595592
def forward(self, hidden_states):
596593
hidden_states = self.transform(hidden_states)
597594
hidden_states = self.decoder(hidden_states)
@@ -872,7 +869,6 @@ def get_output_embeddings(self):
872869

873870
def set_output_embeddings(self, new_embeddings):
874871
self.cls.predictions.decoder = new_embeddings
875-
self.cls.predictions.bias = new_embeddings.bias
876872

877873
@add_start_docstrings_to_model_forward(LAYOUTLM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
878874
@replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC)

src/transformers/models/markuplm/modeling_markuplm.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -318,9 +318,6 @@ def __init__(self, config):
318318
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
319319
self.decoder.bias = self.bias
320320

321-
def _tie_weights(self):
322-
self.decoder.bias = self.bias
323-
324321
def forward(self, hidden_states):
325322
hidden_states = self.transform(hidden_states)
326323
hidden_states = self.decoder(hidden_states)

src/transformers/models/megatron_bert/modeling_megatron_bert.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -659,9 +659,6 @@ def __init__(self, config):
659659
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
660660
self.decoder.bias = self.bias
661661

662-
def _tie_weights(self):
663-
self.decoder.bias = self.bias
664-
665662
def forward(self, hidden_states):
666663
hidden_states = self.transform(hidden_states)
667664
hidden_states = self.decoder(hidden_states)
@@ -1026,7 +1023,6 @@ def get_output_embeddings(self):
10261023

10271024
def set_output_embeddings(self, new_embeddings):
10281025
self.cls.predictions.decoder = new_embeddings
1029-
self.cls.predictions.bias = new_embeddings.bias
10301026

10311027
@add_start_docstrings_to_model_forward(MEGATRON_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
10321028
@replace_return_docstrings(output_type=MegatronBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
@@ -1136,7 +1132,6 @@ def get_output_embeddings(self):
11361132

11371133
def set_output_embeddings(self, new_embeddings):
11381134
self.cls.predictions.decoder = new_embeddings
1139-
self.cls.predictions.bias = new_embeddings.bias
11401135

11411136
@add_start_docstrings_to_model_forward(MEGATRON_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
11421137
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
@@ -1295,7 +1290,6 @@ def get_output_embeddings(self):
12951290

12961291
def set_output_embeddings(self, new_embeddings):
12971292
self.cls.predictions.decoder = new_embeddings
1298-
self.cls.predictions.bias = new_embeddings.bias
12991293

13001294
@add_start_docstrings_to_model_forward(MEGATRON_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
13011295
@add_code_sample_docstrings(

src/transformers/models/mpnet/modeling_mpnet.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -587,7 +587,6 @@ def get_output_embeddings(self):
587587

588588
def set_output_embeddings(self, new_embeddings):
589589
self.lm_head.decoder = new_embeddings
590-
self.lm_head.bias = new_embeddings.bias
591590

592591
@add_start_docstrings_to_model_forward(MPNET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
593592
@add_code_sample_docstrings(
@@ -660,9 +659,6 @@ def __init__(self, config):
660659
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
661660
self.decoder.bias = self.bias
662661

663-
def _tie_weights(self):
664-
self.decoder.bias = self.bias
665-
666662
def forward(self, features, **kwargs):
667663
x = self.dense(features)
668664
x = gelu(x)

src/transformers/models/mra/modeling_mra.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -810,9 +810,6 @@ def __init__(self, config):
810810
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
811811
self.decoder.bias = self.bias
812812

813-
def _tie_weights(self):
814-
self.decoder.bias = self.bias
815-
816813
def forward(self, hidden_states):
817814
hidden_states = self.transform(hidden_states)
818815
hidden_states = self.decoder(hidden_states)
@@ -1046,7 +1043,6 @@ def get_output_embeddings(self):
10461043

10471044
def set_output_embeddings(self, new_embeddings):
10481045
self.cls.predictions.decoder = new_embeddings
1049-
self.cls.predictions.bias = new_embeddings.bias
10501046

10511047
@add_start_docstrings_to_model_forward(MRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
10521048
@add_code_sample_docstrings(

src/transformers/models/nezha/modeling_nezha.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -679,9 +679,6 @@ def __init__(self, config):
679679
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
680680
self.decoder.bias = self.bias
681681

682-
def _tie_weights(self):
683-
self.decoder.bias = self.bias
684-
685682
def forward(self, hidden_states):
686683
hidden_states = self.transform(hidden_states)
687684
hidden_states = self.decoder(hidden_states)
@@ -1047,7 +1044,6 @@ def get_output_embeddings(self):
10471044

10481045
def set_output_embeddings(self, new_embeddings):
10491046
self.cls.predictions.decoder = new_embeddings
1050-
self.cls.predictions.bias = new_embeddings.bias
10511047

10521048
@add_start_docstrings_to_model_forward(NEZHA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
10531049
@replace_return_docstrings(output_type=NezhaForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
@@ -1156,7 +1152,6 @@ def get_output_embeddings(self):
11561152

11571153
def set_output_embeddings(self, new_embeddings):
11581154
self.cls.predictions.decoder = new_embeddings
1159-
self.cls.predictions.bias = new_embeddings.bias
11601155

11611156
@add_start_docstrings_to_model_forward(NEZHA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
11621157
@add_code_sample_docstrings(

0 commit comments

Comments
 (0)