Skip to content

Commit 04e78e6

Browse files
authored
[CB] Allow block sharing in hybrid models (#42877)
* Allow block sharing in hybrid architectures * nit and style * Better docstring for mark_shareable_blocks_as_complete
1 parent d7dd443 commit 04e78e6

File tree

6 files changed

+168
-94
lines changed

6 files changed

+168
-94
lines changed

benchmark_v2/framework/benchmark_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ def time_generate(
256256
if config.continuous_batching:
257257
inputs = self.inputs["input_ids"].tolist()
258258
wall_time_0 = time.perf_counter()
259-
outputs = self.model.generate_batch(inputs, allow_prefix_sharing=False, record_timestamps=True)
259+
outputs = self.model.generate_batch(inputs, allow_block_sharing=False, record_timestamps=True)
260260
else:
261261
streamer = BenchmarkStreamer()
262262
wall_time_0 = time.perf_counter()

src/transformers/generation/continuous_batching/cache.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def __init__(
121121
device: torch.device,
122122
dtype: torch.dtype = torch.float16,
123123
tp_size: int | None = None,
124-
allow_prefix_sharing: bool = True,
124+
allow_block_sharing: bool = True,
125125
) -> None:
126126
"""Initialize a paged attention cache for efficient memory usage. Also turns in prefix sharing if the model has
127127
only full attention layers.
@@ -132,7 +132,8 @@ def __init__(
132132
device: Device for the cache tensors
133133
dtype: Data type of the cache
134134
tp_size: Tensor parallelism size
135-
allow_prefix_sharing: A flag to allow prefix sharing if the model has only full attention layers.
135+
allow_block_sharing: A flag to allow block sharing. If the model has some full attention layers, then prefix
136+
sharing is enabled as well.
136137
"""
137138
self.config = config
138139
self.dtype = dtype
@@ -220,19 +221,20 @@ def __init__(
220221
logger.info(f"{self.cache_shape = } {self.key_cache[0].shape = } {self.key_cache[0].numel() = }")
221222

222223
# Block management data structures
224+
self.allow_block_sharing = allow_block_sharing
223225
self.group_cache_managers: list[CacheAllocator] = []
224226
for i, group_type in enumerate(group_types):
225227
if group_type == "full_attention":
226-
cm = FullAttentionCacheAllocator(i, self.block_size)
228+
cm = FullAttentionCacheAllocator(i, self.block_size, allow_block_sharing=allow_block_sharing)
227229
elif group_type == "sliding_attention":
228230
cm = SlidingAttentionCacheAllocator(i, self.block_size, config.sliding_window)
229231
else:
230232
raise ValueError(f"Invalid group type: {group_type}")
231233
self.group_cache_managers.append(cm)
232234

233-
# We only use prefix sharing if the whole model has only full attention layers
234-
self.use_prefix_sharing = allow_prefix_sharing and group_types == ["full_attention"]
235-
self._block_manager = BlockManager(num_blocks, self.block_size, self.use_prefix_sharing)
235+
# We only use prefix sharing if the whole model has only full attention layers and block sharing is allowed
236+
self.use_prefix_sharing = allow_block_sharing and group_types == ["full_attention"]
237+
self._block_manager = BlockManager(num_blocks, self.block_size)
236238
self.blocks_to_complete: dict[str, int] = {}
237239
self._total_prefix_length: int = 0 # a counter to measure the impact of prefix sharing, also used in tests
238240

@@ -352,7 +354,8 @@ def search_prefix_match(self, request_id: str, prompt_ids: list[int]) -> int:
352354
allocated_blocks = []
353355
for b in range(len(prompt_ids) // self.block_size):
354356
tokens = prompt_ids[b * self.block_size : (b + 1) * self.block_size]
355-
current_hash = self._block_manager.compute_hash(current_hash, tokens)
357+
# Prefix sharing is only supported when there is only one full attention layer group, so group_id=0.
358+
current_hash = self._block_manager.compute_hash(current_hash, tokens, group_id=0)
356359
block_id = self._block_manager._hash_to_id.get(current_hash)
357360
if block_id is not None:
358361
allocated_blocks.append(block_id)
@@ -369,18 +372,21 @@ def search_prefix_match(self, request_id: str, prompt_ids: list[int]) -> int:
369372
self._total_prefix_length += prefix_length
370373
return prefix_length
371374

372-
def mark_blocks_as_complete(self, state: RequestState) -> None:
373-
"""Marks the blocks that have been computed in the forward pass as complete. If prefix sharing is off, this is
374-
a no-op."""
375-
num_complete_blocks = 0 if not self.use_prefix_sharing else self.blocks_to_complete.pop(state.request_id)
375+
def mark_shareable_blocks_as_complete(self, state: RequestState) -> None:
376+
"""Marks the blocks allocated to a request (state) as complete if they are shareable and they have been computed
377+
in the forward pass. A complete block is a block where the KV cache has been fully computed: if the block has
378+
enough space to hold the cache for N tokens, the block is marked as complete when the cache data is present for
379+
the N tokens. If block sharing is off, this is a no-op."""
380+
num_complete_blocks = 0 if not self.allow_block_sharing else self.blocks_to_complete.pop(state.request_id)
376381
if num_complete_blocks == 0:
377382
return None
378-
cm = self.group_cache_managers[0] # if prefix sharing is on, there is only one group
379-
self._block_manager.mark_blocks_as_complete(
380-
num_complete_blocks=num_complete_blocks,
381-
allocated_blocks=cm.block_table[state.request_id],
382-
prompt_ids=(state.initial_tokens + state.generated_tokens),
383-
)
383+
for cm in self.group_cache_managers:
384+
if cm.uses_block_sharing:
385+
self._block_manager.mark_shareable_blocks_as_complete(
386+
num_complete_blocks=num_complete_blocks,
387+
allocated_blocks=cm.block_table[state.request_id],
388+
prompt_ids=(state.initial_tokens + state.generated_tokens),
389+
)
384390

385391

386392
# TODO: rework computation with the groups and their sizes

src/transformers/generation/continuous_batching/cache_manager.py

Lines changed: 45 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -31,29 +31,31 @@ def reverse_enumerate(xs: list[T]) -> Iterator[tuple[int, T]]:
3131
index -= 1
3232

3333

34-
class Block:
34+
class Block: # TODO: rename to ShareableBlock and update the docs
3535
"""A class to represent a block managed by the block manager. We say that a block is complete when the physical KV
3636
cache it points to is fully computed. A block can have a parent, which is the block that came before in the
37-
sequence. Once a block is complete, it is given a hash, which takes into account the tokens ids of the block and
38-
its parent's hash (if there is a parent)."""
37+
sequence. Once a block is complete, it is given a hash, which takes into account the tokens ids of the block, the
38+
layer (group_id) it belong to and its parent's hash (if there is a parent)."""
3939

40-
def __init__(self, id_: int, parent_id: int | None) -> None:
40+
def __init__(self, id_: int, parent_id: int | None, group_id: int) -> None:
4141
self.id: int = id_
4242
self.parent_id: int | None = parent_id
43+
self.group_id: int = group_id
4344
self.hash: int | None = None
4445
self.ref_count: int = 1
4546

4647
def __repr__(self) -> str:
47-
return f"Block(id={self.id}, parent_id={self.parent_id}, hash={self.hash}, ref_count={self.ref_count})"
48+
return f"Block(id={self.id}, parent_id={self.parent_id}, group_id={self.group_id}, hash={self.hash}, ref_count={self.ref_count})"
4849

4950
@property
5051
def is_complete(self) -> bool:
5152
return self.hash is not None
5253

5354

5455
class BlockManager:
55-
"""A class to manage the number of free blocks and block re-use. If prefix sharing is off, the block manager is a
56-
simple FIFO structure where blocks are either free or in use. If prefix sharing is on, blocks can have 3 states:
56+
"""A class to manage the number of free blocks and block re-use. When a block becomes in use, a flag is passed to
57+
determine if the block is shareable or not. If it is, then a Block object is created and kept track of internally.
58+
It can have the following states:
5759
- in use: one or more requests references this block, thus it cannot be written over. The number of requests
5860
referencing this block is stored as ref_count in the Block object.
5961
- un-initialized: the block points to a space in the KV cache tensor that contains no data yet. Those blocks can
@@ -63,19 +65,19 @@ class BlockManager:
6365
the ref_count of the block and remove it from the list of initialized blocks, because it is now in use.
6466
Still, the block can be freed if no un-initialized blocks are left. In that case, we remove its hash from the
6567
hash table.
68+
If the block is not shareable, we just use the block manager as a FIFO structure where blocks are either free or in
69+
use. Sharability is determined by the type of cache allocator: blocks created for full attention layers are
70+
shareable, while blocks created for sliding window attention layers are not.
6671
There is no structure to keep track of the blocks in use: if a block is neither un-initialized nor initialized,
6772
it is in use.
6873
"""
6974

70-
def __init__(self, num_blocks: int, block_size: int, use_prefix_sharing: bool) -> None:
71-
"""Initializes the block manager with a given number of blocks (num_blocks) of size (block_size). Prefix sharing
72-
can be turned on with the (use_prefix_sharing) flag, which only happens if the model has only full attention
73-
layers."""
75+
def __init__(self, num_blocks: int, block_size: int) -> None:
76+
"""Initializes the block manager with a given number of blocks (num_blocks) of size (block_size)."""
7477
self.num_blocks = num_blocks
7578
self.block_size = block_size
7679
self._uninit_block_ids = deque(range(num_blocks))
7780
self._init_block_ids: dict[int, None] = {} # effectively act as an ordered set
78-
self._use_prefix_sharing = use_prefix_sharing
7981
self._hash_to_id: dict[int, int] = {}
8082
self._id_to_block: dict[int, Block] = {}
8183

@@ -102,17 +104,20 @@ def has_enough_free_blocks(self, n_blocks: int) -> bool:
102104
self._uninit_block_ids.append(id_to_uninitialize)
103105
return True
104106

105-
def get_free_blocks(self, n_blocks: int, last_block_id: int | None) -> list[int] | None:
106-
"""Returns a list of (n_blocks) free block and mark them as no longuer free in the internal data structures. One
107-
can also pass a (last_block_id) to indicate the last block id in the sequence, which is used to keep track of
108-
the parent block. If the manager cannot find enough free blocks, it returns None."""
107+
def get_free_blocks(
108+
self, n_blocks: int, last_block_id: int | None, shareable: bool, group_id: int
109+
) -> list[int] | None:
110+
"""Returns a list of (n_blocks) free block and mark them as no longuer free in the internal data structures.
111+
If the (shareable) flag is set to True, a Block object is created to keep track of the block, with the
112+
(last_block_id) to indicate the last block id in the sequence, also named the parent block. If the manager
113+
cannot find enough free blocks, it returns None."""
109114
if not self.has_enough_free_blocks(n_blocks):
110115
return None
111116
allocated_block_ids = [self._uninit_block_ids.popleft() for _ in range(n_blocks)]
112-
# If we use prefix caching, we keep track of the allocated blocks as partial blocks
113-
if self._use_prefix_sharing:
117+
# If the block is shareable, we keep track of the allocated blocks as partial blocks
118+
if shareable:
114119
for block_id in allocated_block_ids:
115-
block = Block(block_id, last_block_id)
120+
block = Block(block_id, last_block_id, group_id)
116121
self._id_to_block[block_id] = block
117122
last_block_id = block_id
118123
# In both cases, we return the allocated block ids
@@ -137,23 +142,23 @@ def decrease_ref_count(self, block_id: int) -> None:
137142
self._id_to_block.pop(block_id)
138143
self._uninit_block_ids.append(block_id)
139144

140-
def free_blocks(self, blocks: list[int]) -> None:
141-
"""Marks a list of (blocks) as free. If there is no prefix sharing, we simply add them to the uninitialized
145+
def free_blocks(self, blocks: list[int], shareable: bool) -> None:
146+
"""Marks a list of (blocks) as free. If the blocks were not (shareable), we simply add them to the uninitialized
142147
blocks queue. Otherwise, their new state depends on whether they are complete."""
143-
if self._use_prefix_sharing:
148+
if shareable:
144149
for block_id in blocks:
145150
self.decrease_ref_count(block_id)
146151
else:
147152
self._uninit_block_ids.extend(blocks)
148153

149-
def mark_blocks_as_complete(
154+
def mark_shareable_blocks_as_complete(
150155
self, num_complete_blocks: int, allocated_blocks: list[int], prompt_ids: list[int]
151156
) -> None:
152157
"""Among the list of (allocated_blocks), mark (num_complete_blocks) incomplete blocks as now complete. The list
153158
of (prompt_ids) is used to compute the hash of the new block."""
154159
# Look for the first complete block, starting from the last block in the sequence
155160
parent_hash = None
156-
incomplete_blocks: list[Block] = []
161+
incomplete_blocks: list[tuple[int, Block]] = []
157162
for i, block_id in reverse_enumerate(allocated_blocks):
158163
block = self._id_to_block[block_id]
159164
if block.is_complete:
@@ -178,7 +183,7 @@ def mark_blocks_as_complete(
178183
# Otherwise, we compute the hash
179184
num_complete_blocks -= 1
180185
tokens = prompt_ids[i * self.block_size : (i + 1) * self.block_size]
181-
block.hash = self.compute_hash(parent_hash, tokens)
186+
block.hash = self.compute_hash(parent_hash, tokens, block.group_id)
182187

183188
existing_block_id = self._hash_to_id.get(block.hash)
184189
# If the block hash is already in the hash to id mapping, we reference the existing block instead
@@ -187,19 +192,20 @@ def mark_blocks_as_complete(
187192
allocated_blocks[i] = existing_block_id
188193
self._id_to_block[existing_block_id].ref_count += 1
189194
new_parent_id = existing_block_id
190-
self.free_blocks([block.id])
195+
self.free_blocks([block.id], shareable=True)
191196

192197
# Otherwise, we add the completed block to the hash table
193198
else:
199+
logger.debug(f"Adding new block {block.id} (group {block.group_id}) with hash {block.hash}")
194200
self._hash_to_id[block.hash] = block.id
195201

196202
# Update loop variables
197203
parent_hash = block.hash
198204

199-
def compute_hash(self, parent_hash: int | None, tokens: list[int]) -> int:
200-
"""Computes the hash of a block containing the given (tokens) with a given (parent_hash). If the block has no
201-
parent, the parent hash is None."""
202-
return hash((parent_hash, tuple(tokens)))
205+
def compute_hash(self, parent_hash: int | None, tokens: list[int], group_id: int) -> int:
206+
"""Computes the hash of a block identified by the (tokens) it contains, its (parent_hash) and the layer
207+
(group_id) it belong to. If the block has no parent, the parent hash is None."""
208+
return hash((parent_hash, tuple(tokens), group_id))
203209

204210

205211
class CacheAllocator(ABC):
@@ -208,6 +214,7 @@ class CacheAllocator(ABC):
208214

209215
_index: int
210216
block_table: dict[str, list[int]] # request_id -> list of block_ids allocated to the request
217+
uses_block_sharing: bool # flag to determine if the blocks are shareable
211218

212219
@abstractmethod
213220
def allocate_blocks(self, n_blocks: int, request_id: str, block_manager: BlockManager) -> int | None:
@@ -218,7 +225,7 @@ def free_blocks(self, request_id: str, block_manager: BlockManager) -> None:
218225
"""Frees all blocks associated with a (request_id) using the (block_manager)."""
219226
if request_id in self.block_table:
220227
blocks_to_free = self.block_table.pop(request_id)
221-
block_manager.free_blocks(blocks_to_free)
228+
block_manager.free_blocks(blocks_to_free, shareable=self.uses_block_sharing)
222229
else:
223230
logger.warning(
224231
f"CacheAllocator {self._index} attempted to free blocks for non-existent request_id: {request_id}"
@@ -240,13 +247,14 @@ def get_seqlens_k(self, request_id: str, past_length: int, query_length: int) ->
240247
class FullAttentionCacheAllocator(CacheAllocator):
241248
"""Cache manager for a group of full attention layers."""
242249

243-
def __init__(self, index: int, block_size: int) -> None:
250+
def __init__(self, index: int, block_size: int, allow_block_sharing: bool) -> None:
244251
"""Initializes the cache manager for a group of full attention layers.
245252
Args:
246253
- index: the index of the associated layer group
247254
- block_size: the size of the blocks in the cache
248255
"""
249256
self._index = index
257+
self.uses_block_sharing = allow_block_sharing
250258
self.block_size = block_size
251259
self.block_table = {}
252260

@@ -261,7 +269,7 @@ def allocate_blocks(self, n_blocks: int, request_id: str, block_manager: BlockMa
261269
else:
262270
last_block_id = self.block_table[request_id][-1]
263271
# Actual allocation, return early if failed
264-
allocated_blocks = block_manager.get_free_blocks(n_blocks, last_block_id)
272+
allocated_blocks = block_manager.get_free_blocks(n_blocks, last_block_id, self.uses_block_sharing, self._index)
265273
if allocated_blocks is None:
266274
return None
267275
self.block_table[request_id].extend(allocated_blocks)
@@ -315,6 +323,7 @@ def __init__(self, index: int, block_size: int, sliding_window: int) -> None:
315323
- sliding_window: the size of the sliding window
316324
"""
317325
self._index = index
326+
self.uses_block_sharing = False
318327
self.block_size = block_size
319328
self.sliding_window = sliding_window
320329
self._max_blocks_per_request = ceil(self.sliding_window / self.block_size)
@@ -334,7 +343,9 @@ def allocate_blocks(self, n_blocks: int, request_id: str, block_manager: BlockMa
334343
after_allocation = min(already_allocated + n_blocks, self._max_blocks_per_request)
335344
actual_n_blocks = after_allocation - already_allocated
336345
# Classic allocation
337-
allocated_blocks = block_manager.get_free_blocks(actual_n_blocks, None) # no prefix caching w/ sliding window
346+
allocated_blocks = block_manager.get_free_blocks(
347+
actual_n_blocks, None, self.uses_block_sharing, self._index
348+
) # no block sharing w/ sliding window
338349
if allocated_blocks is None:
339350
return None
340351
self.block_table[request_id].extend(allocated_blocks)

0 commit comments

Comments
 (0)