Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions tests/kernels/test_encoder_decoder_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,9 @@ class that Attention will automatically select when it is constructed.
)
if test_pt.num_blocks is None or test_pt.num_heads is None:
# Caller does not require a KV cache
return TestResources(scale, attn_backend, attn, None)
return TestResources(
scale, attn_backend, attn,
torch.tensor([], dtype=torch.float32, device=CUDA_DEVICE))

# Construct KV cache
kv_cache = make_kv_cache(test_pt.num_blocks,
Expand Down Expand Up @@ -620,7 +622,9 @@ def _run_encoder_attention_test(
return attn.forward(packed_qkv.query,
packed_qkv.key,
packed_qkv.value,
None,
torch.tensor([],
dtype=torch.float32,
device=packed_qkv.query.device),
attn_metadata,
attn_type=attn_type)

Expand Down
6 changes: 4 additions & 2 deletions vllm/attention/backends/blocksparse_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,8 @@ def forward(
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
NOTE: kv_cache will be an empty tensor with shape [0]
for profiling run.
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
Expand All @@ -373,7 +375,7 @@ def forward(
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)

if kv_cache is not None:
if kv_cache.numel() > 0:
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)

Expand All @@ -399,7 +401,7 @@ def forward(
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.

assert kv_cache is None \
assert kv_cache.numel() == 0 \
or prefill_meta.block_tables is None \
or prefill_meta.block_tables.numel() == 0, \
"Does not support prefix-enabled attention."
Expand Down
6 changes: 4 additions & 2 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,8 @@ def forward(
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
NOTE: kv_cache will be an empty tensor with shape [0]
for profiling run.
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
Expand All @@ -685,7 +687,7 @@ def forward(
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)

if kv_cache is not None:
if kv_cache.numel() > 0:
key_cache = kv_cache[0]
value_cache = kv_cache[1]

Expand Down Expand Up @@ -722,7 +724,7 @@ def forward(

if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run.
if (kv_cache is None or prefill_meta.block_tables is None
if (kv_cache.numel() == 0 or prefill_meta.block_tables is None
or prefill_meta.block_tables.numel() == 0):
# normal attention
# When block_tables are not filled, it means q and k are the
Expand Down
6 changes: 3 additions & 3 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,7 +746,7 @@ def forward(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: Optional[torch.Tensor],
kv_cache: torch.Tensor,
attn_metadata: FlashInferMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
Expand All @@ -770,7 +770,7 @@ def forward(
if attn_metadata.num_decode_tokens > 0:
assert attn_metadata.num_prefill_tokens == 0, (
"Chunked prefill is not supported with flashinfer yet.")
if kv_cache is not None:
if kv_cache.numel() > 0:
# Use the same reshape and cache kernel as flash attention.
ops.reshape_and_cache_flash(
key,
Expand All @@ -796,7 +796,7 @@ def forward(
# when kv_cache is not provided.
# This happens when vllm runs the profiling to
# determine the number of blocks.
if kv_cache is None:
if kv_cache.numel() == 0:
output = torch.ops.vllm.flash_attn_varlen_func(
q=query,
k=key,
Expand Down
9 changes: 6 additions & 3 deletions vllm/attention/backends/ipex_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def forward(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: Optional[torch.Tensor],
kv_cache: torch.Tensor,
attn_metadata: IpexAttnMetadata, # type: ignore
k_scale: float = 1.0,
v_scale: float = 1.0,
Expand All @@ -180,6 +180,8 @@ def forward(
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
NOTE: kv_cache will be an empty tensor with shape [0]
for profiling run.
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
Expand All @@ -196,7 +198,7 @@ def forward(
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)

if kv_cache is not None:
if kv_cache.numel() > 0:
key_cache, value_cache = self.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)
ipex_ops.reshape_and_cache(
Expand All @@ -212,7 +214,8 @@ def forward(

if attn_metadata.is_prompt:
assert attn_metadata.seq_lens is not None
if (kv_cache is None or attn_metadata.block_tables.numel() == 0):
if (kv_cache.numel() == 0
or attn_metadata.block_tables.numel() == 0):
if self.num_kv_heads != self.num_heads:
key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
value = value.repeat_interleave(self.num_queries_per_kv,
Expand Down
12 changes: 7 additions & 5 deletions vllm/attention/backends/pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def forward(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: Tuple[Optional[torch.Tensor], Optional[torch.Tensor]],
kv_cache: Tuple[torch.Tensor, torch.Tensor],
attn_metadata: PallasMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
Expand All @@ -155,8 +155,10 @@ def forward(
query: shape = [batch_size, seq_len, num_heads * head_size]
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
key_cache = [num_kv_heads, num_blocks, block_size, head_size]
value_cache = [num_kv_heads, num_blocks, block_size, head_size]
kv_cache[0] = [num_kv_heads, num_blocks, block_size, head_size]
kv_cache[1] = [num_kv_heads, num_blocks, block_size, head_size]
NOTE: kv_cache[0] and kv_cache[1] will be an empty tensor
with shape [0] for profiling run.
attn_metadata: Metadata for attention.
Returns:
shape = [batch_size, seq_len, num_heads * head_size]
Expand All @@ -173,7 +175,7 @@ def forward(
value = value.view(batch_size, seq_len, self.num_kv_heads,
self.head_size)

if kv_cache[0] is not None:
if kv_cache[0].numel() > 0:
slot_mapping = attn_metadata.slot_mapping
key_cache, value_cache = kv_cache
write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping)
Expand Down Expand Up @@ -205,7 +207,7 @@ def forward(
output = output.permute(0, 2, 1, 3)
else:
# Decoding run.
assert kv_cache is not None
assert kv_cache[0].numel() > 0

pages_per_compute_block = 16 # TODO(woosuk): Tune this value.
if self.megacore_mode == "batch" and batch_size % 2 != 0:
Expand Down
6 changes: 4 additions & 2 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,8 @@ def forward(
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
NOTE: kv_cache will be an empty tensor with shape [0]
for profiling run.
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
Expand All @@ -412,7 +414,7 @@ def forward(
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)

if kv_cache is not None:
if kv_cache.numel() > 0:
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)

Expand Down Expand Up @@ -449,7 +451,7 @@ def forward(
if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run.
assert prefill_meta.seq_lens is not None
if kv_cache is None or prefill_meta.block_tables.numel() == 0:
if kv_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0:
# triton attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
Expand Down
9 changes: 6 additions & 3 deletions vllm/attention/backends/torch_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def forward(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: Optional[torch.Tensor],
kv_cache: torch.Tensor,
attn_metadata: TorchSDPAMetadata, # type: ignore
k_scale: float = 1.0,
v_scale: float = 1.0,
Expand All @@ -164,6 +164,8 @@ def forward(
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
NOTE: kv_cache will be an empty tensor with shape [0]
for profiling run.
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
Expand All @@ -180,7 +182,7 @@ def forward(
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)

if kv_cache is not None:
if kv_cache.numel() > 0:
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)
PagedAttention.write_to_paged_cache(key, value, key_cache,
Expand All @@ -191,7 +193,8 @@ def forward(

if attn_metadata.is_prompt:
assert attn_metadata.seq_lens is not None
if (kv_cache is None or attn_metadata.block_tables.numel() == 0):
if (kv_cache.numel() == 0
or attn_metadata.block_tables.numel() == 0):
if self.num_kv_heads != self.num_heads:
key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
value = value.repeat_interleave(self.num_queries_per_kv,
Expand Down
8 changes: 5 additions & 3 deletions vllm/attention/backends/xformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ def forward(
query: torch.Tensor,
key: Optional[torch.Tensor],
value: Optional[torch.Tensor],
kv_cache: Optional[torch.Tensor],
kv_cache: torch.Tensor,
attn_metadata: "XFormersMetadata",
k_scale: float = 1.0,
v_scale: float = 1.0,
Expand Down Expand Up @@ -489,6 +489,8 @@ def forward(
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
NOTE: kv_cache will be an empty tensor with shape [0]
for profiling run.
attn_metadata: Metadata for attention.
attn_type: Select attention type, between encoder attention,
decoder self-attention, or encoder/decoder cross-
Expand Down Expand Up @@ -522,7 +524,7 @@ def forward(
# which KV cache memory-mapping & which
# seqlen datastructures we utilize

if (attn_type != AttentionType.ENCODER and kv_cache is not None):
if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0):
# KV-cache during decoder-self- or
# encoder-decoder-cross-attention, but not
# during encoder attention.
Expand Down Expand Up @@ -588,7 +590,7 @@ def forward(

if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run.
if kv_cache is None or prefill_meta.block_tables.numel() == 0:
if kv_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0:
# normal attention.
# block tables are empty if the prompt does not have a cached
# prefix.
Expand Down
8 changes: 7 additions & 1 deletion vllm/worker/embedding_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,13 @@ def execute_model(
model_executable = self.model

num_layers = self.model_config.get_num_layers(self.parallel_config)
kv_caches = [None] * num_layers
# use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value ``None``.
# the `dtype` argument does not matter, and we use `float32` as
# a placeholder (it has wide hardware support).
kv_caches = [
torch.tensor([], dtype=torch.float32, device=self.device)
] * num_layers

execute_model_kwargs = {
"input_ids":
Expand Down
8 changes: 7 additions & 1 deletion vllm/worker/enc_dec_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,13 @@ def profile_run(self) -> None:

# Run the model with the dummy inputs.
num_layers = self.model_config.get_num_layers(self.parallel_config)
kv_caches = [None] * num_layers
# use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value ``None``.
# the `dtype` argument does not matter, and we use `float32` as
# a placeholder (it has wide hardware support).
kv_caches = [
torch.tensor([], dtype=torch.float32, device=self.device)
] * num_layers
finished_requests_ids = [seq.request_id for seq in seqs]
model_input = self.prepare_model_input(
seqs, finished_requests_ids=finished_requests_ids)
Expand Down
8 changes: 7 additions & 1 deletion vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1223,7 +1223,13 @@ def profile_run(self) -> None:

# Run the model with the dummy inputs.
num_layers = self.model_config.get_num_layers(self.parallel_config)
kv_caches = [None] * num_layers
# use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value ``None``.
# the `dtype` argument does not matter, and we use `float32` as
# a placeholder (it has wide hardware support).
kv_caches = [
torch.tensor([], dtype=torch.float32, device=self.device)
] * num_layers
finished_requests_ids = [seq.request_id for seq in seqs]
model_input = self.prepare_model_input(
seqs, finished_requests_ids=finished_requests_ids)
Expand Down
4 changes: 2 additions & 2 deletions vllm/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,7 +714,7 @@ def forward(
t: torch.Tensor,
p: torch.Tensor,
num_samples: int,
kv_caches: List[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]],
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
) -> torch.Tensor:
"""Executes the forward pass of the model and samples the next token.

Expand Down Expand Up @@ -745,7 +745,7 @@ def forward(
)

# Skip this in memory profiling at initialization.
if kv_caches[0][0] is not None:
if kv_caches[0][0].numel() > 0:
# index_copy_(slot_mapping) only works when the inserted dimension
# is 0. However, the KV cache in the Pallas backend has the shape
# [num_kv_heads, num_blocks, block_size, head_size]. To make it
Expand Down
10 changes: 9 additions & 1 deletion vllm/worker/tpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,15 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:
head_size = self.model_config.get_head_size()
num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)

kv_caches = [(None, None) for _ in range(num_layers)]
# use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value ``None``.
# the `dtype` argument does not matter, and we use `float32` as
# a placeholder (it has wide hardware support).
kv_caches = [(torch.tensor([], dtype=torch.float32,
device=self.device),
torch.tensor([], dtype=torch.float32,
device=self.device))
for _ in range(num_layers)]
self.model_runner._dummy_run(
batch_size=1,
seq_len=self.scheduler_config.max_num_batched_tokens,
Expand Down
8 changes: 7 additions & 1 deletion vllm/worker/xpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,13 @@ def profile_run(self) -> None:

# Run the model with the dummy inputs.
num_layers = self.model_config.get_num_layers(self.parallel_config)
kv_caches = [None] * num_layers
# use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value ``None``.
# the `dtype` argument does not matter, and we use `float32` as
# a placeholder (it has wide hardware support).
kv_caches = [
torch.tensor([], dtype=torch.float32, device=self.device)
] * num_layers
finished_requests_ids = [seq.request_id for seq in seqs]
model_input = self.prepare_model_input(
seqs, finished_requests_ids=finished_requests_ids)
Expand Down