Skip to content

Commit dea89b4

Browse files
njhillxuebwang-amd
authored andcommitted
[Core] Use CpuGpuBuffer for block table tensors (vllm-project#24795)
Signed-off-by: Nick Hill <nhill@redhat.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
1 parent 512ea9d commit dea89b4

File tree

6 files changed

+53
-63
lines changed

6 files changed

+53
-63
lines changed

tests/v1/tpu/worker/test_tpu_model_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def _is_req_state_block_table_match(model_runner, req_id: str) -> bool:
125125
return False
126126

127127
num_blocks = block_table.num_blocks_per_row[req_index]
128-
block_table_values = block_table.block_table_np[req_index, :num_blocks]
128+
block_table_values = block_table.block_table.np[req_index, :num_blocks]
129129
return (block_table_values == req_block_ids).all()
130130

131131

tests/v1/worker/test_gpu_input_batch.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from vllm.v1.pool.metadata import PoolingMetadata
1616
from vllm.v1.sample.logits_processor import LogitsProcessors
1717
from vllm.v1.sample.metadata import SamplingMetadata
18+
from vllm.v1.utils import CpuGpuBuffer
1819
from vllm.v1.worker.block_table import BlockTable, MultiGroupBlockTable
1920
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
2021

@@ -45,7 +46,7 @@ def _compare_objs(obj1,
4546

4647
is_same = False
4748
if isinstance(a, torch.Tensor):
48-
if (a.numel() == 0 or b.numel() == 0):
49+
if a.numel() == 0 or b.numel() == 0:
4950
is_same = (a.numel() == 0 and b.numel() == 0)
5051
elif torch.allclose(a, b):
5152
is_same = True
@@ -61,6 +62,8 @@ def _compare_objs(obj1,
6162
is_same = True # if we make it here must be same
6263
elif a == b:
6364
is_same = True
65+
elif isinstance(a, CpuGpuBuffer):
66+
is_same = np.allclose(a.np, b.np) and torch.allclose(a.gpu, b.gpu)
6467
assert is_same, f"Attribute {attr_name} is different"\
6568
f" in {obj1} and {obj2}: {a} != {b}"
6669

tests/v1/worker/test_gpu_model_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def _is_req_state_block_table_match(model_runner, req_id: str) -> bool:
165165
req_state.block_ids[0]):
166166
return False
167167
num_blocks = block_table.num_blocks_per_row[req_index]
168-
return (block_table.block_table_np[req_index, :num_blocks] ==
168+
return (block_table.block_table.np[req_index, :num_blocks] ==
169169
req_state.block_ids[0]).all()
170170

171171

vllm/v1/worker/block_table.py

Lines changed: 32 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
from typing import Union
34

45
import numpy as np
56
import torch
67

78
from vllm.distributed import get_dcp_group
89
from vllm.logger import init_logger
910
from vllm.utils import cdiv
11+
from vllm.v1.utils import CpuGpuBuffer
1012

1113
logger = init_logger(__name__)
1214

@@ -29,28 +31,13 @@ def __init__(
2931
self.pin_memory = pin_memory
3032
self.device = device
3133

32-
self.block_table = torch.zeros(
33-
(max_num_reqs, max_num_blocks_per_req),
34-
device=self.device,
35-
dtype=torch.int32,
36-
)
37-
self.block_table_cpu = torch.zeros(
38-
(max_num_reqs, max_num_blocks_per_req),
39-
device="cpu",
40-
dtype=torch.int32,
41-
pin_memory=pin_memory,
42-
)
43-
self.block_table_np = self.block_table_cpu.numpy()
34+
self.block_table = self._make_buffer(max_num_reqs,
35+
max_num_blocks_per_req,
36+
dtype=torch.int32)
4437
self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32)
4538

46-
self.slot_mapping_cpu = torch.zeros(self.max_num_batched_tokens,
47-
dtype=torch.int64,
48-
device="cpu",
49-
pin_memory=self.pin_memory)
50-
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
51-
self.slot_mapping = torch.zeros(self.max_num_batched_tokens,
52-
dtype=torch.int64,
53-
device=self.device)
39+
self.slot_mapping = self._make_buffer(self.max_num_batched_tokens,
40+
dtype=torch.int64)
5441
try:
5542
self.dcp_world_size = get_dcp_group().world_size
5643
self.dcp_rank = get_dcp_group().rank_in_group
@@ -69,25 +56,22 @@ def append_row(
6956
num_blocks = len(block_ids)
7057
start = self.num_blocks_per_row[row_idx]
7158
self.num_blocks_per_row[row_idx] += num_blocks
72-
self.block_table_np[row_idx, start:start + num_blocks] = block_ids
59+
self.block_table.np[row_idx, start:start + num_blocks] = block_ids
7360

7461
def add_row(self, block_ids: list[int], row_idx: int) -> None:
7562
self.num_blocks_per_row[row_idx] = 0
7663
self.append_row(block_ids, row_idx)
7764

7865
def move_row(self, src: int, tgt: int) -> None:
7966
num_blocks = self.num_blocks_per_row[src]
80-
self.block_table_np[tgt, :num_blocks] = self.block_table_np[
81-
src, :num_blocks]
67+
block_table_np = self.block_table.np
68+
block_table_np[tgt, :num_blocks] = block_table_np[src, :num_blocks]
8269
self.num_blocks_per_row[tgt] = num_blocks
8370

8471
def swap_row(self, src: int, tgt: int) -> None:
85-
num_blocks_src = self.num_blocks_per_row[src]
86-
num_blocks_tgt = self.num_blocks_per_row[tgt]
87-
self.num_blocks_per_row[src] = num_blocks_tgt
88-
self.num_blocks_per_row[tgt] = num_blocks_src
89-
90-
self.block_table_np[[src, tgt]] = self.block_table_np[[tgt, src]]
72+
src_tgt, tgt_src = [src, tgt], [tgt, src]
73+
self.num_blocks_per_row[src_tgt] = self.num_blocks_per_row[tgt_src]
74+
self.block_table.np[src_tgt] = self.block_table.np[tgt_src]
9175

9276
def compute_slot_mapping(self, req_indices: np.ndarray,
9377
positions: np.ndarray) -> None:
@@ -107,7 +91,7 @@ def compute_slot_mapping(self, req_indices: np.ndarray,
10791
virtual_block_size = self.block_size * self.dcp_world_size
10892
block_table_indices = (req_indices * self.max_num_blocks_per_req +
10993
positions // virtual_block_size)
110-
block_numbers = self.block_table_np.ravel()[block_table_indices]
94+
block_numbers = self.block_table.np.ravel()[block_table_indices]
11195
# Use virtual_block_size for mask calculation, which marks local
11296
# tokens.
11397
virtual_block_offsets = positions % virtual_block_size
@@ -117,40 +101,45 @@ def compute_slot_mapping(self, req_indices: np.ndarray,
117101
# Calculate slot_mapping
118102
slot_mapping = block_numbers * self.block_size + block_offsets
119103
# Write final slots, use -1 for not-local
120-
self.slot_mapping_np[:req_indices.shape[0]] = np.where(
104+
self.slot_mapping.np[:req_indices.shape[0]] = np.where(
121105
mask, slot_mapping, -1)
122106
else:
123107
block_table_indices = (req_indices * self.max_num_blocks_per_req +
124108
positions // self.block_size)
125-
block_numbers = self.block_table_np.ravel()[block_table_indices]
109+
block_numbers = self.block_table.np.ravel()[block_table_indices]
126110
block_offsets = positions % self.block_size
127111
np.add(block_numbers * self.block_size,
128112
block_offsets,
129-
out=self.slot_mapping_np[:req_indices.shape[0]])
113+
out=self.slot_mapping.np[:req_indices.shape[0]])
130114

131115
def commit_block_table(self, num_reqs: int) -> None:
132-
self.block_table[:num_reqs].copy_(self.block_table_cpu[:num_reqs],
133-
non_blocking=True)
116+
self.block_table.copy_to_gpu(num_reqs)
134117

135118
def commit_slot_mapping(self, num_tokens: int) -> None:
136-
self.slot_mapping[:num_tokens].copy_(
137-
self.slot_mapping_cpu[:num_tokens], non_blocking=True)
119+
self.slot_mapping.copy_to_gpu(num_tokens)
138120

139121
def clear(self) -> None:
140-
self.block_table.fill_(0)
141-
self.block_table_cpu.fill_(0)
122+
self.block_table.gpu.fill_(0)
123+
self.block_table.cpu.fill_(0)
142124

143-
def get_device_tensor(self) -> torch.Tensor:
125+
def get_device_tensor(self, num_reqs: int) -> torch.Tensor:
144126
"""Returns the device tensor of the block table."""
145-
return self.block_table
127+
return self.block_table.gpu[:num_reqs]
146128

147129
def get_cpu_tensor(self) -> torch.Tensor:
148130
"""Returns the CPU tensor of the block table."""
149-
return self.block_table_cpu
131+
return self.block_table.cpu
150132

151133
def get_numpy_array(self) -> np.ndarray:
152134
"""Returns the numpy array of the block table."""
153-
return self.block_table_np
135+
return self.block_table.np
136+
137+
def _make_buffer(self, *size: Union[int, torch.SymInt],
138+
dtype: torch.dtype) -> CpuGpuBuffer:
139+
return CpuGpuBuffer(*size,
140+
dtype=dtype,
141+
device=self.device,
142+
pin_memory=self.pin_memory)
154143

155144

156145
class MultiGroupBlockTable:

vllm/v1/worker/cpu_model_runner.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def replace_tensor(obj: Any, cpu_attr_name: str,
8989
assert isinstance(device_tensor, torch.Tensor)
9090
setattr(obj, device_attr_name, cpu_tensor)
9191

92-
for k, v in vars(self).items():
92+
for v in vars(self).values():
9393
if isinstance(v, CpuGpuBuffer):
9494
v.gpu = v.cpu
9595

@@ -98,9 +98,9 @@ def replace_tensor(obj: Any, cpu_attr_name: str,
9898
replace_tensor(self.input_batch, k, k[:-11])
9999

100100
for block_table in self.input_batch.block_table.block_tables:
101-
for k, v in vars(block_table).items():
102-
if k.endswith("_cpu") and isinstance(v, torch.Tensor):
103-
replace_tensor(block_table, k, k[:-4])
101+
for v in vars(block_table).values():
102+
if isinstance(v, CpuGpuBuffer):
103+
v.gpu = v.cpu
104104

105105
def load_model(self, eep_scale_up: bool = False) -> None:
106106
logger.info("Starting to load model %s...", self.model_config.model)

vllm/v1/worker/gpu_model_runner.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -427,9 +427,6 @@ def _make_buffer(self,
427427
*size: Union[int, torch.SymInt],
428428
dtype: torch.dtype,
429429
numpy: bool = True) -> CpuGpuBuffer:
430-
# Bfloat16 torch tensors cannot be directly cast to a numpy array, so
431-
# if a bfloat16 buffer is needed without a corresponding numpy array,
432-
# don't bother instantiating the numpy array.
433430
return CpuGpuBuffer(*size,
434431
dtype=dtype,
435432
device=self.device,
@@ -1062,13 +1059,14 @@ def _prepare_inputs(
10621059
num_common_prefix_blocks = 0
10631060
else:
10641061
blk_table = self.input_batch.block_table[kv_cache_group_id]
1065-
blk_table_tensor = blk_table.get_device_tensor()[:num_reqs]
1066-
slot_mapping = blk_table.slot_mapping[:
1067-
total_num_scheduled_tokens]
1062+
blk_table_tensor = blk_table.get_device_tensor(num_reqs)
1063+
slot_mapping = blk_table.slot_mapping.gpu[:
1064+
total_num_scheduled_tokens]
10681065

10691066
# Fill unused with -1. Needed for reshape_and_cache in full cuda
10701067
# graph mode.
1071-
blk_table.slot_mapping[total_num_scheduled_tokens:].fill_(-1)
1068+
blk_table.slot_mapping.gpu[total_num_scheduled_tokens:].fill_(
1069+
-1)
10721070
num_common_prefix_blocks = (
10731071
scheduler_output.
10741072
num_common_prefix_blocks[kv_cache_group_id])
@@ -2903,10 +2901,10 @@ def _dummy_run(
29032901
num_actual_tokens=num_tokens,
29042902
max_query_len=max_query_len,
29052903
max_seq_len=self.max_model_len,
2906-
block_table_tensor=self.input_batch.block_table[
2907-
kv_cache_group_id].get_device_tensor()[:num_reqs],
2908-
slot_mapping=self.input_batch.
2909-
block_table[kv_cache_group_id].slot_mapping[:num_tokens],
2904+
block_table_tensor=self.input_batch.
2905+
block_table[kv_cache_group_id].get_device_tensor(num_reqs),
2906+
slot_mapping=self.input_batch.block_table[
2907+
kv_cache_group_id].slot_mapping.gpu[:num_tokens],
29102908
causal=True)
29112909
for attn_group in self.attn_groups[kv_cache_group_id]:
29122910
if ubatch_slices is not None:
@@ -3265,8 +3263,8 @@ def freeze_gc():
32653263
cudagraph_runtime_mode=cudagraph_runtime_mode,
32663264
uniform_decode=False)
32673265

3268-
# Capture full cudagraph for uniform decode batches if we have
3269-
# dont already have full mixed prefill-decode cudagraphs
3266+
# Capture full cudagraph for uniform decode batches if we
3267+
# don't already have full mixed prefill-decode cudagraphs.
32703268
if cudagraph_mode.decode_mode() == CUDAGraphMode.FULL and \
32713269
cudagraph_mode.separate_routine():
32723270
max_num_tokens = self.scheduler_config.max_num_seqs * \

0 commit comments

Comments
 (0)