Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/v1/tpu/worker/test_tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def _is_req_state_block_table_match(model_runner, req_id: str) -> bool:
return False

num_blocks = block_table.num_blocks_per_row[req_index]
block_table_values = block_table.block_table_np[req_index, :num_blocks]
block_table_values = block_table.block_table.np[req_index, :num_blocks]
return (block_table_values == req_block_ids).all()


Expand Down
5 changes: 4 additions & 1 deletion tests/v1/worker/test_gpu_input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.logits_processor import LogitsProcessors
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.block_table import BlockTable, MultiGroupBlockTable
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch

Expand Down Expand Up @@ -45,7 +46,7 @@ def _compare_objs(obj1,

is_same = False
if isinstance(a, torch.Tensor):
if (a.numel() == 0 or b.numel() == 0):
if a.numel() == 0 or b.numel() == 0:
is_same = (a.numel() == 0 and b.numel() == 0)
elif torch.allclose(a, b):
is_same = True
Expand All @@ -61,6 +62,8 @@ def _compare_objs(obj1,
is_same = True # if we make it here must be same
elif a == b:
is_same = True
elif isinstance(a, CpuGpuBuffer):
is_same = np.allclose(a.np, b.np) and torch.allclose(a.gpu, b.gpu)
assert is_same, f"Attribute {attr_name} is different"\
f" in {obj1} and {obj2}: {a} != {b}"

Expand Down
2 changes: 1 addition & 1 deletion tests/v1/worker/test_gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def _is_req_state_block_table_match(model_runner, req_id: str) -> bool:
req_state.block_ids[0]):
return False
num_blocks = block_table.num_blocks_per_row[req_index]
return (block_table.block_table_np[req_index, :num_blocks] ==
return (block_table.block_table.np[req_index, :num_blocks] ==
req_state.block_ids[0]).all()


Expand Down
75 changes: 32 additions & 43 deletions vllm/v1/worker/block_table.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Union

import numpy as np
import torch

from vllm.distributed import get_dcp_group
from vllm.logger import init_logger
from vllm.utils import cdiv
from vllm.v1.utils import CpuGpuBuffer

logger = init_logger(__name__)

Expand All @@ -29,28 +31,13 @@ def __init__(
self.pin_memory = pin_memory
self.device = device

self.block_table = torch.zeros(
(max_num_reqs, max_num_blocks_per_req),
device=self.device,
dtype=torch.int32,
)
self.block_table_cpu = torch.zeros(
(max_num_reqs, max_num_blocks_per_req),
device="cpu",
dtype=torch.int32,
pin_memory=pin_memory,
)
self.block_table_np = self.block_table_cpu.numpy()
self.block_table = self._make_buffer(max_num_reqs,
max_num_blocks_per_req,
dtype=torch.int32)
self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32)

self.slot_mapping_cpu = torch.zeros(self.max_num_batched_tokens,
dtype=torch.int64,
device="cpu",
pin_memory=self.pin_memory)
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
self.slot_mapping = torch.zeros(self.max_num_batched_tokens,
dtype=torch.int64,
device=self.device)
self.slot_mapping = self._make_buffer(self.max_num_batched_tokens,
dtype=torch.int64)
try:
self.dcp_world_size = get_dcp_group().world_size
self.dcp_rank = get_dcp_group().rank_in_group
Expand All @@ -69,25 +56,22 @@ def append_row(
num_blocks = len(block_ids)
start = self.num_blocks_per_row[row_idx]
self.num_blocks_per_row[row_idx] += num_blocks
self.block_table_np[row_idx, start:start + num_blocks] = block_ids
self.block_table.np[row_idx, start:start + num_blocks] = block_ids

def add_row(self, block_ids: list[int], row_idx: int) -> None:
self.num_blocks_per_row[row_idx] = 0
self.append_row(block_ids, row_idx)

def move_row(self, src: int, tgt: int) -> None:
num_blocks = self.num_blocks_per_row[src]
self.block_table_np[tgt, :num_blocks] = self.block_table_np[
src, :num_blocks]
block_table_np = self.block_table.np
block_table_np[tgt, :num_blocks] = block_table_np[src, :num_blocks]
self.num_blocks_per_row[tgt] = num_blocks

def swap_row(self, src: int, tgt: int) -> None:
num_blocks_src = self.num_blocks_per_row[src]
num_blocks_tgt = self.num_blocks_per_row[tgt]
self.num_blocks_per_row[src] = num_blocks_tgt
self.num_blocks_per_row[tgt] = num_blocks_src

self.block_table_np[[src, tgt]] = self.block_table_np[[tgt, src]]
src_tgt, tgt_src = [src, tgt], [tgt, src]
self.num_blocks_per_row[src_tgt] = self.num_blocks_per_row[tgt_src]
self.block_table.np[src_tgt] = self.block_table.np[tgt_src]

def compute_slot_mapping(self, req_indices: np.ndarray,
positions: np.ndarray) -> None:
Expand All @@ -107,7 +91,7 @@ def compute_slot_mapping(self, req_indices: np.ndarray,
virtual_block_size = self.block_size * self.dcp_world_size
block_table_indices = (req_indices * self.max_num_blocks_per_req +
positions // virtual_block_size)
block_numbers = self.block_table_np.ravel()[block_table_indices]
block_numbers = self.block_table.np.ravel()[block_table_indices]
# Use virtual_block_size for mask calculation, which marks local
# tokens.
virtual_block_offsets = positions % virtual_block_size
Expand All @@ -117,40 +101,45 @@ def compute_slot_mapping(self, req_indices: np.ndarray,
# Calculate slot_mapping
slot_mapping = block_numbers * self.block_size + block_offsets
# Write final slots, use -1 for not-local
self.slot_mapping_np[:req_indices.shape[0]] = np.where(
self.slot_mapping.np[:req_indices.shape[0]] = np.where(
mask, slot_mapping, -1)
else:
block_table_indices = (req_indices * self.max_num_blocks_per_req +
positions // self.block_size)
block_numbers = self.block_table_np.ravel()[block_table_indices]
block_numbers = self.block_table.np.ravel()[block_table_indices]
block_offsets = positions % self.block_size
np.add(block_numbers * self.block_size,
block_offsets,
out=self.slot_mapping_np[:req_indices.shape[0]])
out=self.slot_mapping.np[:req_indices.shape[0]])

def commit_block_table(self, num_reqs: int) -> None:
self.block_table[:num_reqs].copy_(self.block_table_cpu[:num_reqs],
non_blocking=True)
self.block_table.copy_to_gpu(num_reqs)

def commit_slot_mapping(self, num_tokens: int) -> None:
self.slot_mapping[:num_tokens].copy_(
self.slot_mapping_cpu[:num_tokens], non_blocking=True)
self.slot_mapping.copy_to_gpu(num_tokens)

def clear(self) -> None:
self.block_table.fill_(0)
self.block_table_cpu.fill_(0)
self.block_table.gpu.fill_(0)
self.block_table.cpu.fill_(0)

def get_device_tensor(self) -> torch.Tensor:
def get_device_tensor(self, num_reqs: int) -> torch.Tensor:
"""Returns the device tensor of the block table."""
return self.block_table
return self.block_table.gpu[:num_reqs]

def get_cpu_tensor(self) -> torch.Tensor:
"""Returns the CPU tensor of the block table."""
return self.block_table_cpu
return self.block_table.cpu

def get_numpy_array(self) -> np.ndarray:
"""Returns the numpy array of the block table."""
return self.block_table_np
return self.block_table.np

def _make_buffer(self, *size: Union[int, torch.SymInt],
dtype: torch.dtype) -> CpuGpuBuffer:
return CpuGpuBuffer(*size,
dtype=dtype,
device=self.device,
pin_memory=self.pin_memory)


class MultiGroupBlockTable:
Expand Down
8 changes: 4 additions & 4 deletions vllm/v1/worker/cpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def replace_tensor(obj: Any, cpu_attr_name: str,
assert isinstance(device_tensor, torch.Tensor)
setattr(obj, device_attr_name, cpu_tensor)

for k, v in vars(self).items():
for v in vars(self).values():
if isinstance(v, CpuGpuBuffer):
v.gpu = v.cpu

Expand All @@ -93,9 +93,9 @@ def replace_tensor(obj: Any, cpu_attr_name: str,
replace_tensor(self.input_batch, k, k[:-11])

for block_table in self.input_batch.block_table.block_tables:
for k, v in vars(block_table).items():
if k.endswith("_cpu") and isinstance(v, torch.Tensor):
replace_tensor(block_table, k, k[:-4])
for v in vars(block_table).values():
if isinstance(v, CpuGpuBuffer):
v.gpu = v.cpu

def load_model(self, eep_scale_up: bool = False) -> None:
logger.info("Starting to load model %s...", self.model_config.model)
Expand Down
24 changes: 11 additions & 13 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,9 +416,6 @@ def _make_buffer(self,
*size: Union[int, torch.SymInt],
dtype: torch.dtype,
numpy: bool = True) -> CpuGpuBuffer:
# Bfloat16 torch tensors cannot be directly cast to a numpy array, so
# if a bfloat16 buffer is needed without a corresponding numpy array,
# don't bother instantiating the numpy array.
return CpuGpuBuffer(*size,
dtype=dtype,
device=self.device,
Expand Down Expand Up @@ -1039,13 +1036,14 @@ def _prepare_inputs(
num_common_prefix_blocks = 0
else:
blk_table = self.input_batch.block_table[kv_cache_group_id]
blk_table_tensor = blk_table.get_device_tensor()[:num_reqs]
slot_mapping = blk_table.slot_mapping[:
total_num_scheduled_tokens]
blk_table_tensor = blk_table.get_device_tensor(num_reqs)
slot_mapping = blk_table.slot_mapping.gpu[:
total_num_scheduled_tokens]

# Fill unused with -1. Needed for reshape_and_cache in full cuda
# graph mode.
blk_table.slot_mapping[total_num_scheduled_tokens:].fill_(-1)
blk_table.slot_mapping.gpu[total_num_scheduled_tokens:].fill_(
-1)
num_common_prefix_blocks = (
scheduler_output.
num_common_prefix_blocks[kv_cache_group_id])
Expand Down Expand Up @@ -2761,10 +2759,10 @@ def _dummy_run(
num_actual_tokens=num_tokens,
max_query_len=max_query_len,
max_seq_len=self.max_model_len,
block_table_tensor=self.input_batch.block_table[
kv_cache_group_id].get_device_tensor()[:num_reqs],
slot_mapping=self.input_batch.
block_table[kv_cache_group_id].slot_mapping[:num_tokens],
block_table_tensor=self.input_batch.
block_table[kv_cache_group_id].get_device_tensor(num_reqs),
slot_mapping=self.input_batch.block_table[
kv_cache_group_id].slot_mapping.gpu[:num_tokens],
causal=True)

for attn_group in self.attn_groups[kv_cache_group_id]:
Expand Down Expand Up @@ -3105,8 +3103,8 @@ def freeze_gc():
cudagraph_runtime_mode=cudagraph_runtime_mode,
uniform_decode=False)

# Capture full cudagraph for uniform decode batches if we have
# dont already have full mixed prefill-decode cudagraphs
# Capture full cudagraph for uniform decode batches if we
# don't already have full mixed prefill-decode cudagraphs.
if cudagraph_mode.decode_mode() == CUDAGraphMode.FULL and \
cudagraph_mode.separate_routine():
max_num_tokens = self.scheduler_config.max_num_seqs * \
Expand Down