Skip to content

Commit 3bc44ea

Browse files
authored
[qwen-vl] Standardize config (#37268)
* update * fix tests * fixup * update * skip this one * fixup * fix
1 parent 4f96081 commit 3bc44ea

File tree

13 files changed

+202
-55
lines changed

13 files changed

+202
-55
lines changed

docs/source/en/model_doc/qwen2_5_vl.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,10 +232,15 @@ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
232232

233233
[[autodoc]] Qwen2_5_VLConfig
234234

235+
## Qwen2_5_VLTextConfig
236+
237+
[[autodoc]] Qwen2_5_VLTextConfig
238+
235239
## Qwen2_5_VLProcessor
236240

237241
[[autodoc]] Qwen2_5_VLProcessor
238242

243+
239244
## Qwen2_5_VLModel
240245

241246
[[autodoc]] Qwen2_5_VLModel

docs/source/en/model_doc/qwen2_vl.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,10 @@ model = Qwen2VLForConditionalGeneration.from_pretrained(
278278

279279
[[autodoc]] Qwen2VLConfig
280280

281+
## Qwen2VLTextConfig
282+
283+
[[autodoc]] Qwen2VLTextConfig
284+
281285
## Qwen2VLImageProcessor
282286

283287
[[autodoc]] Qwen2VLImageProcessor

src/transformers/models/auto/configuration_auto.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,10 +258,12 @@
258258
("qwen2", "Qwen2Config"),
259259
("qwen2_5_omni", "Qwen2_5OmniConfig"),
260260
("qwen2_5_vl", "Qwen2_5_VLConfig"),
261+
("qwen2_5_vl_text", "Qwen2_5_VLTextConfig"),
261262
("qwen2_audio", "Qwen2AudioConfig"),
262263
("qwen2_audio_encoder", "Qwen2AudioEncoderConfig"),
263264
("qwen2_moe", "Qwen2MoeConfig"),
264265
("qwen2_vl", "Qwen2VLConfig"),
266+
("qwen2_vl_text", "Qwen2VLTextConfig"),
265267
("qwen3", "Qwen3Config"),
266268
("qwen3_moe", "Qwen3MoeConfig"),
267269
("rag", "RagConfig"),
@@ -625,10 +627,12 @@
625627
("qwen2", "Qwen2"),
626628
("qwen2_5_omni", "Qwen2_5Omni"),
627629
("qwen2_5_vl", "Qwen2_5_VL"),
630+
("qwen2_5_vl_text", "Qwen2_5_VL"),
628631
("qwen2_audio", "Qwen2Audio"),
629632
("qwen2_audio_encoder", "Qwen2AudioEncoder"),
630633
("qwen2_moe", "Qwen2MoE"),
631634
("qwen2_vl", "Qwen2VL"),
635+
("qwen2_vl_text", "Qwen2VL"),
632636
("qwen3", "Qwen3"),
633637
("qwen3_moe", "Qwen3MoE"),
634638
("rag", "RAG"),
@@ -793,6 +797,8 @@
793797
("chinese_clip_vision_model", "chinese_clip"),
794798
("rt_detr_resnet", "rt_detr"),
795799
("granitevision", "llava_next"),
800+
("qwen2_5_vl_text", "qwen2_5_vl"),
801+
("qwen2_vl_text", "qwen2_vl"),
796802
("sam_vision_model", "sam"),
797803
("llama4_text", "llama4"),
798804
("blip_2_qformer", "blip_2"),

src/transformers/models/auto/modeling_auto.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,9 +234,11 @@
234234
("qdqbert", "QDQBertModel"),
235235
("qwen2", "Qwen2Model"),
236236
("qwen2_5_vl", "Qwen2_5_VLModel"),
237+
("qwen2_5_vl_text", "Qwen2_5_VLModel"),
237238
("qwen2_audio_encoder", "Qwen2AudioEncoder"),
238239
("qwen2_moe", "Qwen2MoeModel"),
239240
("qwen2_vl", "Qwen2VLModel"),
241+
("qwen2_vl_text", "Qwen2VLModel"),
240242
("qwen3", "Qwen3Model"),
241243
("qwen3_moe", "Qwen3MoeModel"),
242244
("recurrent_gemma", "RecurrentGemmaModel"),

src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1792,7 +1792,7 @@ def forward(
17921792

17931793

17941794
class Qwen2_5OmniDecoderLayer(nn.Module):
1795-
def __init__(self, config: Qwen2_5OmniConfig, layer_idx: int):
1795+
def __init__(self, config: Qwen2_5OmniTextConfig, layer_idx: int):
17961796
super().__init__()
17971797
self.hidden_size = config.hidden_size
17981798

src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py

Lines changed: 70 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -67,17 +67,16 @@ def __init__(
6767
self.initializer_range = initializer_range
6868

6969

70-
class Qwen2_5_VLConfig(PretrainedConfig):
70+
class Qwen2_5_VLTextConfig(PretrainedConfig):
7171
r"""
72-
This is the configuration class to store the configuration of a [`Qwen2_5_VLModel`]. It is used to instantiate a
72+
This is the configuration class to store the configuration of a [`Qwen2_5_VLTextModel`]. It is used to instantiate a
7373
Qwen2-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration
7474
with the defaults will yield a similar configuration to that of
7575
Qwen2-VL-7B-Instruct [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct).
7676
7777
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
7878
documentation from [`PretrainedConfig`] for more information.
7979
80-
8180
Args:
8281
vocab_size (`int`, *optional*, defaults to 152064):
8382
Vocabulary size of the Qwen2_5_VL model. Defines the number of different tokens that can be represented by the
@@ -120,8 +119,6 @@ class Qwen2_5_VLConfig(PretrainedConfig):
120119
The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
121120
attention_dropout (`float`, *optional*, defaults to 0.0):
122121
The dropout ratio for the attention probabilities.
123-
vision_config (`Dict`, *optional*):
124-
The config for the visual encoder initialization.
125122
rope_scaling (`Dict`, *optional*):
126123
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
127124
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
@@ -161,20 +158,20 @@ class Qwen2_5_VLConfig(PretrainedConfig):
161158
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
162159
163160
```python
164-
>>> from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLConfig
161+
>>> from transformers import Qwen2_5_VLTextModel, Qwen2_5_VLConfig
165162
166163
>>> # Initializing a Qwen2_5_VL style configuration
167164
>>> configuration = Qwen2_5_VLConfig()
168165
169166
>>> # Initializing a model from the Qwen2-VL-7B style configuration
170-
>>> model = Qwen2_5_VLForConditionalGeneration(configuration)
167+
>>> model = Qwen2_5_VLTextModel(configuration)
171168
172169
>>> # Accessing the model configuration
173170
>>> configuration = model.config
174171
```"""
175172

176-
model_type = "qwen2_5_vl"
177-
sub_configs = {"vision_config": Qwen2_5_VLVisionConfig}
173+
model_type = "qwen2_5_vl_text"
174+
base_config_key = "text_config"
178175
keys_to_ignore_at_inference = ["past_key_values"]
179176
# Default tensor parallel plan for base model `Qwen2_5_VL`
180177
base_model_tp_plan = {
@@ -211,15 +208,9 @@ def __init__(
211208
sliding_window=4096,
212209
max_window_layers=80,
213210
attention_dropout=0.0,
214-
vision_config=None,
215211
rope_scaling=None,
216212
**kwargs,
217213
):
218-
if isinstance(vision_config, dict):
219-
self.vision_config = self.sub_configs["vision_config"](**vision_config)
220-
elif vision_config is None:
221-
self.vision_config = self.sub_configs["vision_config"]()
222-
223214
self.vocab_size = vocab_size
224215
self.max_position_embeddings = max_position_embeddings
225216
self.hidden_size = hidden_size
@@ -257,4 +248,67 @@ def __init__(
257248
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
258249

259250

260-
__all__ = ["Qwen2_5_VLConfig"]
251+
class Qwen2_5_VLConfig(PretrainedConfig):
252+
r"""
253+
This is the configuration class to store the configuration of a [`Qwen2_5_VLModel`]. It is used to instantiate a
254+
Qwen2-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration
255+
with the defaults will yield a similar configuration to that of
256+
Qwen2-VL-7B-Instruct [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct).
257+
258+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
259+
documentation from [`PretrainedConfig`] for more information.
260+
261+
262+
Args:
263+
text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen2_5_VLTextConfig`):
264+
The config object or dictionary of the text backbone.
265+
vision_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen2_5_VLVisionConfig`):
266+
The config object or dictionary of the vision backbone.
267+
image_token_id (`int`, *optional*, defaults to 151655):
268+
The image token index to encode the image prompt.
269+
video_token_id (`int`, *optional*, defaults to 151656):
270+
The video token index to encode the image prompt.
271+
272+
```python
273+
>>> from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLConfig
274+
275+
>>> # Initializing a Qwen2_5_VL style configuration
276+
>>> configuration = Qwen2_5_VLConfig()
277+
278+
>>> # Initializing a model from the Qwen2-VL-7B style configuration
279+
>>> model = Qwen2_5_VLForConditionalGeneration(configuration)
280+
281+
>>> # Accessing the model configuration
282+
>>> configuration = model.config
283+
```"""
284+
285+
model_type = "qwen2_5_vl"
286+
sub_configs = {"vision_config": Qwen2_5_VLVisionConfig, "text_config": Qwen2_5_VLTextConfig}
287+
keys_to_ignore_at_inference = ["past_key_values"]
288+
289+
def __init__(
290+
self,
291+
text_config=None,
292+
vision_config=None,
293+
image_token_id=151655,
294+
video_token_id=151656,
295+
**kwargs,
296+
):
297+
if isinstance(vision_config, dict):
298+
self.vision_config = self.sub_configs["vision_config"](**vision_config)
299+
elif vision_config is None:
300+
self.vision_config = self.sub_configs["vision_config"]()
301+
302+
if isinstance(text_config, dict):
303+
self.text_config = self.sub_configs["text_config"](**text_config)
304+
elif text_config is None:
305+
# For BC use all kwargs to init `TextConfig`
306+
self.text_config = self.sub_configs["text_config"](**kwargs)
307+
308+
self.image_token_id = image_token_id
309+
self.video_token_id = video_token_id
310+
311+
super().__init__(**kwargs)
312+
313+
314+
__all__ = ["Qwen2_5_VLConfig", "Qwen2_5_VLTextConfig"]

src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
logging,
4949
replace_return_docstrings,
5050
)
51-
from .configuration_qwen2_5_vl import Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig
51+
from .configuration_qwen2_5_vl import Qwen2_5_VLConfig, Qwen2_5_VLTextConfig, Qwen2_5_VLVisionConfig
5252

5353

5454
if is_flash_attn_available():
@@ -390,7 +390,7 @@ class Qwen2_5_VLPreTrainedModel(PreTrainedModel):
390390
_supports_static_cache = False # TODO (joao): fix. torch.compile failing probably due to `cache_positions`
391391

392392
def _init_weights(self, module):
393-
std = self.config.initializer_range
393+
std = self.config.get_text_config().initializer_range
394394
if isinstance(module, (nn.Linear, nn.Conv3d)):
395395
module.weight.data.normal_(mean=0.0, std=std)
396396
if module.bias is not None:
@@ -566,7 +566,7 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.
566566

567567

568568
class Qwen2_5_VLRotaryEmbedding(nn.Module):
569-
def __init__(self, config: Qwen2_5_VLConfig, device=None):
569+
def __init__(self, config: Qwen2_5_VLTextConfig, device=None):
570570
super().__init__()
571571
# BC: "rope_type" was originally "type"
572572
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
@@ -680,7 +680,7 @@ class Qwen2_5_VLAttention(nn.Module):
680680
and "Generating Long Sequences with Sparse Transformers".
681681
"""
682682

683-
def __init__(self, config: Qwen2_5_VLConfig, layer_idx: Optional[int] = None):
683+
def __init__(self, config: Qwen2_5_VLTextConfig, layer_idx: Optional[int] = None):
684684
super().__init__()
685685
self.config = config
686686
self.layer_idx = layer_idx
@@ -989,7 +989,7 @@ def forward(
989989

990990

991991
class Qwen2_5_VLDecoderLayer(nn.Module):
992-
def __init__(self, config: Qwen2_5_VLConfig, layer_idx: int):
992+
def __init__(self, config: Qwen2_5_VLTextConfig, layer_idx: int):
993993
super().__init__()
994994
self.hidden_size = config.hidden_size
995995

@@ -1077,7 +1077,9 @@ def forward(
10771077
Qwen2_5_VL_START_DOCSTRING,
10781078
)
10791079
class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
1080-
def __init__(self, config: Qwen2_5_VLConfig):
1080+
config_class = Qwen2_5_VLTextConfig
1081+
1082+
def __init__(self, config: Qwen2_5_VLTextConfig):
10811083
super().__init__(config)
10821084
self.padding_idx = config.pad_token_id
10831085
self.vocab_size = config.vocab_size
@@ -1497,9 +1499,11 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi
14971499
def __init__(self, config):
14981500
super().__init__(config)
14991501
self.visual = Qwen2_5_VisionTransformerPretrainedModel._from_config(config.vision_config)
1500-
self.model = Qwen2_5_VLModel(config)
1501-
self.vocab_size = config.vocab_size
1502-
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1502+
1503+
text_config = config.get_text_config()
1504+
self.model = Qwen2_5_VLModel._from_config(text_config)
1505+
self.vocab_size = text_config.vocab_size
1506+
self.lm_head = nn.Linear(text_config.hidden_size, text_config.vocab_size, bias=False)
15031507
self.rope_deltas = None # cache rope_deltas here
15041508

15051509
# Initialize weights and apply final processing

src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
import torch.utils.checkpoint
2929
from torch.nn import CrossEntropyLoss
3030

31-
from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLConfig
31+
from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLConfig, Qwen2VLTextConfig
3232
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
3333
PatchEmbed,
3434
PatchMerger,
@@ -110,9 +110,13 @@ def __init__(
110110
self.initializer_range = initializer_range
111111

112112

113+
class Qwen2_5_VLTextConfig(Qwen2VLTextConfig):
114+
model_type = "qwen2_5_vl_text"
115+
116+
113117
class Qwen2_5_VLConfig(Qwen2VLConfig):
114118
model_type = "qwen2_5_vl"
115-
sub_configs = {"vision_config": Qwen2_5_VLVisionConfig}
119+
sub_configs = {"vision_config": Qwen2_5_VLVisionConfig, "text_config": Qwen2_5_VLTextConfig}
116120

117121

118122
class Qwen2_5_VLMLP(nn.Module):
@@ -227,7 +231,7 @@ def forward(
227231

228232
class Qwen2_5_VLPreTrainedModel(Qwen2VLPreTrainedModel):
229233
def _init_weights(self, module):
230-
std = self.config.initializer_range
234+
std = self.config.get_text_config().initializer_range
231235
if isinstance(module, (nn.Linear, nn.Conv3d)):
232236
module.weight.data.normal_(mean=0.0, std=std)
233237
if module.bias is not None:
@@ -971,6 +975,7 @@ def __call__(
971975

972976
__all__ = [
973977
"Qwen2_5_VLConfig",
978+
"Qwen2_5_VLTextConfig",
974979
"Qwen2_5_VLForConditionalGeneration",
975980
"Qwen2_5_VLModel",
976981
"Qwen2_5_VLPreTrainedModel",

0 commit comments

Comments
 (0)