Skip to content

Commit a2a74aa

Browse files
committed
Style
1 parent 28f5476 commit a2a74aa

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

src/transformers/generation/continuous_batching/cache_manager.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,9 @@ def get_write_indices(self, request_id: str, past_length: int, query_length: int
264264
def get_seqlens_k(self, request_id: str, past_length: int, query_length: int) -> tuple[str, int]:
265265
"""Returns the attention type of the cache allocator and the key sequence length for the given request_id."""
266266

267-
def fork_blocks(self, source_request_id: str, dst_request_id: str, block_manager: BlockManager) -> tuple[list[int], list[int]]:
267+
def fork_blocks(
268+
self, source_request_id: str, dst_request_id: str, block_manager: BlockManager
269+
) -> tuple[list[int], list[int]]:
268270
"""Fork the cache blocks for a given request_id into a new request_id."""
269271
if source_request_id not in self.block_table:
270272
raise ValueError(f"No block table found for request {source_request_id}")
@@ -283,6 +285,7 @@ def fork_blocks(self, source_request_id: str, dst_request_id: str, block_manager
283285
self.block_table[dst_request_id] = forked_blocks
284286
return source_blocks, forked_blocks
285287

288+
286289
class FullAttentionCacheAllocator(CacheAllocator):
287290
"""Cache manager for a group of full attention layers."""
288291

src/transformers/generation/continuous_batching/continuous_api.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -576,7 +576,6 @@ def update_batch(self) -> None:
576576
for state in self.requests_in_batch:
577577
# If the request has no remaining prompt ids, it means prefill has already ended or just finished
578578
if len(state.remaining_prefill_tokens) == 0:
579-
580579
# If there are no generated tokens yet, it means prefill just ended
581580
if state.generated_len() == 0:
582581
self.metrics.record_ttft_metric(state.created_time, state.request_id)

0 commit comments

Comments
 (0)