Skip to content

[Core] Prefill Only Tokens Without KV Cache in Batch Requests (Disagg Prefill) #12285

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions vllm/distributed/kv_transfer/kv_connector/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def recv_kv_caches_and_hidden_states(
self, model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: List[torch.Tensor]
) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool,
) -> Tuple[Union[torch.Tensor, IntermediateTensors], List[bool],
"ModelInputForGPUWithSamplingMetadata"]:
"""
Receive KV caches and hidden states from the connector.
Expand All @@ -110,7 +110,7 @@ def recv_kv_caches_and_hidden_states(
IntermediateTensors):
Concatenated hidden states if all required data is retrieved,
otherwise `None`.
- bypass_model_exec (bool):
- bypass_model_exec (List[bool]):
Indicates whether the model execution can be skipped (True) or
needs to be redone (False).
- model_input (ModelInputForGPUWithSamplingMetadata):
Expand Down
176 changes: 155 additions & 21 deletions vllm/distributed/kv_transfer/kv_connector/simple_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@

But the logic can be extended to support other pipe and lookup buffer.
"""
from copy import deepcopy
from typing import TYPE_CHECKING, List, Optional, Tuple, Union

import torch
from torch.nn.utils.rnn import pad_sequence

from vllm import _custom_ops as ops
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import (
Expand All @@ -36,6 +39,8 @@ def __init__(

self.config = config.kv_transfer_config
self.tp_size = config.parallel_config.tensor_parallel_size
# The following config is needed to rebuild the model input
self.cache_config = config.cache_config

if self.config.kv_connector == "PyNcclConnector":
from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import (
Expand Down Expand Up @@ -202,15 +207,9 @@ def recv_kv_caches_and_hidden_states(
self, model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: List[torch.Tensor]
) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool,
) -> Tuple[Union[torch.Tensor, IntermediateTensors], List[bool],
"ModelInputForGPUWithSamplingMetadata"]:

# When bypass_model_exec is set to False, it means that at least for one
# request its corresponding KV cache or hidden state is missing.
# In this case we need to do prefilling to recompute missing KV cache
# and hidden states.
bypass_model_exec = True

input_tokens_tensor = model_input.input_tokens
seq_lens = model_input.attn_metadata.seq_lens
slot_mapping = model_input.attn_metadata.slot_mapping.flatten()
Expand All @@ -221,6 +220,12 @@ def recv_kv_caches_and_hidden_states(
num_computed_tokens_list = []
start_pos_list = []

# When bypass_model_exec[i] is set to False, it means that for
# request[i] its corresponding KV cache or hidden state is missing.
# In this case we need to do prefilling to recompute missing KV cache
# and hidden states of request[i].
bypass_model_exec = [True] * len(seq_lens)

# enumerate different requests
# FIXME(Kuntai): This impl assumes that all requests are prefill.
for idx, slen in enumerate(seq_lens):
Expand All @@ -238,7 +243,7 @@ def recv_kv_caches_and_hidden_states(
torch.ones_like(current_tokens, dtype=bool))
if ret[0] is None:
# didn't find any match.
bypass_model_exec = False
bypass_model_exec[idx] = False
num_computed_tokens_list.append(0)
continue

Expand All @@ -248,17 +253,22 @@ def recv_kv_caches_and_hidden_states(
hidden: torch.Tensor = ret[4]

num_computed_tokens = roi.shape[0]
num_computed_tokens_list.append(num_computed_tokens)

# check if both KV cache and the hidden states are received
# If not, need to redo the forwarding to compute missing states
if not all([(num_computed_tokens == num_tokens), hidden is not None
]):
bypass_model_exec = False
bypass_model_exec[idx] = False
continue

# update the end position based on how many tokens are cached.
end_pos = start_pos + num_computed_tokens

# Avoid error when prefix is exactly the same as the retrieved
if num_computed_tokens == num_tokens:
num_computed_tokens -= 1
num_computed_tokens_list.append(num_computed_tokens)

# put received KV caches into paged memory
for i in range(model_executable.model.start_layer,
model_executable.model.end_layer):
Expand All @@ -282,23 +292,35 @@ def recv_kv_caches_and_hidden_states(

hidden_or_intermediate_states_for_one_req.append(hidden)

if not bypass_model_exec:
# Some of the KV cache is not retrieved
# Here we will fall back to normal model forwarding
# But optionally you can adjust model_input so that you only do
# prefilling on those tokens that are missing KV caches.
logger.debug(
"[rank%d]: Failed to receive all KVs and hidden "
"states, redo model forwarding.", torch.distributed.get_rank())
hidden_or_intermediate_states = None

else:
all_bypass_flag = True
for idx, bypass_flag in enumerate(bypass_model_exec):
if not bypass_flag:
# Some of the KV cache of this request is not retrieved
# Here we will fall back to normal model forwarding
logger.debug(
"[rank%d]: Failed to receive request %d's"
" KVs and hidden states, "
"redo model forwarding.", torch.distributed.get_rank(),
idx)

hidden_or_intermediate_states = torch.cat(
hidden_or_intermediate_states_for_one_req, dim=0)
all_bypass_flag = False
if all_bypass_flag:
logger.debug(
"[rank%d]: Successfully received all KVs and hidden "
"states, skip model forwarding.", torch.distributed.get_rank())
hidden_or_intermediate_states = torch.cat(
hidden_or_intermediate_states_for_one_req, dim=0)

if not all(bypass_model_exec):
rebuilt_model_input = self.build_partial_prefill_input(
model_input, input_tokens_list, num_computed_tokens_list,
start_pos_list, slot_mapping, kv_caches[0][0].device)
logger.debug("Rebuilt the input!")
return (hidden_or_intermediate_states, bypass_model_exec,
rebuilt_model_input)

return hidden_or_intermediate_states, bypass_model_exec, model_input

def close(self):
Expand All @@ -311,3 +333,115 @@ def close(self):
# MooncakePipe reuses data_pipe for signal_pipe, so we only have to
# close the data_pipe.
pass

def build_partial_prefill_input(
self, model_input: "ModelInputForGPUWithSamplingMetadata",
full_tokens_list: List[torch.Tensor],
num_computed_tokens_list: List[int], start_pos_list: List[int],
slot_mapping_flat: torch.Tensor,
device: torch.device) -> "ModelInputForGPUWithSamplingMetadata":
"""Helper function to rebuild the model input for the current request.
"""
assert model_input.attn_metadata is not None
assert isinstance(model_input.attn_metadata, FlashAttentionMetadata), \
"Only FlashAttention backend is supported for now."
assert model_input.attn_metadata.context_lens_tensor is not None
assert model_input.attn_metadata.block_tables is not None
assert model_input.attn_metadata.query_start_loc is not None
assert model_input.input_positions is not None

rebuilt_input_tokens = []
rebuilt_input_positions = []
rebuilt_num_prefills = 0
rebuilt_num_prefill_tokens = 0
rebuilt_slot_mapping = []
rebuilt_max_query_len = 0

rebuilt_block_tables = []

rebuilt_query_start_loc = [0]
rebuilt_context_lens_tensor = []

last_query_start_loc = 0

# recounting query and context lengths
for idx in range(len(full_tokens_list)):
token_tensor = full_tokens_list[idx]
num_token = len(token_tensor)
num_computed_token = num_computed_tokens_list[idx]
start_pos = start_pos_list[idx]
q_len = num_token - num_computed_token

rebuilt_input_tokens.append(token_tensor[num_computed_token:])

assert q_len > 0
start_input_pos_idx = start_pos + num_computed_token
end_input_pos_idx = start_input_pos_idx + q_len
rebuilt_input_positions.append(
model_input.
input_positions[start_input_pos_idx:end_input_pos_idx])

# Attn metadata-related
rebuilt_num_prefills += 1
rebuilt_num_prefill_tokens += q_len
start_slot_idx = start_pos + num_computed_token
end_slot_idx = start_slot_idx + q_len
new_slot_mapping = slot_mapping_flat[start_slot_idx:end_slot_idx]
rebuilt_slot_mapping.append(new_slot_mapping)
rebuilt_max_query_len = max(q_len, rebuilt_max_query_len)
last_query_start_loc += q_len
rebuilt_query_start_loc.append(last_query_start_loc)
rebuilt_context_lens_tensor.append(num_computed_token)

# recover `block_table`
if len(model_input.attn_metadata.block_tables[idx]) > 0:
rebuilt_block_tables.append(
model_input.attn_metadata.block_tables[idx])
else:
slot_mapping_req = slot_mapping_flat[start_pos:end_slot_idx]
vllm_block_size = self.cache_config.block_size
rebuilt_block_table = slot_mapping_req[::16].to(torch.int32) \
// vllm_block_size
rebuilt_block_tables.append(rebuilt_block_table)

# rebuilt attn_metadata
rebuilt_attn_metadata = deepcopy(model_input.attn_metadata)
rebuilt_attn_metadata.num_prefills = rebuilt_num_prefills
rebuilt_attn_metadata.num_prefill_tokens = rebuilt_num_prefill_tokens
rebuilt_attn_metadata.slot_mapping = torch.cat(
rebuilt_slot_mapping).to(device)
rebuilt_attn_metadata.max_query_len = rebuilt_max_query_len
rebuilt_attn_metadata.block_tables = pad_sequence(
rebuilt_block_tables, batch_first=True).to(device)
rebuilt_attn_metadata.query_start_loc = torch.tensor(
rebuilt_query_start_loc,
dtype=model_input.attn_metadata.query_start_loc.dtype).to(device)
rebuilt_attn_metadata.context_lens_tensor = torch.tensor(
rebuilt_context_lens_tensor,
dtype=model_input.attn_metadata.context_lens_tensor.dtype,
).to(device)
rebuilt_attn_metadata._cached_prefill_metadata = None

# import here to avoid circular import.
from vllm.worker.model_runner import (
ModelInputForGPUWithSamplingMetadata)
rebuilt_model_input = ModelInputForGPUWithSamplingMetadata(
input_tokens=torch.cat(rebuilt_input_tokens).to(device),
input_positions=torch.cat(rebuilt_input_positions).to(device),
seq_lens=model_input.seq_lens,
query_lens=model_input.query_lens,
lora_mapping=model_input.lora_mapping,
lora_requests=model_input.lora_requests,
attn_metadata=rebuilt_attn_metadata,
prompt_adapter_mapping=model_input.prompt_adapter_mapping,
prompt_adapter_requests=model_input.prompt_adapter_requests,
multi_modal_kwargs=model_input.multi_modal_kwargs,
request_ids_to_seq_ids=model_input.request_ids_to_seq_ids,
finished_requests_ids=model_input.finished_requests_ids,
virtual_engine=model_input.virtual_engine,
sampling_metadata=model_input.sampling_metadata,
is_prompt=model_input.is_prompt,
async_callback=model_input.async_callback,
)

return rebuilt_model_input
2 changes: 1 addition & 1 deletion vllm/distributed/kv_transfer/kv_transfer_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def recv_kv_caches_and_hidden_states(
self, model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: List[torch.Tensor]
) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool,
) -> Tuple[Union[torch.Tensor, IntermediateTensors], List[bool],
"ModelInputForGPUWithSamplingMetadata"]:

return self.connector.recv_kv_caches_and_hidden_states(
Expand Down
Loading
Loading