@@ -31,29 +31,31 @@ def reverse_enumerate(xs: list[T]) -> Iterator[tuple[int, T]]:
3131 index -= 1
3232
3333
34- class Block :
34+ class Block : # TODO: rename to ShareableBlock and update the docs
3535 """A class to represent a block managed by the block manager. We say that a block is complete when the physical KV
3636 cache it points to is fully computed. A block can have a parent, which is the block that came before in the
37- sequence. Once a block is complete, it is given a hash, which takes into account the tokens ids of the block and
38- its parent's hash (if there is a parent)."""
37+ sequence. Once a block is complete, it is given a hash, which takes into account the tokens ids of the block, the
38+ layer (group_id) it belong to and its parent's hash (if there is a parent)."""
3939
40- def __init__ (self , id_ : int , parent_id : int | None ) -> None :
40+ def __init__ (self , id_ : int , parent_id : int | None , group_id : int ) -> None :
4141 self .id : int = id_
4242 self .parent_id : int | None = parent_id
43+ self .group_id : int = group_id
4344 self .hash : int | None = None
4445 self .ref_count : int = 1
4546
4647 def __repr__ (self ) -> str :
47- return f"Block(id={ self .id } , parent_id={ self .parent_id } , hash={ self .hash } , ref_count={ self .ref_count } )"
48+ return f"Block(id={ self .id } , parent_id={ self .parent_id } , group_id= { self . group_id } , hash={ self .hash } , ref_count={ self .ref_count } )"
4849
4950 @property
5051 def is_complete (self ) -> bool :
5152 return self .hash is not None
5253
5354
5455class BlockManager :
55- """A class to manage the number of free blocks and block re-use. If prefix sharing is off, the block manager is a
56- simple FIFO structure where blocks are either free or in use. If prefix sharing is on, blocks can have 3 states:
56+ """A class to manage the number of free blocks and block re-use. When a block becomes in use, a flag is passed to
57+ determine if the block is shareable or not. If it is, then a Block object is created and kept track of internally.
58+ It can have the following states:
5759 - in use: one or more requests references this block, thus it cannot be written over. The number of requests
5860 referencing this block is stored as ref_count in the Block object.
5961 - un-initialized: the block points to a space in the KV cache tensor that contains no data yet. Those blocks can
@@ -63,19 +65,19 @@ class BlockManager:
6365 the ref_count of the block and remove it from the list of initialized blocks, because it is now in use.
6466 Still, the block can be freed if no un-initialized blocks are left. In that case, we remove its hash from the
6567 hash table.
68+ If the block is not shareable, we just use the block manager as a FIFO structure where blocks are either free or in
69+ use. Sharability is determined by the type of cache allocator: blocks created for full attention layers are
70+ shareable, while blocks created for sliding window attention layers are not.
6671 There is no structure to keep track of the blocks in use: if a block is neither un-initialized nor initialized,
6772 it is in use.
6873 """
6974
70- def __init__ (self , num_blocks : int , block_size : int , use_prefix_sharing : bool ) -> None :
71- """Initializes the block manager with a given number of blocks (num_blocks) of size (block_size). Prefix sharing
72- can be turned on with the (use_prefix_sharing) flag, which only happens if the model has only full attention
73- layers."""
75+ def __init__ (self , num_blocks : int , block_size : int ) -> None :
76+ """Initializes the block manager with a given number of blocks (num_blocks) of size (block_size)."""
7477 self .num_blocks = num_blocks
7578 self .block_size = block_size
7679 self ._uninit_block_ids = deque (range (num_blocks ))
7780 self ._init_block_ids : dict [int , None ] = {} # effectively act as an ordered set
78- self ._use_prefix_sharing = use_prefix_sharing
7981 self ._hash_to_id : dict [int , int ] = {}
8082 self ._id_to_block : dict [int , Block ] = {}
8183
@@ -102,17 +104,20 @@ def has_enough_free_blocks(self, n_blocks: int) -> bool:
102104 self ._uninit_block_ids .append (id_to_uninitialize )
103105 return True
104106
105- def get_free_blocks (self , n_blocks : int , last_block_id : int | None ) -> list [int ] | None :
106- """Returns a list of (n_blocks) free block and mark them as no longuer free in the internal data structures. One
107- can also pass a (last_block_id) to indicate the last block id in the sequence, which is used to keep track of
108- the parent block. If the manager cannot find enough free blocks, it returns None."""
107+ def get_free_blocks (
108+ self , n_blocks : int , last_block_id : int | None , shareable : bool , group_id : int
109+ ) -> list [int ] | None :
110+ """Returns a list of (n_blocks) free block and mark them as no longuer free in the internal data structures.
111+ If the (shareable) flag is set to True, a Block object is created to keep track of the block, with the
112+ (last_block_id) to indicate the last block id in the sequence, also named the parent block. If the manager
113+ cannot find enough free blocks, it returns None."""
109114 if not self .has_enough_free_blocks (n_blocks ):
110115 return None
111116 allocated_block_ids = [self ._uninit_block_ids .popleft () for _ in range (n_blocks )]
112- # If we use prefix caching , we keep track of the allocated blocks as partial blocks
113- if self . _use_prefix_sharing :
117+ # If the block is shareable , we keep track of the allocated blocks as partial blocks
118+ if shareable :
114119 for block_id in allocated_block_ids :
115- block = Block (block_id , last_block_id )
120+ block = Block (block_id , last_block_id , group_id )
116121 self ._id_to_block [block_id ] = block
117122 last_block_id = block_id
118123 # In both cases, we return the allocated block ids
@@ -137,23 +142,23 @@ def decrease_ref_count(self, block_id: int) -> None:
137142 self ._id_to_block .pop (block_id )
138143 self ._uninit_block_ids .append (block_id )
139144
140- def free_blocks (self , blocks : list [int ]) -> None :
141- """Marks a list of (blocks) as free. If there is no prefix sharing , we simply add them to the uninitialized
145+ def free_blocks (self , blocks : list [int ], shareable : bool ) -> None :
146+ """Marks a list of (blocks) as free. If the blocks were not (shareable) , we simply add them to the uninitialized
142147 blocks queue. Otherwise, their new state depends on whether they are complete."""
143- if self . _use_prefix_sharing :
148+ if shareable :
144149 for block_id in blocks :
145150 self .decrease_ref_count (block_id )
146151 else :
147152 self ._uninit_block_ids .extend (blocks )
148153
149- def mark_blocks_as_complete (
154+ def mark_shareable_blocks_as_complete (
150155 self , num_complete_blocks : int , allocated_blocks : list [int ], prompt_ids : list [int ]
151156 ) -> None :
152157 """Among the list of (allocated_blocks), mark (num_complete_blocks) incomplete blocks as now complete. The list
153158 of (prompt_ids) is used to compute the hash of the new block."""
154159 # Look for the first complete block, starting from the last block in the sequence
155160 parent_hash = None
156- incomplete_blocks : list [Block ] = []
161+ incomplete_blocks : list [tuple [ int , Block ] ] = []
157162 for i , block_id in reverse_enumerate (allocated_blocks ):
158163 block = self ._id_to_block [block_id ]
159164 if block .is_complete :
@@ -178,7 +183,7 @@ def mark_blocks_as_complete(
178183 # Otherwise, we compute the hash
179184 num_complete_blocks -= 1
180185 tokens = prompt_ids [i * self .block_size : (i + 1 ) * self .block_size ]
181- block .hash = self .compute_hash (parent_hash , tokens )
186+ block .hash = self .compute_hash (parent_hash , tokens , block . group_id )
182187
183188 existing_block_id = self ._hash_to_id .get (block .hash )
184189 # If the block hash is already in the hash to id mapping, we reference the existing block instead
@@ -187,19 +192,20 @@ def mark_blocks_as_complete(
187192 allocated_blocks [i ] = existing_block_id
188193 self ._id_to_block [existing_block_id ].ref_count += 1
189194 new_parent_id = existing_block_id
190- self .free_blocks ([block .id ])
195+ self .free_blocks ([block .id ], shareable = True )
191196
192197 # Otherwise, we add the completed block to the hash table
193198 else :
199+ logger .debug (f"Adding new block { block .id } (group { block .group_id } ) with hash { block .hash } " )
194200 self ._hash_to_id [block .hash ] = block .id
195201
196202 # Update loop variables
197203 parent_hash = block .hash
198204
199- def compute_hash (self , parent_hash : int | None , tokens : list [int ]) -> int :
200- """Computes the hash of a block containing the given (tokens) with a given (parent_hash). If the block has no
201- parent, the parent hash is None."""
202- return hash ((parent_hash , tuple (tokens )))
205+ def compute_hash (self , parent_hash : int | None , tokens : list [int ], group_id : int ) -> int :
206+ """Computes the hash of a block identified by the (tokens) it contains, its (parent_hash) and the layer
207+ (group_id) it belong to. If the block has no parent, the parent hash is None."""
208+ return hash ((parent_hash , tuple (tokens ), group_id ))
203209
204210
205211class CacheAllocator (ABC ):
@@ -208,6 +214,7 @@ class CacheAllocator(ABC):
208214
209215 _index : int
210216 block_table : dict [str , list [int ]] # request_id -> list of block_ids allocated to the request
217+ uses_block_sharing : bool # flag to determine if the blocks are shareable
211218
212219 @abstractmethod
213220 def allocate_blocks (self , n_blocks : int , request_id : str , block_manager : BlockManager ) -> int | None :
@@ -218,7 +225,7 @@ def free_blocks(self, request_id: str, block_manager: BlockManager) -> None:
218225 """Frees all blocks associated with a (request_id) using the (block_manager)."""
219226 if request_id in self .block_table :
220227 blocks_to_free = self .block_table .pop (request_id )
221- block_manager .free_blocks (blocks_to_free )
228+ block_manager .free_blocks (blocks_to_free , shareable = self . uses_block_sharing )
222229 else :
223230 logger .warning (
224231 f"CacheAllocator { self ._index } attempted to free blocks for non-existent request_id: { request_id } "
@@ -240,13 +247,14 @@ def get_seqlens_k(self, request_id: str, past_length: int, query_length: int) ->
240247class FullAttentionCacheAllocator (CacheAllocator ):
241248 """Cache manager for a group of full attention layers."""
242249
243- def __init__ (self , index : int , block_size : int ) -> None :
250+ def __init__ (self , index : int , block_size : int , allow_block_sharing : bool ) -> None :
244251 """Initializes the cache manager for a group of full attention layers.
245252 Args:
246253 - index: the index of the associated layer group
247254 - block_size: the size of the blocks in the cache
248255 """
249256 self ._index = index
257+ self .uses_block_sharing = allow_block_sharing
250258 self .block_size = block_size
251259 self .block_table = {}
252260
@@ -261,7 +269,7 @@ def allocate_blocks(self, n_blocks: int, request_id: str, block_manager: BlockMa
261269 else :
262270 last_block_id = self .block_table [request_id ][- 1 ]
263271 # Actual allocation, return early if failed
264- allocated_blocks = block_manager .get_free_blocks (n_blocks , last_block_id )
272+ allocated_blocks = block_manager .get_free_blocks (n_blocks , last_block_id , self . uses_block_sharing , self . _index )
265273 if allocated_blocks is None :
266274 return None
267275 self .block_table [request_id ].extend (allocated_blocks )
@@ -315,6 +323,7 @@ def __init__(self, index: int, block_size: int, sliding_window: int) -> None:
315323 - sliding_window: the size of the sliding window
316324 """
317325 self ._index = index
326+ self .uses_block_sharing = False
318327 self .block_size = block_size
319328 self .sliding_window = sliding_window
320329 self ._max_blocks_per_request = ceil (self .sliding_window / self .block_size )
@@ -334,7 +343,9 @@ def allocate_blocks(self, n_blocks: int, request_id: str, block_manager: BlockMa
334343 after_allocation = min (already_allocated + n_blocks , self ._max_blocks_per_request )
335344 actual_n_blocks = after_allocation - already_allocated
336345 # Classic allocation
337- allocated_blocks = block_manager .get_free_blocks (actual_n_blocks , None ) # no prefix caching w/ sliding window
346+ allocated_blocks = block_manager .get_free_blocks (
347+ actual_n_blocks , None , self .uses_block_sharing , self ._index
348+ ) # no block sharing w/ sliding window
338349 if allocated_blocks is None :
339350 return None
340351 self .block_table [request_id ].extend (allocated_blocks )
0 commit comments