@@ -274,6 +274,8 @@ def __init__(
274274 quant_config : Optional [QuantizationConfig ] = None ,
275275 prefix : str = "" ,
276276 use_data_parallel : bool = False ,
277+ attn_backend : _Backend = _Backend .TORCH_SDPA ,
278+ use_upstream_fa : bool = False ,
277279 ) -> None :
278280 super ().__init__ ()
279281 # Per attention head and per partition values.
@@ -300,25 +302,8 @@ def __init__(
300302 quant_config = quant_config ,
301303 prefix = f"{ prefix } .proj" ,
302304 disable_tp = use_data_parallel )
303-
304- # Detect attention implementation.
305- self .attn_backend = get_vit_attn_backend (
306- head_size = self .hidden_size_per_attention_head ,
307- dtype = torch .get_default_dtype ())
308- self .use_upstream_fa = False
309- if self .attn_backend != _Backend .FLASH_ATTN and \
310- check_upstream_fa_availability (
311- torch .get_default_dtype ()):
312- self .attn_backend = _Backend .FLASH_ATTN
313- self .use_upstream_fa = True
314-
315- if self .attn_backend not in {
316- _Backend .FLASH_ATTN , _Backend .TORCH_SDPA , _Backend .XFORMERS ,
317- _Backend .ROCM_AITER_FA
318- }:
319- raise RuntimeError (
320- f"Qwen2.5-VL does not support { self .attn_backend } backend now."
321- )
305+ self .attn_backend = attn_backend
306+ self .use_upstream_fa = use_upstream_fa
322307 self .is_flash_attn_backend = self .attn_backend in {
323308 _Backend .FLASH_ATTN , _Backend .ROCM_AITER_FA
324309 }
@@ -443,6 +428,8 @@ def __init__(
443428 quant_config : Optional [QuantizationConfig ] = None ,
444429 prefix : str = "" ,
445430 use_data_parallel : bool = False ,
431+ attn_backend : _Backend = _Backend .TORCH_SDPA ,
432+ use_upstream_fa : bool = False ,
446433 ) -> None :
447434 super ().__init__ ()
448435 if norm_layer is None :
@@ -455,7 +442,9 @@ def __init__(
455442 projection_size = dim ,
456443 quant_config = quant_config ,
457444 prefix = f"{ prefix } .attn" ,
458- use_data_parallel = use_data_parallel )
445+ use_data_parallel = use_data_parallel ,
446+ attn_backend = attn_backend ,
447+ use_upstream_fa = use_upstream_fa )
459448 self .mlp = Qwen2_5_VisionMLP (dim ,
460449 mlp_hidden_dim ,
461450 act_fn = act_fn ,
@@ -627,17 +616,35 @@ def __init__(
627616 head_dim = self .hidden_size // self .num_heads
628617 self .rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding (head_dim // 2 )
629618
619+ use_upstream_fa = False
620+ self .attn_backend = get_vit_attn_backend (
621+ head_size = head_dim , dtype = torch .get_default_dtype ())
622+ if self .attn_backend != _Backend .FLASH_ATTN and \
623+ check_upstream_fa_availability (
624+ torch .get_default_dtype ()):
625+ self .attn_backend = _Backend .FLASH_ATTN
626+ use_upstream_fa = True
627+
628+ if self .attn_backend not in {
629+ _Backend .FLASH_ATTN , _Backend .TORCH_SDPA , _Backend .XFORMERS ,
630+ _Backend .ROCM_AITER_FA
631+ }:
632+ raise RuntimeError (
633+ f"Qwen2.5-VL does not support { self .attn_backend } backend now."
634+ )
635+
630636 self .blocks = nn .ModuleList ([
631- Qwen2_5_VisionBlock (dim = self .hidden_size ,
632- num_heads = self .num_heads ,
633- mlp_hidden_dim = vision_config .intermediate_size ,
634- act_fn = get_act_and_mul_fn (
635- vision_config .hidden_act ),
636- norm_layer = norm_layer ,
637- quant_config = quant_config ,
638- prefix = f"{ prefix } .blocks.{ layer_idx } " ,
639- use_data_parallel = use_data_parallel )
640- for layer_idx in range (depth )
637+ Qwen2_5_VisionBlock (
638+ dim = self .hidden_size ,
639+ num_heads = self .num_heads ,
640+ mlp_hidden_dim = vision_config .intermediate_size ,
641+ act_fn = get_act_and_mul_fn (vision_config .hidden_act ),
642+ norm_layer = norm_layer ,
643+ quant_config = quant_config ,
644+ prefix = f"{ prefix } .blocks.{ layer_idx } " ,
645+ use_data_parallel = use_data_parallel ,
646+ attn_backend = self .attn_backend ,
647+ use_upstream_fa = use_upstream_fa ) for layer_idx in range (depth )
641648 ])
642649 self .merger = Qwen2_5_VisionPatchMerger (
643650 d_model = vision_config .out_hidden_size ,
@@ -648,12 +655,6 @@ def __init__(
648655 prefix = f"{ prefix } .merger" ,
649656 use_data_parallel = use_data_parallel ,
650657 )
651- self .attn_backend = get_vit_attn_backend (
652- head_size = head_dim , dtype = torch .get_default_dtype ())
653- if self .attn_backend != _Backend .FLASH_ATTN and \
654- check_upstream_fa_availability (
655- torch .get_default_dtype ()):
656- self .attn_backend = _Backend .FLASH_ATTN
657658
658659 @property
659660 def dtype (self ) -> torch .dtype :
0 commit comments