Skip to content

Commit cc668d6

Browse files
committed
nit and style
1 parent 57217c9 commit cc668d6

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

src/transformers/generation/continuous_batching/cache_manager.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def reverse_enumerate(xs: list[T]) -> Iterator[tuple[int, T]]:
3434
class Block: # TODO: rename to ShareableBlock and update the docs
3535
"""A class to represent a block managed by the block manager. We say that a block is complete when the physical KV
3636
cache it points to is fully computed. A block can have a parent, which is the block that came before in the
37-
sequence. Once a block is complete, it is given a hash, which takes into account the tokens ids of the block, the
37+
sequence. Once a block is complete, it is given a hash, which takes into account the tokens ids of the block, the
3838
layer (group_id) it belong to and its parent's hash (if there is a parent)."""
3939

4040
def __init__(self, id_: int, parent_id: int | None, group_id: int) -> None:
@@ -343,7 +343,9 @@ def allocate_blocks(self, n_blocks: int, request_id: str, block_manager: BlockMa
343343
after_allocation = min(already_allocated + n_blocks, self._max_blocks_per_request)
344344
actual_n_blocks = after_allocation - already_allocated
345345
# Classic allocation
346-
allocated_blocks = block_manager.get_free_blocks(actual_n_blocks, None, self.uses_block_sharing, self._index) # no block sharing w/ sliding window
346+
allocated_blocks = block_manager.get_free_blocks(
347+
actual_n_blocks, None, self.uses_block_sharing, self._index
348+
) # no block sharing w/ sliding window
347349
if allocated_blocks is None:
348350
return None
349351
self.block_table[request_id].extend(allocated_blocks)

tests/generation/test_continuous_batching.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def _test_continuous_batching_parity(
229229

230230
# Generation with continuous batching
231231
continuous_batching_outputs = model.generate_batch(
232-
inputs=input_ids, generation_config=model.generation_config, allow_prefix_sharing=allow_block_sharing
232+
inputs=input_ids, generation_config=model.generation_config, allow_block_sharing=allow_block_sharing
233233
)
234234

235235
# Prepare non-continuous batching inputs

0 commit comments

Comments
 (0)