Skip to content

Commit 6059af1

Browse files
wip
1 parent b7638b0 commit 6059af1

File tree

2 files changed

+17
-10
lines changed

2 files changed

+17
-10
lines changed

src/guidellm/scheduler/scheduler.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@ def __init__(
6464

6565
self.worker = worker
6666
self.request_loader = request_loader
67-
self.error_rate: Optional[float] = None
6867

6968
async def run(
7069
self,
@@ -118,16 +117,14 @@ async def run(
118117
if max_error_rate is not None and (max_error_rate < 0 or max_error_rate > 1):
119118
raise ValueError(f"Invalid max_error_rate: {max_error_rate}")
120119

121-
shutdown_event = multiprocessing.Event()
122-
123120
with (
124121
multiprocessing.Manager() as manager,
125122
ProcessPoolExecutor(
126123
max_workers=scheduling_strategy.processes_limit
127124
) as executor,
128125
):
129126
requests_iter: Optional[Iterator[Any]] = None
130-
futures, requests_queue, responses_queue = await self._start_processes(
127+
futures, requests_queue, responses_queue, shutdown_event = await self._start_processes(
131128
manager, executor, scheduling_strategy
132129
)
133130
run_info, requests_iter, times_iter = self._run_setup(
@@ -167,7 +164,9 @@ async def run(
167164
)
168165
if iter_result is not None:
169166
if self._is_max_error_rate_reached(iter_result.run_info):
170-
logger.info(f"Max_error rate of ({iter_result.run_info.max_error_rate}) reached!")
167+
logger.info(f"Max_error rate of ({iter_result.run_info.max_error_rate}) reached, sending "
168+
f"shutdown signal")
169+
shutdown_event.set()
171170
yield iter_result
172171

173172
# yield control to the event loop
@@ -191,8 +190,10 @@ async def _start_processes(
191190
list[asyncio.Future],
192191
multiprocessing.Queue,
193192
multiprocessing.Queue,
193+
multiprocessing.Event
194194
]:
195195
await self.worker.prepare_multiprocessing()
196+
shutdown_event = multiprocessing.Event()
196197
requests_queue = manager.Queue(
197198
maxsize=scheduling_strategy.queued_requests_limit
198199
)
@@ -229,6 +230,7 @@ async def _start_processes(
229230
requests_queue,
230231
responses_queue,
231232
id_,
233+
shutdown_event,
232234
)
233235
)
234236
elif scheduling_strategy.processing_mode == "async":
@@ -240,6 +242,7 @@ async def _start_processes(
240242
responses_queue,
241243
requests_limit,
242244
id_,
245+
shutdown_event,
243246
)
244247
)
245248
else:
@@ -250,7 +253,7 @@ async def _start_processes(
250253

251254
await asyncio.sleep(0.1) # give time for processes to start
252255

253-
return futures, requests_queue, responses_queue
256+
return futures, requests_queue, responses_queue, shutdown_event
254257

255258
def _run_setup(
256259
self,
@@ -385,8 +388,7 @@ def _check_result_ready(
385388
)
386389
raise ValueError(f"Invalid process response type: {process_response}")
387390

388-
@staticmethod
389-
def _is_max_error_rate_reached(run_info: SchedulerRunInfo) -> bool:
391+
def _is_max_error_rate_reached(self, run_info: SchedulerRunInfo) -> bool:
390392
current_error_rate = run_info.errored_requests / run_info.end_number
391393
return current_error_rate > run_info.max_error_rate
392394

src/guidellm/scheduler/worker.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,9 +121,13 @@ async def resolve(
121121
...
122122

123123
async def get_request(
124-
self, requests_queue: multiprocessing.Queue
124+
self, requests_queue: multiprocessing.Queue, shutdown_event: multiprocessing.Event, shutdonen_check_
125125
) -> Optional[WorkerProcessRequest[RequestT]]:
126-
return await asyncio.to_thread(requests_queue.get) # type: ignore[attr-defined]
126+
def _get_queue_intermittently(request_queue: multiprocessing.Queue, shutdown_event):
127+
try:
128+
129+
130+
return await asyncio.to_thread(_get_queue_intermittently()) # type: ignore[attr-defined]
127131

128132
async def send_result(
129133
self,
@@ -222,6 +226,7 @@ def process_loop_asynchronous(
222226
results_queue: multiprocessing.Queue,
223227
max_concurrency: int,
224228
process_id: int,
229+
shutdown_event: multiprocessing.Event,
225230
):
226231
async def _process_runner():
227232
pending = asyncio.Semaphore(max_concurrency)

0 commit comments

Comments
 (0)