-
Notifications
You must be signed in to change notification settings - Fork 31.5k
[CB] Allow block sharing in hybrid models #42877
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
57217c9
cc668d6
2349f38
65ee04a
5f1ded2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -121,7 +121,7 @@ def __init__( | |
| device: torch.device, | ||
| dtype: torch.dtype = torch.float16, | ||
| tp_size: int | None = None, | ||
| allow_prefix_sharing: bool = True, | ||
| allow_block_sharing: bool = True, | ||
| ) -> None: | ||
| """Initialize a paged attention cache for efficient memory usage. Also turns in prefix sharing if the model has | ||
| only full attention layers. | ||
|
|
@@ -132,7 +132,8 @@ def __init__( | |
| device: Device for the cache tensors | ||
| dtype: Data type of the cache | ||
| tp_size: Tensor parallelism size | ||
| allow_prefix_sharing: A flag to allow prefix sharing if the model has only full attention layers. | ||
| allow_block_sharing: A flag to allow block sharing. If the model has some full attention layers, then prefix | ||
| sharing is enabled as well. | ||
| """ | ||
| self.config = config | ||
| self.dtype = dtype | ||
|
|
@@ -220,19 +221,20 @@ def __init__( | |
| logger.info(f"{self.cache_shape = } {self.key_cache[0].shape = } {self.key_cache[0].numel() = }") | ||
|
|
||
| # Block management data structures | ||
| self.allow_block_sharing = allow_block_sharing | ||
| self.group_cache_managers: list[CacheAllocator] = [] | ||
| for i, group_type in enumerate(group_types): | ||
| if group_type == "full_attention": | ||
| cm = FullAttentionCacheAllocator(i, self.block_size) | ||
| cm = FullAttentionCacheAllocator(i, self.block_size, allow_block_sharing=allow_block_sharing) | ||
| elif group_type == "sliding_attention": | ||
| cm = SlidingAttentionCacheAllocator(i, self.block_size, config.sliding_window) | ||
| else: | ||
| raise ValueError(f"Invalid group type: {group_type}") | ||
| self.group_cache_managers.append(cm) | ||
|
|
||
| # We only use prefix sharing if the whole model has only full attention layers | ||
| self.use_prefix_sharing = allow_prefix_sharing and group_types == ["full_attention"] | ||
| self._block_manager = BlockManager(num_blocks, self.block_size, self.use_prefix_sharing) | ||
| # We only use prefix sharing if the whole model has only full attention layers and block sharing is allowed | ||
| self.use_prefix_sharing = allow_block_sharing and group_types == ["full_attention"] | ||
| self._block_manager = BlockManager(num_blocks, self.block_size) | ||
| self.blocks_to_complete: dict[str, int] = {} | ||
| self._total_prefix_length: int = 0 # a counter to measure the impact of prefix sharing, also used in tests | ||
|
|
||
|
|
@@ -352,7 +354,8 @@ def search_prefix_match(self, request_id: str, prompt_ids: list[int]) -> int: | |
| allocated_blocks = [] | ||
| for b in range(len(prompt_ids) // self.block_size): | ||
| tokens = prompt_ids[b * self.block_size : (b + 1) * self.block_size] | ||
| current_hash = self._block_manager.compute_hash(current_hash, tokens) | ||
| # Prefix sharing is only supported when there is only one full attention layer group, so group_id=0. | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is the comment still valid? thought this PR allowed different groups to exist, and only acts on 1
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, because prefix sharing is still only activated if there is only one group -- cf comment above
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. but then to mark as complete we loop on the groups no?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes because marking a block as complete is useful in the context of block sharing, which can happen in an hybrid model. But here we are in the context of prefix sharing, which is more restrictive, so we know there is only one group. Maybe I am missing something here |
||
| current_hash = self._block_manager.compute_hash(current_hash, tokens, group_id=0) | ||
| block_id = self._block_manager._hash_to_id.get(current_hash) | ||
| if block_id is not None: | ||
| allocated_blocks.append(block_id) | ||
|
|
@@ -369,18 +372,21 @@ def search_prefix_match(self, request_id: str, prompt_ids: list[int]) -> int: | |
| self._total_prefix_length += prefix_length | ||
| return prefix_length | ||
|
|
||
| def mark_blocks_as_complete(self, state: RequestState) -> None: | ||
| """Marks the blocks that have been computed in the forward pass as complete. If prefix sharing is off, this is | ||
| a no-op.""" | ||
| num_complete_blocks = 0 if not self.use_prefix_sharing else self.blocks_to_complete.pop(state.request_id) | ||
| def mark_shareable_blocks_as_complete(self, state: RequestState) -> None: | ||
| """Marks the blocks allocated to a request (state) as complete if they are shareable and they have been computed | ||
| in the forward pass. A complete block is a block where the KV cache has been fully computed: if the block has | ||
| enough space to hold the cache for N tokens, the block is marked as complete when the cache data is present for | ||
| the N tokens. If block sharing is off, this is a no-op.""" | ||
| num_complete_blocks = 0 if not self.allow_block_sharing else self.blocks_to_complete.pop(state.request_id) | ||
| if num_complete_blocks == 0: | ||
| return None | ||
| cm = self.group_cache_managers[0] # if prefix sharing is on, there is only one group | ||
| self._block_manager.mark_blocks_as_complete( | ||
| num_complete_blocks=num_complete_blocks, | ||
| allocated_blocks=cm.block_table[state.request_id], | ||
| prompt_ids=(state.initial_tokens + state.generated_tokens), | ||
| ) | ||
| for cm in self.group_cache_managers: | ||
| if cm.uses_block_sharing: | ||
remi-or marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| self._block_manager.mark_shareable_blocks_as_complete( | ||
| num_complete_blocks=num_complete_blocks, | ||
| allocated_blocks=cm.block_table[state.request_id], | ||
| prompt_ids=(state.initial_tokens + state.generated_tokens), | ||
| ) | ||
|
|
||
|
|
||
| # TODO: rework computation with the groups and their sizes | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -31,29 +31,31 @@ def reverse_enumerate(xs: list[T]) -> Iterator[tuple[int, T]]: | |
| index -= 1 | ||
|
|
||
|
|
||
| class Block: | ||
| class Block: # TODO: rename to ShareableBlock and update the docs | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. SlidingBlock(Block) are not sharable
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, so they wont create this kind of object! Hence the proposed name change. |
||
| """A class to represent a block managed by the block manager. We say that a block is complete when the physical KV | ||
| cache it points to is fully computed. A block can have a parent, which is the block that came before in the | ||
| sequence. Once a block is complete, it is given a hash, which takes into account the tokens ids of the block and | ||
| its parent's hash (if there is a parent).""" | ||
| sequence. Once a block is complete, it is given a hash, which takes into account the tokens ids of the block, the | ||
| layer (group_id) it belong to and its parent's hash (if there is a parent).""" | ||
|
|
||
| def __init__(self, id_: int, parent_id: int | None) -> None: | ||
| def __init__(self, id_: int, parent_id: int | None, group_id: int) -> None: | ||
| self.id: int = id_ | ||
| self.parent_id: int | None = parent_id | ||
| self.group_id: int = group_id | ||
| self.hash: int | None = None | ||
| self.ref_count: int = 1 | ||
|
|
||
| def __repr__(self) -> str: | ||
| return f"Block(id={self.id}, parent_id={self.parent_id}, hash={self.hash}, ref_count={self.ref_count})" | ||
| return f"Block(id={self.id}, parent_id={self.parent_id}, group_id={self.group_id}, hash={self.hash}, ref_count={self.ref_count})" | ||
|
|
||
| @property | ||
| def is_complete(self) -> bool: | ||
| return self.hash is not None | ||
|
|
||
|
|
||
| class BlockManager: | ||
| """A class to manage the number of free blocks and block re-use. If prefix sharing is off, the block manager is a | ||
| simple FIFO structure where blocks are either free or in use. If prefix sharing is on, blocks can have 3 states: | ||
| """A class to manage the number of free blocks and block re-use. When a block becomes in use, a flag is passed to | ||
| determine if the block is shareable or not. If it is, then a Block object is created and kept track of internally. | ||
| It can have the following states: | ||
| - in use: one or more requests references this block, thus it cannot be written over. The number of requests | ||
| referencing this block is stored as ref_count in the Block object. | ||
| - un-initialized: the block points to a space in the KV cache tensor that contains no data yet. Those blocks can | ||
|
|
@@ -63,19 +65,19 @@ class BlockManager: | |
| the ref_count of the block and remove it from the list of initialized blocks, because it is now in use. | ||
| Still, the block can be freed if no un-initialized blocks are left. In that case, we remove its hash from the | ||
| hash table. | ||
| If the block is not shareable, we just use the block manager as a FIFO structure where blocks are either free or in | ||
| use. Sharability is determined by the type of cache allocator: blocks created for full attention layers are | ||
| shareable, while blocks created for sliding window attention layers are not. | ||
| There is no structure to keep track of the blocks in use: if a block is neither un-initialized nor initialized, | ||
| it is in use. | ||
| """ | ||
|
|
||
| def __init__(self, num_blocks: int, block_size: int, use_prefix_sharing: bool) -> None: | ||
| """Initializes the block manager with a given number of blocks (num_blocks) of size (block_size). Prefix sharing | ||
| can be turned on with the (use_prefix_sharing) flag, which only happens if the model has only full attention | ||
| layers.""" | ||
| def __init__(self, num_blocks: int, block_size: int) -> None: | ||
| """Initializes the block manager with a given number of blocks (num_blocks) of size (block_size).""" | ||
| self.num_blocks = num_blocks | ||
| self.block_size = block_size | ||
| self._uninit_block_ids = deque(range(num_blocks)) | ||
| self._init_block_ids: dict[int, None] = {} # effectively act as an ordered set | ||
| self._use_prefix_sharing = use_prefix_sharing | ||
| self._hash_to_id: dict[int, int] = {} | ||
| self._id_to_block: dict[int, Block] = {} | ||
|
|
||
|
|
@@ -102,17 +104,20 @@ def has_enough_free_blocks(self, n_blocks: int) -> bool: | |
| self._uninit_block_ids.append(id_to_uninitialize) | ||
| return True | ||
|
|
||
| def get_free_blocks(self, n_blocks: int, last_block_id: int | None) -> list[int] | None: | ||
| """Returns a list of (n_blocks) free block and mark them as no longuer free in the internal data structures. One | ||
| can also pass a (last_block_id) to indicate the last block id in the sequence, which is used to keep track of | ||
| the parent block. If the manager cannot find enough free blocks, it returns None.""" | ||
| def get_free_blocks( | ||
| self, n_blocks: int, last_block_id: int | None, shareable: bool, group_id: int | ||
| ) -> list[int] | None: | ||
| """Returns a list of (n_blocks) free block and mark them as no longuer free in the internal data structures. | ||
| If the (shareable) flag is set to True, a Block object is created to keep track of the block, with the | ||
| (last_block_id) to indicate the last block id in the sequence, also named the parent block. If the manager | ||
| cannot find enough free blocks, it returns None.""" | ||
| if not self.has_enough_free_blocks(n_blocks): | ||
| return None | ||
| allocated_block_ids = [self._uninit_block_ids.popleft() for _ in range(n_blocks)] | ||
| # If we use prefix caching, we keep track of the allocated blocks as partial blocks | ||
| if self._use_prefix_sharing: | ||
| # If the block is shareable, we keep track of the allocated blocks as partial blocks | ||
| if shareable: | ||
|
Comment on lines
+117
to
+118
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. in general with
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The thing is, for non shareable blocks, we dont need the block object at all: we just want the physical block_id. Hence no need to create a python object that will never be used and keep track of it. So less overhead!
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok got it ! |
||
| for block_id in allocated_block_ids: | ||
| block = Block(block_id, last_block_id) | ||
| block = Block(block_id, last_block_id, group_id) | ||
| self._id_to_block[block_id] = block | ||
| last_block_id = block_id | ||
| # In both cases, we return the allocated block ids | ||
|
|
@@ -137,23 +142,23 @@ def decrease_ref_count(self, block_id: int) -> None: | |
| self._id_to_block.pop(block_id) | ||
| self._uninit_block_ids.append(block_id) | ||
|
|
||
| def free_blocks(self, blocks: list[int]) -> None: | ||
| """Marks a list of (blocks) as free. If there is no prefix sharing, we simply add them to the uninitialized | ||
| def free_blocks(self, blocks: list[int], shareable: bool) -> None: | ||
| """Marks a list of (blocks) as free. If the blocks were not (shareable), we simply add them to the uninitialized | ||
| blocks queue. Otherwise, their new state depends on whether they are complete.""" | ||
| if self._use_prefix_sharing: | ||
| if shareable: | ||
| for block_id in blocks: | ||
| self.decrease_ref_count(block_id) | ||
| else: | ||
| self._uninit_block_ids.extend(blocks) | ||
|
|
||
| def mark_blocks_as_complete( | ||
| def mark_shareable_blocks_as_complete( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. only sharable blocks can be marked as complete and only sharable blocks need this logic (same as above comment) how can we better "encapsulate" this?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess this will make more sense when Block is renamed to ShareableBlock. In that case, I think the function name makes sense. The issue is I renammed one but not the other imo
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok! |
||
| self, num_complete_blocks: int, allocated_blocks: list[int], prompt_ids: list[int] | ||
| ) -> None: | ||
| """Among the list of (allocated_blocks), mark (num_complete_blocks) incomplete blocks as now complete. The list | ||
| of (prompt_ids) is used to compute the hash of the new block.""" | ||
| # Look for the first complete block, starting from the last block in the sequence | ||
| parent_hash = None | ||
| incomplete_blocks: list[Block] = [] | ||
| incomplete_blocks: list[tuple[int, Block]] = [] | ||
| for i, block_id in reverse_enumerate(allocated_blocks): | ||
| block = self._id_to_block[block_id] | ||
| if block.is_complete: | ||
|
|
@@ -178,7 +183,7 @@ def mark_blocks_as_complete( | |
| # Otherwise, we compute the hash | ||
| num_complete_blocks -= 1 | ||
| tokens = prompt_ids[i * self.block_size : (i + 1) * self.block_size] | ||
| block.hash = self.compute_hash(parent_hash, tokens) | ||
| block.hash = self.compute_hash(parent_hash, tokens, block.group_id) | ||
|
|
||
| existing_block_id = self._hash_to_id.get(block.hash) | ||
| # If the block hash is already in the hash to id mapping, we reference the existing block instead | ||
|
|
@@ -187,19 +192,20 @@ def mark_blocks_as_complete( | |
| allocated_blocks[i] = existing_block_id | ||
| self._id_to_block[existing_block_id].ref_count += 1 | ||
| new_parent_id = existing_block_id | ||
| self.free_blocks([block.id]) | ||
| self.free_blocks([block.id], shareable=True) | ||
|
|
||
| # Otherwise, we add the completed block to the hash table | ||
| else: | ||
| logger.debug(f"Adding new block {block.id} (group {block.group_id}) with hash {block.hash}") | ||
| self._hash_to_id[block.hash] = block.id | ||
|
|
||
| # Update loop variables | ||
| parent_hash = block.hash | ||
|
|
||
| def compute_hash(self, parent_hash: int | None, tokens: list[int]) -> int: | ||
| """Computes the hash of a block containing the given (tokens) with a given (parent_hash). If the block has no | ||
| parent, the parent hash is None.""" | ||
| return hash((parent_hash, tuple(tokens))) | ||
| def compute_hash(self, parent_hash: int | None, tokens: list[int], group_id: int) -> int: | ||
| """Computes the hash of a block identified by the (tokens) it contains, its (parent_hash) and the layer | ||
| (group_id) it belong to. If the block has no parent, the parent hash is None.""" | ||
| return hash((parent_hash, tuple(tokens), group_id)) | ||
|
|
||
|
|
||
| class CacheAllocator(ABC): | ||
|
|
@@ -208,6 +214,7 @@ class CacheAllocator(ABC): | |
|
|
||
| _index: int | ||
| block_table: dict[str, list[int]] # request_id -> list of block_ids allocated to the request | ||
| uses_block_sharing: bool # flag to determine if the blocks are shareable | ||
|
|
||
| @abstractmethod | ||
| def allocate_blocks(self, n_blocks: int, request_id: str, block_manager: BlockManager) -> int | None: | ||
|
|
@@ -218,7 +225,7 @@ def free_blocks(self, request_id: str, block_manager: BlockManager) -> None: | |
| """Frees all blocks associated with a (request_id) using the (block_manager).""" | ||
| if request_id in self.block_table: | ||
| blocks_to_free = self.block_table.pop(request_id) | ||
| block_manager.free_blocks(blocks_to_free) | ||
| block_manager.free_blocks(blocks_to_free, shareable=self.uses_block_sharing) | ||
| else: | ||
| logger.warning( | ||
| f"CacheAllocator {self._index} attempted to free blocks for non-existent request_id: {request_id}" | ||
|
|
@@ -240,13 +247,14 @@ def get_seqlens_k(self, request_id: str, past_length: int, query_length: int) -> | |
| class FullAttentionCacheAllocator(CacheAllocator): | ||
| """Cache manager for a group of full attention layers.""" | ||
|
|
||
| def __init__(self, index: int, block_size: int) -> None: | ||
| def __init__(self, index: int, block_size: int, allow_block_sharing: bool) -> None: | ||
| """Initializes the cache manager for a group of full attention layers. | ||
| Args: | ||
| - index: the index of the associated layer group | ||
| - block_size: the size of the blocks in the cache | ||
| """ | ||
| self._index = index | ||
| self.uses_block_sharing = allow_block_sharing | ||
| self.block_size = block_size | ||
| self.block_table = {} | ||
|
|
||
|
|
@@ -261,7 +269,7 @@ def allocate_blocks(self, n_blocks: int, request_id: str, block_manager: BlockMa | |
| else: | ||
| last_block_id = self.block_table[request_id][-1] | ||
| # Actual allocation, return early if failed | ||
| allocated_blocks = block_manager.get_free_blocks(n_blocks, last_block_id) | ||
| allocated_blocks = block_manager.get_free_blocks(n_blocks, last_block_id, self.uses_block_sharing, self._index) | ||
| if allocated_blocks is None: | ||
| return None | ||
| self.block_table[request_id].extend(allocated_blocks) | ||
|
|
@@ -315,6 +323,7 @@ def __init__(self, index: int, block_size: int, sliding_window: int) -> None: | |
| - sliding_window: the size of the sliding window | ||
| """ | ||
| self._index = index | ||
| self.uses_block_sharing = False | ||
| self.block_size = block_size | ||
| self.sliding_window = sliding_window | ||
| self._max_blocks_per_request = ceil(self.sliding_window / self.block_size) | ||
|
|
@@ -334,7 +343,9 @@ def allocate_blocks(self, n_blocks: int, request_id: str, block_manager: BlockMa | |
| after_allocation = min(already_allocated + n_blocks, self._max_blocks_per_request) | ||
| actual_n_blocks = after_allocation - already_allocated | ||
| # Classic allocation | ||
| allocated_blocks = block_manager.get_free_blocks(actual_n_blocks, None) # no prefix caching w/ sliding window | ||
| allocated_blocks = block_manager.get_free_blocks( | ||
| actual_n_blocks, None, self.uses_block_sharing, self._index | ||
| ) # no block sharing w/ sliding window | ||
| if allocated_blocks is None: | ||
| return None | ||
| self.block_table[request_id].extend(allocated_blocks) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no rename on the self attr?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good question!
prefix_sharingis only possible if block sharing (which is more of a memory optimization) is enabled AND the model has no sliding window layers: if there are any, they will create sliding window groups with no shareable blocks, hence no prefix sharing.