@@ -111,6 +111,13 @@ def maybe_get_vit_flash_attn_backend(
111111 return attn_backend , flash_attn_varlen_func
112112
113113
114+ def allocate_tensor (shape : torch .Size , device : torch .device , dtype : torch .dtype ):
115+ if get_current_vllm_config ().model_config .init_attn_out :
116+ return torch .zeros (shape , device = device , dtype = dtype )
117+ else :
118+ return torch .empty (shape , device = device , dtype = dtype )
119+
120+
114121class Attention (nn .Module , AttentionLayerBase ):
115122 """Attention layer.
116123
@@ -349,8 +356,8 @@ def forward(
349356
350357 # Use torch.empty to avoid initializing tensor with zero.
351358 output_numel = output_shape .numel ()
352- output_shape = ( output_numel // (self .num_heads * self .head_size ), self .num_heads , self .head_size )
353- output = torch . empty (output_shape , dtype = output_dtype , device = query .device )
359+ output_shape = torch . Size (( output_numel // (self .num_heads * self .head_size ), self .num_heads , self .head_size ) )
360+ output = allocate_tensor (output_shape , device = query .device , dtype = output_dtype )
354361
355362 # Reshape the query, key, and value tensors.
356363 # NOTE(woosuk): We do this outside the custom op to minimize the
@@ -708,7 +715,7 @@ def forward(
708715 self .calc_kv_scales (q , kv_c_normed , k_pe )
709716
710717 if self .attn_backend .accept_output_buffer :
711- output = torch . zeros (output_shape , dtype = q .dtype , device = q .device )
718+ output = allocate_tensor (output_shape , dtype = q .dtype , device = q .device )
712719 self .impl .forward (
713720 self ,
714721 q ,
@@ -725,7 +732,7 @@ def forward(
725732 )
726733 else :
727734 if self .attn_backend .accept_output_buffer :
728- output = torch . zeros (output_shape , dtype = q .dtype , device = q .device )
735+ output = allocate_tensor (output_shape , dtype = q .dtype , device = q .device )
729736 torch .ops .vllm .unified_mla_attention_with_output (
730737 q ,
731738 kv_c_normed ,
0 commit comments