Skip to content

Commit 64ae6e6

Browse files
authored
fix qwen25-vl grad acc (huggingface#40333)
* fix qwen25—vl grad acc * fix Qwen2_5_VLForConditionalGeneration for accepts_loss_kwargs * fix ci * fix ci * fix typo * fix CI
1 parent 6d2bb1e commit 64ae6e6

File tree

5 files changed

+19
-0
lines changed

5 files changed

+19
-0
lines changed

src/transformers/models/glm4v/modeling_glm4v.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -893,6 +893,8 @@ def forward(
893893
class Glm4vModel(Glm4vPreTrainedModel):
894894
base_model_prefix = ""
895895
_checkpoint_conversion_mapping = {}
896+
# Reference: fix gemma3 grad acc #37208
897+
accepts_loss_kwargs = False
896898
config: Glm4vConfig
897899
_no_split_modules = ["Glm4vTextDecoderLayer", "Glm4vVisionBlock"]
898900

@@ -1329,6 +1331,8 @@ class Glm4vCausalLMOutputWithPast(ModelOutput):
13291331
class Glm4vForConditionalGeneration(Glm4vPreTrainedModel, GenerationMixin):
13301332
_checkpoint_conversion_mapping = {}
13311333
_tied_weights_keys = ["lm_head.weight"]
1334+
# Reference: fix gemma3 grad acc #37208
1335+
accepts_loss_kwargs = False
13321336

13331337
def __init__(self, config):
13341338
super().__init__(config)

src/transformers/models/glm4v_moe/modeling_glm4v_moe.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1009,6 +1009,8 @@ def forward(
10091009
class Glm4vMoeModel(Glm4vMoePreTrainedModel):
10101010
base_model_prefix = ""
10111011
_checkpoint_conversion_mapping = {}
1012+
# Reference: fix gemma3 grad acc #37208
1013+
accepts_loss_kwargs = False
10121014
config: Glm4vMoeConfig
10131015
_no_split_modules = ["Glm4vMoeTextDecoderLayer", "Glm4vMoeVisionBlock"]
10141016

@@ -1445,6 +1447,8 @@ class Glm4vMoeCausalLMOutputWithPast(ModelOutput):
14451447
class Glm4vMoeForConditionalGeneration(Glm4vMoePreTrainedModel, GenerationMixin):
14461448
_checkpoint_conversion_mapping = {}
14471449
_tied_weights_keys = ["lm_head.weight"]
1450+
# Reference: fix gemma3 grad acc #37208
1451+
accepts_loss_kwargs = False
14481452

14491453
def __init__(self, config):
14501454
super().__init__(config)

src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -938,6 +938,8 @@ def forward(
938938
class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
939939
base_model_prefix = ""
940940
_checkpoint_conversion_mapping = {"^model": "language_model"}
941+
# Reference: fix gemma3 grad acc #37208
942+
accepts_loss_kwargs = False
941943
config: Qwen2_5_VLConfig
942944
_no_split_modules = ["Qwen2_5_VLDecoderLayer", "Qwen2_5_VLVisionBlock"]
943945

@@ -1368,6 +1370,8 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi
13681370
r"^model(?!\.(language_model|visual))": "model.language_model",
13691371
}
13701372
_tied_weights_keys = ["lm_head.weight"]
1373+
# Reference: fix gemma3 grad acc #37208
1374+
accepts_loss_kwargs = False
13711375

13721376
def __init__(self, config):
13731377
super().__init__(config)

src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,8 @@ class Qwen2_5_VLModel(Qwen2VLModel):
346346
config: Qwen2_5_VLConfig
347347
base_model_prefix = ""
348348
_no_split_modules = ["Qwen2_5_VLDecoderLayer", "Qwen2_5_VLVisionBlock"]
349+
# Reference: fix gemma3 grad acc #37208
350+
accepts_loss_kwargs = False
349351

350352
def __init__(self, config):
351353
super().__init__(config)
@@ -651,6 +653,9 @@ class Qwen2_5_VLCausalLMOutputWithPast(Qwen2VLCausalLMOutputWithPast):
651653

652654

653655
class Qwen2_5_VLForConditionalGeneration(Qwen2VLForConditionalGeneration):
656+
# Reference: fix gemma3 grad acc #37208
657+
accepts_loss_kwargs = False
658+
654659
def forward(
655660
self,
656661
input_ids: torch.LongTensor = None,

src/transformers/models/qwen2_vl/modeling_qwen2_vl.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -911,6 +911,8 @@ def forward(
911911
class Qwen2VLModel(Qwen2VLPreTrainedModel):
912912
base_model_prefix = ""
913913
_checkpoint_conversion_mapping = {"^model": "language_model"}
914+
# Reference: fix gemma3 grad acc #37208
915+
accepts_loss_kwargs = False
914916

915917
def __init__(self, config: Qwen2VLConfig):
916918
super().__init__(config)

0 commit comments

Comments
 (0)