Skip to content

Commit 8e7c62a

Browse files
committed
nit
Signed-off-by: Boyuan Feng <boyuan@meta.com>
1 parent c41d02f commit 8e7c62a

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

vllm/attention/layer.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,13 @@ def maybe_get_vit_flash_attn_backend(
111111
return attn_backend, flash_attn_varlen_func
112112

113113

114+
def allocate_tensor(shape: torch.Size, device: torch.device, dtype: torch.dtype):
115+
if get_current_vllm_config().model_config.init_attn_out:
116+
return torch.zeros(shape, device=device, dtype=dtype)
117+
else:
118+
return torch.empty(shape, device=device, dtype=dtype)
119+
120+
114121
class Attention(nn.Module, AttentionLayerBase):
115122
"""Attention layer.
116123
@@ -349,8 +356,8 @@ def forward(
349356

350357
# Use torch.empty to avoid initializing tensor with zero.
351358
output_numel = output_shape.numel()
352-
output_shape = (output_numel//(self.num_heads * self.head_size), self.num_heads, self.head_size)
353-
output = torch.empty(output_shape, dtype=output_dtype, device=query.device)
359+
output_shape = torch.Size((output_numel//(self.num_heads * self.head_size), self.num_heads, self.head_size))
360+
output = allocate_tensor(output_shape, device=query.device, dtype=output_dtype)
354361

355362
# Reshape the query, key, and value tensors.
356363
# NOTE(woosuk): We do this outside the custom op to minimize the
@@ -708,7 +715,7 @@ def forward(
708715
self.calc_kv_scales(q, kv_c_normed, k_pe)
709716

710717
if self.attn_backend.accept_output_buffer:
711-
output = torch.zeros(output_shape, dtype=q.dtype, device=q.device)
718+
output = allocate_tensor(output_shape, dtype=q.dtype, device=q.device)
712719
self.impl.forward(
713720
self,
714721
q,
@@ -725,7 +732,7 @@ def forward(
725732
)
726733
else:
727734
if self.attn_backend.accept_output_buffer:
728-
output = torch.zeros(output_shape, dtype=q.dtype, device=q.device)
735+
output = allocate_tensor(output_shape, dtype=q.dtype, device=q.device)
729736
torch.ops.vllm.unified_mla_attention_with_output(
730737
q,
731738
kv_c_normed,

vllm/config/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ class ModelConfig:
187187
- 1k -> 1000\n
188188
- 1K -> 1024\n
189189
- 25.6k -> 25,600"""
190+
init_attn_out: bool = False
190191
spec_target_max_model_len: Optional[int] = None
191192
"""Specify the maximum length for spec decoding draft models."""
192193
quantization: SkipValidation[Optional[QuantizationMethods]] = None

0 commit comments

Comments
 (0)