Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmark_v2/framework/benchmark_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def time_generate(
if config.continuous_batching:
inputs = self.inputs["input_ids"].tolist()
wall_time_0 = time.perf_counter()
outputs = self.model.generate_batch(inputs, allow_prefix_sharing=False, record_timestamps=True)
outputs = self.model.generate_batch(inputs, allow_block_sharing=False, record_timestamps=True)
else:
streamer = BenchmarkStreamer()
wall_time_0 = time.perf_counter()
Expand Down
40 changes: 23 additions & 17 deletions src/transformers/generation/continuous_batching/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def __init__(
device: torch.device,
dtype: torch.dtype = torch.float16,
tp_size: int | None = None,
allow_prefix_sharing: bool = True,
allow_block_sharing: bool = True,
) -> None:
"""Initialize a paged attention cache for efficient memory usage. Also turns in prefix sharing if the model has
only full attention layers.
Expand All @@ -132,7 +132,8 @@ def __init__(
device: Device for the cache tensors
dtype: Data type of the cache
tp_size: Tensor parallelism size
allow_prefix_sharing: A flag to allow prefix sharing if the model has only full attention layers.
allow_block_sharing: A flag to allow block sharing. If the model has some full attention layers, then prefix
sharing is enabled as well.
"""
self.config = config
self.dtype = dtype
Expand Down Expand Up @@ -220,19 +221,20 @@ def __init__(
logger.info(f"{self.cache_shape = } {self.key_cache[0].shape = } {self.key_cache[0].numel() = }")

# Block management data structures
self.allow_block_sharing = allow_block_sharing
self.group_cache_managers: list[CacheAllocator] = []
for i, group_type in enumerate(group_types):
if group_type == "full_attention":
cm = FullAttentionCacheAllocator(i, self.block_size)
cm = FullAttentionCacheAllocator(i, self.block_size, allow_block_sharing=allow_block_sharing)
elif group_type == "sliding_attention":
cm = SlidingAttentionCacheAllocator(i, self.block_size, config.sliding_window)
else:
raise ValueError(f"Invalid group type: {group_type}")
self.group_cache_managers.append(cm)

# We only use prefix sharing if the whole model has only full attention layers
self.use_prefix_sharing = allow_prefix_sharing and group_types == ["full_attention"]
self._block_manager = BlockManager(num_blocks, self.block_size, self.use_prefix_sharing)
# We only use prefix sharing if the whole model has only full attention layers and block sharing is allowed
self.use_prefix_sharing = allow_block_sharing and group_types == ["full_attention"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no rename on the self attr?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question! prefix_sharing is only possible if block sharing (which is more of a memory optimization) is enabled AND the model has no sliding window layers: if there are any, they will create sliding window groups with no shareable blocks, hence no prefix sharing.

self._block_manager = BlockManager(num_blocks, self.block_size)
self.blocks_to_complete: dict[str, int] = {}
self._total_prefix_length: int = 0 # a counter to measure the impact of prefix sharing, also used in tests

Expand Down Expand Up @@ -352,7 +354,8 @@ def search_prefix_match(self, request_id: str, prompt_ids: list[int]) -> int:
allocated_blocks = []
for b in range(len(prompt_ids) // self.block_size):
tokens = prompt_ids[b * self.block_size : (b + 1) * self.block_size]
current_hash = self._block_manager.compute_hash(current_hash, tokens)
# Prefix sharing is only supported when there is only one full attention layer group, so group_id=0.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the comment still valid? thought this PR allowed different groups to exist, and only acts on 1

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, because prefix sharing is still only activated if there is only one group -- cf comment above

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but then to mark as complete we loop on the groups no?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes because marking a block as complete is useful in the context of block sharing, which can happen in an hybrid model. But here we are in the context of prefix sharing, which is more restrictive, so we know there is only one group. Maybe I am missing something here

current_hash = self._block_manager.compute_hash(current_hash, tokens, group_id=0)
block_id = self._block_manager._hash_to_id.get(current_hash)
if block_id is not None:
allocated_blocks.append(block_id)
Expand All @@ -369,18 +372,21 @@ def search_prefix_match(self, request_id: str, prompt_ids: list[int]) -> int:
self._total_prefix_length += prefix_length
return prefix_length

def mark_blocks_as_complete(self, state: RequestState) -> None:
"""Marks the blocks that have been computed in the forward pass as complete. If prefix sharing is off, this is
a no-op."""
num_complete_blocks = 0 if not self.use_prefix_sharing else self.blocks_to_complete.pop(state.request_id)
def mark_shareable_blocks_as_complete(self, state: RequestState) -> None:
"""Marks the blocks allocated to a request (state) as complete if they are shareable and they have been computed
in the forward pass. A complete block is a block where the KV cache has been fully computed: if the block has
enough space to hold the cache for N tokens, the block is marked as complete when the cache data is present for
the N tokens. If block sharing is off, this is a no-op."""
num_complete_blocks = 0 if not self.allow_block_sharing else self.blocks_to_complete.pop(state.request_id)
if num_complete_blocks == 0:
return None
cm = self.group_cache_managers[0] # if prefix sharing is on, there is only one group
self._block_manager.mark_blocks_as_complete(
num_complete_blocks=num_complete_blocks,
allocated_blocks=cm.block_table[state.request_id],
prompt_ids=(state.initial_tokens + state.generated_tokens),
)
for cm in self.group_cache_managers:
if cm.uses_block_sharing:
self._block_manager.mark_shareable_blocks_as_complete(
num_complete_blocks=num_complete_blocks,
allocated_blocks=cm.block_table[state.request_id],
prompt_ids=(state.initial_tokens + state.generated_tokens),
)


# TODO: rework computation with the groups and their sizes
Expand Down
79 changes: 45 additions & 34 deletions src/transformers/generation/continuous_batching/cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,29 +31,31 @@ def reverse_enumerate(xs: list[T]) -> Iterator[tuple[int, T]]:
index -= 1


class Block:
class Block: # TODO: rename to ShareableBlock and update the docs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SlidingBlock(Block) are not sharable

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, so they wont create this kind of object! Hence the proposed name change.

"""A class to represent a block managed by the block manager. We say that a block is complete when the physical KV
cache it points to is fully computed. A block can have a parent, which is the block that came before in the
sequence. Once a block is complete, it is given a hash, which takes into account the tokens ids of the block and
its parent's hash (if there is a parent)."""
sequence. Once a block is complete, it is given a hash, which takes into account the tokens ids of the block, the
layer (group_id) it belong to and its parent's hash (if there is a parent)."""

def __init__(self, id_: int, parent_id: int | None) -> None:
def __init__(self, id_: int, parent_id: int | None, group_id: int) -> None:
self.id: int = id_
self.parent_id: int | None = parent_id
self.group_id: int = group_id
self.hash: int | None = None
self.ref_count: int = 1

def __repr__(self) -> str:
return f"Block(id={self.id}, parent_id={self.parent_id}, hash={self.hash}, ref_count={self.ref_count})"
return f"Block(id={self.id}, parent_id={self.parent_id}, group_id={self.group_id}, hash={self.hash}, ref_count={self.ref_count})"

@property
def is_complete(self) -> bool:
return self.hash is not None


class BlockManager:
"""A class to manage the number of free blocks and block re-use. If prefix sharing is off, the block manager is a
simple FIFO structure where blocks are either free or in use. If prefix sharing is on, blocks can have 3 states:
"""A class to manage the number of free blocks and block re-use. When a block becomes in use, a flag is passed to
determine if the block is shareable or not. If it is, then a Block object is created and kept track of internally.
It can have the following states:
- in use: one or more requests references this block, thus it cannot be written over. The number of requests
referencing this block is stored as ref_count in the Block object.
- un-initialized: the block points to a space in the KV cache tensor that contains no data yet. Those blocks can
Expand All @@ -63,19 +65,19 @@ class BlockManager:
the ref_count of the block and remove it from the list of initialized blocks, because it is now in use.
Still, the block can be freed if no un-initialized blocks are left. In that case, we remove its hash from the
hash table.
If the block is not shareable, we just use the block manager as a FIFO structure where blocks are either free or in
use. Sharability is determined by the type of cache allocator: blocks created for full attention layers are
shareable, while blocks created for sliding window attention layers are not.
There is no structure to keep track of the blocks in use: if a block is neither un-initialized nor initialized,
it is in use.
"""

def __init__(self, num_blocks: int, block_size: int, use_prefix_sharing: bool) -> None:
"""Initializes the block manager with a given number of blocks (num_blocks) of size (block_size). Prefix sharing
can be turned on with the (use_prefix_sharing) flag, which only happens if the model has only full attention
layers."""
def __init__(self, num_blocks: int, block_size: int) -> None:
"""Initializes the block manager with a given number of blocks (num_blocks) of size (block_size)."""
self.num_blocks = num_blocks
self.block_size = block_size
self._uninit_block_ids = deque(range(num_blocks))
self._init_block_ids: dict[int, None] = {} # effectively act as an ordered set
self._use_prefix_sharing = use_prefix_sharing
self._hash_to_id: dict[int, int] = {}
self._id_to_block: dict[int, Block] = {}

Expand All @@ -102,17 +104,20 @@ def has_enough_free_blocks(self, n_blocks: int) -> bool:
self._uninit_block_ids.append(id_to_uninitialize)
return True

def get_free_blocks(self, n_blocks: int, last_block_id: int | None) -> list[int] | None:
"""Returns a list of (n_blocks) free block and mark them as no longuer free in the internal data structures. One
can also pass a (last_block_id) to indicate the last block id in the sequence, which is used to keep track of
the parent block. If the manager cannot find enough free blocks, it returns None."""
def get_free_blocks(
self, n_blocks: int, last_block_id: int | None, shareable: bool, group_id: int
) -> list[int] | None:
"""Returns a list of (n_blocks) free block and mark them as no longuer free in the internal data structures.
If the (shareable) flag is set to True, a Block object is created to keep track of the block, with the
(last_block_id) to indicate the last block id in the sequence, also named the parent block. If the manager
cannot find enough free blocks, it returns None."""
if not self.has_enough_free_blocks(n_blocks):
return None
allocated_block_ids = [self._uninit_block_ids.popleft() for _ in range(n_blocks)]
# If we use prefix caching, we keep track of the allocated blocks as partial blocks
if self._use_prefix_sharing:
# If the block is shareable, we keep track of the allocated blocks as partial blocks
if shareable:
Comment on lines +117 to +118
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in general with transformers we don't want if else, we want 2 classes, 1 for sharable, on for non sharable. Splitting the logic by class usually scales better

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The thing is, for non shareable blocks, we dont need the block object at all: we just want the physical block_id. Hence no need to create a python object that will never be used and keep track of it. So less overhead!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok got it !

for block_id in allocated_block_ids:
block = Block(block_id, last_block_id)
block = Block(block_id, last_block_id, group_id)
self._id_to_block[block_id] = block
last_block_id = block_id
# In both cases, we return the allocated block ids
Expand All @@ -137,23 +142,23 @@ def decrease_ref_count(self, block_id: int) -> None:
self._id_to_block.pop(block_id)
self._uninit_block_ids.append(block_id)

def free_blocks(self, blocks: list[int]) -> None:
"""Marks a list of (blocks) as free. If there is no prefix sharing, we simply add them to the uninitialized
def free_blocks(self, blocks: list[int], shareable: bool) -> None:
"""Marks a list of (blocks) as free. If the blocks were not (shareable), we simply add them to the uninitialized
blocks queue. Otherwise, their new state depends on whether they are complete."""
if self._use_prefix_sharing:
if shareable:
for block_id in blocks:
self.decrease_ref_count(block_id)
else:
self._uninit_block_ids.extend(blocks)

def mark_blocks_as_complete(
def mark_shareable_blocks_as_complete(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only sharable blocks can be marked as complete and only sharable blocks need this logic (same as above comment) how can we better "encapsulate" this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this will make more sense when Block is renamed to ShareableBlock. In that case, I think the function name makes sense. The issue is I renammed one but not the other imo

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok!

self, num_complete_blocks: int, allocated_blocks: list[int], prompt_ids: list[int]
) -> None:
"""Among the list of (allocated_blocks), mark (num_complete_blocks) incomplete blocks as now complete. The list
of (prompt_ids) is used to compute the hash of the new block."""
# Look for the first complete block, starting from the last block in the sequence
parent_hash = None
incomplete_blocks: list[Block] = []
incomplete_blocks: list[tuple[int, Block]] = []
for i, block_id in reverse_enumerate(allocated_blocks):
block = self._id_to_block[block_id]
if block.is_complete:
Expand All @@ -178,7 +183,7 @@ def mark_blocks_as_complete(
# Otherwise, we compute the hash
num_complete_blocks -= 1
tokens = prompt_ids[i * self.block_size : (i + 1) * self.block_size]
block.hash = self.compute_hash(parent_hash, tokens)
block.hash = self.compute_hash(parent_hash, tokens, block.group_id)

existing_block_id = self._hash_to_id.get(block.hash)
# If the block hash is already in the hash to id mapping, we reference the existing block instead
Expand All @@ -187,19 +192,20 @@ def mark_blocks_as_complete(
allocated_blocks[i] = existing_block_id
self._id_to_block[existing_block_id].ref_count += 1
new_parent_id = existing_block_id
self.free_blocks([block.id])
self.free_blocks([block.id], shareable=True)

# Otherwise, we add the completed block to the hash table
else:
logger.debug(f"Adding new block {block.id} (group {block.group_id}) with hash {block.hash}")
self._hash_to_id[block.hash] = block.id

# Update loop variables
parent_hash = block.hash

def compute_hash(self, parent_hash: int | None, tokens: list[int]) -> int:
"""Computes the hash of a block containing the given (tokens) with a given (parent_hash). If the block has no
parent, the parent hash is None."""
return hash((parent_hash, tuple(tokens)))
def compute_hash(self, parent_hash: int | None, tokens: list[int], group_id: int) -> int:
"""Computes the hash of a block identified by the (tokens) it contains, its (parent_hash) and the layer
(group_id) it belong to. If the block has no parent, the parent hash is None."""
return hash((parent_hash, tuple(tokens), group_id))


class CacheAllocator(ABC):
Expand All @@ -208,6 +214,7 @@ class CacheAllocator(ABC):

_index: int
block_table: dict[str, list[int]] # request_id -> list of block_ids allocated to the request
uses_block_sharing: bool # flag to determine if the blocks are shareable

@abstractmethod
def allocate_blocks(self, n_blocks: int, request_id: str, block_manager: BlockManager) -> int | None:
Expand All @@ -218,7 +225,7 @@ def free_blocks(self, request_id: str, block_manager: BlockManager) -> None:
"""Frees all blocks associated with a (request_id) using the (block_manager)."""
if request_id in self.block_table:
blocks_to_free = self.block_table.pop(request_id)
block_manager.free_blocks(blocks_to_free)
block_manager.free_blocks(blocks_to_free, shareable=self.uses_block_sharing)
else:
logger.warning(
f"CacheAllocator {self._index} attempted to free blocks for non-existent request_id: {request_id}"
Expand All @@ -240,13 +247,14 @@ def get_seqlens_k(self, request_id: str, past_length: int, query_length: int) ->
class FullAttentionCacheAllocator(CacheAllocator):
"""Cache manager for a group of full attention layers."""

def __init__(self, index: int, block_size: int) -> None:
def __init__(self, index: int, block_size: int, allow_block_sharing: bool) -> None:
"""Initializes the cache manager for a group of full attention layers.
Args:
- index: the index of the associated layer group
- block_size: the size of the blocks in the cache
"""
self._index = index
self.uses_block_sharing = allow_block_sharing
self.block_size = block_size
self.block_table = {}

Expand All @@ -261,7 +269,7 @@ def allocate_blocks(self, n_blocks: int, request_id: str, block_manager: BlockMa
else:
last_block_id = self.block_table[request_id][-1]
# Actual allocation, return early if failed
allocated_blocks = block_manager.get_free_blocks(n_blocks, last_block_id)
allocated_blocks = block_manager.get_free_blocks(n_blocks, last_block_id, self.uses_block_sharing, self._index)
if allocated_blocks is None:
return None
self.block_table[request_id].extend(allocated_blocks)
Expand Down Expand Up @@ -315,6 +323,7 @@ def __init__(self, index: int, block_size: int, sliding_window: int) -> None:
- sliding_window: the size of the sliding window
"""
self._index = index
self.uses_block_sharing = False
self.block_size = block_size
self.sliding_window = sliding_window
self._max_blocks_per_request = ceil(self.sliding_window / self.block_size)
Expand All @@ -334,7 +343,9 @@ def allocate_blocks(self, n_blocks: int, request_id: str, block_manager: BlockMa
after_allocation = min(already_allocated + n_blocks, self._max_blocks_per_request)
actual_n_blocks = after_allocation - already_allocated
# Classic allocation
allocated_blocks = block_manager.get_free_blocks(actual_n_blocks, None) # no prefix caching w/ sliding window
allocated_blocks = block_manager.get_free_blocks(
actual_n_blocks, None, self.uses_block_sharing, self._index
) # no block sharing w/ sliding window
if allocated_blocks is None:
return None
self.block_table[request_id].extend(allocated_blocks)
Expand Down
Loading