Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 45 additions & 23 deletions src/transformers/generation/continuous_batching/continuous_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -964,45 +964,67 @@ def cancel_request(self, request_id: str) -> None:
if self.batch_processor is not None:
self.batch_processor.scheduler.set_request_cancellation(request_id)

# TODO:handle benchmarking properly when updating / fixing the requeue logic
def get_result(self, request_id: str | None = None, timeout: float | None = None) -> GenerationOutput | None:
"""Retrieve one result from the output queue.

Args:
timeout: Maximum time to wait for a result

Returns:
Optional[GenerationOutput]: The result data or None if timeout
"""
# Fast exit: no thread + no pending output
if self._generation_thread is None and self.output_queue.empty():
return None

deadline = None if timeout is None else perf_counter() + timeout
deferred: list[GenerationOutput] = []

try:
result = self.output_queue.get(block=True, timeout=timeout)
# NOTE: requeue logic here
if request_id is not None and result.request_id != request_id:
self.output_queue.put(result)
return None
return result
except queue.Empty:
return None

while True:
remaining = None if deadline is None else max(0.0, deadline - perf_counter())
if remaining == 0.0:
return None

try:
result = self.output_queue.get(timeout=remaining)
except queue.Empty:
return None

# Match found
if request_id is None or result.request_id == request_id:
return result

# Defer mismatched result instead of immediately requeuing
deferred.append(result)

finally:
# Reinsert deferred results preserving order
for item in deferred:
self.output_queue.put(item)

def __iter__(self):
"""Iterate over results as they become available."""
while self._generation_thread is not None and self._generation_thread.is_alive():
result = self.get_result(timeout=0.1)
if result is not None:
yield result

# FIXME: stop iteration when request status is finished?
def request_id_iter(self, request_id: str) -> Generator[GenerationOutput]:
"""Iterate over results matching a specific request id as they become available."""
request_cancelled = False
while self._generation_thread is not None and self._generation_thread.is_alive() and not request_cancelled:
"""Iterate over results for a specific request until completion or cancellation."""
request_done = False

while (
not request_done
and self._generation_thread is not None
and self._generation_thread.is_alive()
):
result = self.get_result(request_id=request_id, timeout=0.1)

if result is not None:
yield result

# Stop iteration on terminal state
if result.is_finished():
request_done = True
break

# Stop if request was cancelled
if self.batch_processor is not None:
request_cancelled = self.batch_processor.scheduler.request_is_cancelled(request_id)
if self.batch_processor.scheduler.request_is_cancelled(request_id):
break

@traced
def _generation_step(self) -> None:
Expand Down