11# SPDX-License-Identifier: Apache-2.0 
22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project 
3+ from  typing  import  Union 
34
45import  numpy  as  np 
56import  torch 
67
78from  vllm .distributed  import  get_dcp_group 
89from  vllm .logger  import  init_logger 
910from  vllm .utils  import  cdiv 
11+ from  vllm .v1 .utils  import  CpuGpuBuffer 
1012
1113logger  =  init_logger (__name__ )
1214
@@ -29,28 +31,13 @@ def __init__(
2931        self .pin_memory  =  pin_memory 
3032        self .device  =  device 
3133
32-         self .block_table  =  torch .zeros (
33-             (max_num_reqs , max_num_blocks_per_req ),
34-             device = self .device ,
35-             dtype = torch .int32 ,
36-         )
37-         self .block_table_cpu  =  torch .zeros (
38-             (max_num_reqs , max_num_blocks_per_req ),
39-             device = "cpu" ,
40-             dtype = torch .int32 ,
41-             pin_memory = pin_memory ,
42-         )
43-         self .block_table_np  =  self .block_table_cpu .numpy ()
34+         self .block_table  =  self ._make_buffer (max_num_reqs ,
35+                                              max_num_blocks_per_req ,
36+                                              dtype = torch .int32 )
4437        self .num_blocks_per_row  =  np .zeros (max_num_reqs , dtype = np .int32 )
4538
46-         self .slot_mapping_cpu  =  torch .zeros (self .max_num_batched_tokens ,
47-                                             dtype = torch .int64 ,
48-                                             device = "cpu" ,
49-                                             pin_memory = self .pin_memory )
50-         self .slot_mapping_np  =  self .slot_mapping_cpu .numpy ()
51-         self .slot_mapping  =  torch .zeros (self .max_num_batched_tokens ,
52-                                         dtype = torch .int64 ,
53-                                         device = self .device )
39+         self .slot_mapping  =  self ._make_buffer (self .max_num_batched_tokens ,
40+                                               dtype = torch .int64 )
5441        try :
5542            self .dcp_world_size  =  get_dcp_group ().world_size 
5643            self .dcp_rank  =  get_dcp_group ().rank_in_group 
@@ -69,25 +56,22 @@ def append_row(
6956        num_blocks  =  len (block_ids )
7057        start  =  self .num_blocks_per_row [row_idx ]
7158        self .num_blocks_per_row [row_idx ] +=  num_blocks 
72-         self .block_table_np [row_idx , start :start  +  num_blocks ] =  block_ids 
59+         self .block_table . np [row_idx , start :start  +  num_blocks ] =  block_ids 
7360
7461    def  add_row (self , block_ids : list [int ], row_idx : int ) ->  None :
7562        self .num_blocks_per_row [row_idx ] =  0 
7663        self .append_row (block_ids , row_idx )
7764
7865    def  move_row (self , src : int , tgt : int ) ->  None :
7966        num_blocks  =  self .num_blocks_per_row [src ]
80-         self . block_table_np [ tgt , : num_blocks ]  =  self .block_table_np [ 
81-              src , :num_blocks ]
67+         block_table_np   =  self .block_table . np 
68+         block_table_np [ tgt , : num_blocks ]  =   block_table_np [ src , :num_blocks ]
8269        self .num_blocks_per_row [tgt ] =  num_blocks 
8370
8471    def  swap_row (self , src : int , tgt : int ) ->  None :
85-         num_blocks_src  =  self .num_blocks_per_row [src ]
86-         num_blocks_tgt  =  self .num_blocks_per_row [tgt ]
87-         self .num_blocks_per_row [src ] =  num_blocks_tgt 
88-         self .num_blocks_per_row [tgt ] =  num_blocks_src 
89- 
90-         self .block_table_np [[src , tgt ]] =  self .block_table_np [[tgt , src ]]
72+         src_tgt , tgt_src  =  [src , tgt ], [tgt , src ]
73+         self .num_blocks_per_row [src_tgt ] =  self .num_blocks_per_row [tgt_src ]
74+         self .block_table .np [src_tgt ] =  self .block_table .np [tgt_src ]
9175
9276    def  compute_slot_mapping (self , req_indices : np .ndarray ,
9377                             positions : np .ndarray ) ->  None :
@@ -107,7 +91,7 @@ def compute_slot_mapping(self, req_indices: np.ndarray,
10791            virtual_block_size  =  self .block_size  *  self .dcp_world_size 
10892            block_table_indices  =  (req_indices  *  self .max_num_blocks_per_req  + 
10993                                   positions  //  virtual_block_size )
110-             block_numbers  =  self .block_table_np .ravel ()[block_table_indices ]
94+             block_numbers  =  self .block_table . np .ravel ()[block_table_indices ]
11195            # Use virtual_block_size for mask calculation, which marks local 
11296            # tokens. 
11397            virtual_block_offsets  =  positions  %  virtual_block_size 
@@ -117,40 +101,45 @@ def compute_slot_mapping(self, req_indices: np.ndarray,
117101            # Calculate slot_mapping 
118102            slot_mapping  =  block_numbers  *  self .block_size  +  block_offsets 
119103            # Write final slots, use -1 for not-local 
120-             self .slot_mapping_np [:req_indices .shape [0 ]] =  np .where (
104+             self .slot_mapping . np [:req_indices .shape [0 ]] =  np .where (
121105                mask , slot_mapping , - 1 )
122106        else :
123107            block_table_indices  =  (req_indices  *  self .max_num_blocks_per_req  + 
124108                                   positions  //  self .block_size )
125-             block_numbers  =  self .block_table_np .ravel ()[block_table_indices ]
109+             block_numbers  =  self .block_table . np .ravel ()[block_table_indices ]
126110            block_offsets  =  positions  %  self .block_size 
127111            np .add (block_numbers  *  self .block_size ,
128112                   block_offsets ,
129-                    out = self .slot_mapping_np [:req_indices .shape [0 ]])
113+                    out = self .slot_mapping . np [:req_indices .shape [0 ]])
130114
131115    def  commit_block_table (self , num_reqs : int ) ->  None :
132-         self .block_table [:num_reqs ].copy_ (self .block_table_cpu [:num_reqs ],
133-                                           non_blocking = True )
116+         self .block_table .copy_to_gpu (num_reqs )
134117
135118    def  commit_slot_mapping (self , num_tokens : int ) ->  None :
136-         self .slot_mapping [:num_tokens ].copy_ (
137-             self .slot_mapping_cpu [:num_tokens ], non_blocking = True )
119+         self .slot_mapping .copy_to_gpu (num_tokens )
138120
139121    def  clear (self ) ->  None :
140-         self .block_table .fill_ (0 )
141-         self .block_table_cpu .fill_ (0 )
122+         self .block_table .gpu . fill_ (0 )
123+         self .block_table . cpu .fill_ (0 )
142124
143-     def  get_device_tensor (self ) ->  torch .Tensor :
125+     def  get_device_tensor (self ,  num_reqs :  int ) ->  torch .Tensor :
144126        """Returns the device tensor of the block table.""" 
145-         return  self .block_table 
127+         return  self .block_table . gpu [: num_reqs ] 
146128
147129    def  get_cpu_tensor (self ) ->  torch .Tensor :
148130        """Returns the CPU tensor of the block table.""" 
149-         return  self .block_table_cpu 
131+         return  self .block_table . cpu 
150132
151133    def  get_numpy_array (self ) ->  np .ndarray :
152134        """Returns the numpy array of the block table.""" 
153-         return  self .block_table_np 
135+         return  self .block_table .np 
136+ 
137+     def  _make_buffer (self , * size : Union [int , torch .SymInt ],
138+                      dtype : torch .dtype ) ->  CpuGpuBuffer :
139+         return  CpuGpuBuffer (* size ,
140+                             dtype = dtype ,
141+                             device = self .device ,
142+                             pin_memory = self .pin_memory )
154143
155144
156145class  MultiGroupBlockTable :
0 commit comments