11# SPDX-License-Identifier: Apache-2.0
22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+ from typing import Union
34
45import numpy as np
56import torch
67
78from vllm .distributed import get_dcp_group
89from vllm .logger import init_logger
910from vllm .utils import cdiv
11+ from vllm .v1 .utils import CpuGpuBuffer
1012
1113logger = init_logger (__name__ )
1214
@@ -29,28 +31,13 @@ def __init__(
2931 self .pin_memory = pin_memory
3032 self .device = device
3133
32- self .block_table = torch .zeros (
33- (max_num_reqs , max_num_blocks_per_req ),
34- device = self .device ,
35- dtype = torch .int32 ,
36- )
37- self .block_table_cpu = torch .zeros (
38- (max_num_reqs , max_num_blocks_per_req ),
39- device = "cpu" ,
40- dtype = torch .int32 ,
41- pin_memory = pin_memory ,
42- )
43- self .block_table_np = self .block_table_cpu .numpy ()
34+ self .block_table = self ._make_buffer (max_num_reqs ,
35+ max_num_blocks_per_req ,
36+ dtype = torch .int32 )
4437 self .num_blocks_per_row = np .zeros (max_num_reqs , dtype = np .int32 )
4538
46- self .slot_mapping_cpu = torch .zeros (self .max_num_batched_tokens ,
47- dtype = torch .int64 ,
48- device = "cpu" ,
49- pin_memory = self .pin_memory )
50- self .slot_mapping_np = self .slot_mapping_cpu .numpy ()
51- self .slot_mapping = torch .zeros (self .max_num_batched_tokens ,
52- dtype = torch .int64 ,
53- device = self .device )
39+ self .slot_mapping = self ._make_buffer (self .max_num_batched_tokens ,
40+ dtype = torch .int64 )
5441 try :
5542 self .dcp_world_size = get_dcp_group ().world_size
5643 self .dcp_rank = get_dcp_group ().rank_in_group
@@ -69,25 +56,22 @@ def append_row(
6956 num_blocks = len (block_ids )
7057 start = self .num_blocks_per_row [row_idx ]
7158 self .num_blocks_per_row [row_idx ] += num_blocks
72- self .block_table_np [row_idx , start :start + num_blocks ] = block_ids
59+ self .block_table . np [row_idx , start :start + num_blocks ] = block_ids
7360
7461 def add_row (self , block_ids : list [int ], row_idx : int ) -> None :
7562 self .num_blocks_per_row [row_idx ] = 0
7663 self .append_row (block_ids , row_idx )
7764
7865 def move_row (self , src : int , tgt : int ) -> None :
7966 num_blocks = self .num_blocks_per_row [src ]
80- self . block_table_np [ tgt , : num_blocks ] = self .block_table_np [
81- src , :num_blocks ]
67+ block_table_np = self .block_table . np
68+ block_table_np [ tgt , : num_blocks ] = block_table_np [ src , :num_blocks ]
8269 self .num_blocks_per_row [tgt ] = num_blocks
8370
8471 def swap_row (self , src : int , tgt : int ) -> None :
85- num_blocks_src = self .num_blocks_per_row [src ]
86- num_blocks_tgt = self .num_blocks_per_row [tgt ]
87- self .num_blocks_per_row [src ] = num_blocks_tgt
88- self .num_blocks_per_row [tgt ] = num_blocks_src
89-
90- self .block_table_np [[src , tgt ]] = self .block_table_np [[tgt , src ]]
72+ src_tgt , tgt_src = [src , tgt ], [tgt , src ]
73+ self .num_blocks_per_row [src_tgt ] = self .num_blocks_per_row [tgt_src ]
74+ self .block_table .np [src_tgt ] = self .block_table .np [tgt_src ]
9175
9276 def compute_slot_mapping (self , req_indices : np .ndarray ,
9377 positions : np .ndarray ) -> None :
@@ -107,7 +91,7 @@ def compute_slot_mapping(self, req_indices: np.ndarray,
10791 virtual_block_size = self .block_size * self .dcp_world_size
10892 block_table_indices = (req_indices * self .max_num_blocks_per_req +
10993 positions // virtual_block_size )
110- block_numbers = self .block_table_np .ravel ()[block_table_indices ]
94+ block_numbers = self .block_table . np .ravel ()[block_table_indices ]
11195 # Use virtual_block_size for mask calculation, which marks local
11296 # tokens.
11397 virtual_block_offsets = positions % virtual_block_size
@@ -117,40 +101,45 @@ def compute_slot_mapping(self, req_indices: np.ndarray,
117101 # Calculate slot_mapping
118102 slot_mapping = block_numbers * self .block_size + block_offsets
119103 # Write final slots, use -1 for not-local
120- self .slot_mapping_np [:req_indices .shape [0 ]] = np .where (
104+ self .slot_mapping . np [:req_indices .shape [0 ]] = np .where (
121105 mask , slot_mapping , - 1 )
122106 else :
123107 block_table_indices = (req_indices * self .max_num_blocks_per_req +
124108 positions // self .block_size )
125- block_numbers = self .block_table_np .ravel ()[block_table_indices ]
109+ block_numbers = self .block_table . np .ravel ()[block_table_indices ]
126110 block_offsets = positions % self .block_size
127111 np .add (block_numbers * self .block_size ,
128112 block_offsets ,
129- out = self .slot_mapping_np [:req_indices .shape [0 ]])
113+ out = self .slot_mapping . np [:req_indices .shape [0 ]])
130114
131115 def commit_block_table (self , num_reqs : int ) -> None :
132- self .block_table [:num_reqs ].copy_ (self .block_table_cpu [:num_reqs ],
133- non_blocking = True )
116+ self .block_table .copy_to_gpu (num_reqs )
134117
135118 def commit_slot_mapping (self , num_tokens : int ) -> None :
136- self .slot_mapping [:num_tokens ].copy_ (
137- self .slot_mapping_cpu [:num_tokens ], non_blocking = True )
119+ self .slot_mapping .copy_to_gpu (num_tokens )
138120
139121 def clear (self ) -> None :
140- self .block_table .fill_ (0 )
141- self .block_table_cpu .fill_ (0 )
122+ self .block_table .gpu . fill_ (0 )
123+ self .block_table . cpu .fill_ (0 )
142124
143- def get_device_tensor (self ) -> torch .Tensor :
125+ def get_device_tensor (self , num_reqs : int ) -> torch .Tensor :
144126 """Returns the device tensor of the block table."""
145- return self .block_table
127+ return self .block_table . gpu [: num_reqs ]
146128
147129 def get_cpu_tensor (self ) -> torch .Tensor :
148130 """Returns the CPU tensor of the block table."""
149- return self .block_table_cpu
131+ return self .block_table . cpu
150132
151133 def get_numpy_array (self ) -> np .ndarray :
152134 """Returns the numpy array of the block table."""
153- return self .block_table_np
135+ return self .block_table .np
136+
137+ def _make_buffer (self , * size : Union [int , torch .SymInt ],
138+ dtype : torch .dtype ) -> CpuGpuBuffer :
139+ return CpuGpuBuffer (* size ,
140+ dtype = dtype ,
141+ device = self .device ,
142+ pin_memory = self .pin_memory )
154143
155144
156145class MultiGroupBlockTable :
0 commit comments