Skip to content

Commit f9936d3

Browse files
committed
track original input/output dtype
Signed-off-by: Jonas Kuebler <kuebj@amazon.com>
1 parent b30425b commit f9936d3

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

vllm/attention/layer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,7 @@ def forward(
281281
if attn_metadata.enable_kv_scales_calculation:
282282
self.calc_kv_scales(query, key, value)
283283

284+
output_dtype = query.dtype
284285
if self.query_quant is not None:
285286
# quantizing with a simple torch operation enables
286287
# torch.compile to fuse this into previous ops
@@ -293,7 +294,7 @@ def forward(
293294
output_shape = (output_shape
294295
if output_shape is not None else query.shape)
295296
output = torch.zeros(output_shape,
296-
dtype=query.dtype,
297+
dtype=output_dtype,
297298
device=query.device)
298299
hidden_size = output_shape[-1]
299300

0 commit comments

Comments
 (0)