Skip to content

call terminate after the last prediction or after a timeout if shutting down #1843

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 27 additions & 4 deletions python/cog/server/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,9 +240,12 @@ def startup() -> None:
app.state.setup_task = runner.setup()

@app.on_event("shutdown")
def shutdown() -> None:
async def shutdown() -> None:
# this will fire when Server.stop sets should_exit
# the server and hence the server thread will not exit until this completes
# so we want runner.shutdown to block until everything is good
log.info("app shutdown event has occurred")
runner.shutdown()
await runner.shutdown()

@app.get("/")
async def root() -> Any:
Expand Down Expand Up @@ -440,25 +443,45 @@ def predict(...) -> output_type:

class Server(uvicorn.Server):
def start(self) -> None:
# run is a uvicorn.Server method that runs the server
# it will keep running until server shutdown handlers complete
self._thread = threading.Thread(target=self.run)
self._thread.start()

def stop(self) -> None:
log.info("stopping server")
# https://github.com/encode/uvicorn/blob/master/uvicorn/server.py#L250-L252
# https://github.com/encode/uvicorn/discussions/1103#discussioncomment-941739
# uvicorn's loop will check should_exit to see if it will exit
# once uvicorn starts exiting, the `shutdown` event will fire
self.should_exit = True

self._thread.join(timeout=5)
if not self._thread.is_alive():
log.info("server thread is not running after join")
log.info("server has stopped gracefully, not forcing exit")
return

log.warn("failed to exit after 5 seconds, setting force_exit")
# as of uvicorn 0.30.5, force_exit does three things:
# 1. don't wait for connections to close. if force_exit becomes set
# while waiting for connections to close, uvicorn stops waiting
# https://github.com/encode/uvicorn/blob/master/uvicorn/server.py#L294-L298
# 2. don't wait for background tasks to complete.
# this respects force_exit becoming after the wait starts
# https://github.com/encode/uvicorn/blob/master/uvicorn/server.py#L300-L305
# 3. when shutdown starts, skip the shutdown event / lifecycle
# the shutdown handler is not interrupted by force_exit becoming set
# https://github.com/encode/uvicorn/blob/master/uvicorn/server.py#L289-L290
self.force_exit = True
# this join is supposed to block until the shutdown handler completes
self._thread.join(timeout=5)
if not self._thread.is_alive():
return

log.warn("failed to exit after another 5 seconds, sending SIGKILL")
# because the child is created with spawn, it won't share a process group
# so killing the parent process will orphan the child
# FIXME: should we manually kill the child?
os.kill(os.getpid(), signal.SIGKILL)


Expand Down Expand Up @@ -569,7 +592,7 @@ def _cpu_count() -> int:
shutdown_event.wait()
except KeyboardInterrupt:
pass

# this will try to shut down gracefully, then kill our process after 10s
s.stop()

# return error exit code when setup failed and cog is running in interactive mode (not k8s)
Expand Down
38 changes: 27 additions & 11 deletions python/cog/server/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,10 +198,11 @@ async def inner() -> SetupResult:
except Exception:
logs.append(traceback.format_exc())
status = schema.Status.FAILED
except BaseException:
self.log.info("caught BaseException during setup, did something go wrong?")
except asyncio.CancelledError:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this just for the log line?

self.log.info("caught CancelledError during setup")
logs.append(traceback.format_exc())
status = schema.Status.FAILED
# unclear if we should re-raise this

# fixme: handle BaseException is mux.read times out and gets cancelled

Expand Down Expand Up @@ -237,9 +238,12 @@ def handle_error(task: RunnerTask) -> None:
if self._shutdown_event is not None:
self._shutdown_event.set()
except BaseException:
self.log.error("caught base exception while running setup", exc_info=True)
self.log.error(
"caught base exception while running setup", exc_info=True
)
if self._shutdown_event is not None:
self._shutdown_event.set()
raise

result = asyncio.create_task(inner())
result.add_done_callback(handle_error)
Expand Down Expand Up @@ -340,7 +344,7 @@ async def async_predict_handling_errors() -> schema.PredictionResponse:
await event_handler.failed(error=str(e))
self.log.warn("failed to download url path from input", exc_info=True)
return event_handler.response
except Exception as e: # should this be BaseException?
except Exception as e: # should this be BaseException?
tb = traceback.format_exc()
await event_handler.append_logs(tb)
await event_handler.failed(error=str(e))
Expand Down Expand Up @@ -371,7 +375,8 @@ async def async_predict_handling_errors() -> schema.PredictionResponse:

return (response, result)

def shutdown(self) -> None:
async def shutdown(self) -> None:
# this is called by the app's shutdown handler. server won't exit until this is done
self.log.info("runner.shutdown called")
if self._state == WorkerState.DEFUNCT:
return
Expand All @@ -382,22 +387,33 @@ def shutdown(self) -> None:
self.log.info("child is alive during shutdown, sending Shutdown event")
self._events.send(Shutdown())

def terminate(self) -> None:
self.log.info("runner.terminate is called")
for _, task in self._predictions.values():
task.cancel()
if self._state == WorkerState.DEFUNCT:
self.log.info("worker state is already defunct, no need to terminate")
return

self._terminating.set()
prediction_tasks = [task for _, task in self._predictions.values()]
try:
if prediction_tasks:
await asyncio.wait(prediction_tasks, timeout=9)
# should we do this?
except TimeoutError:
self.log.warn("runner timeout while waiting for predictions to complete")

self._state = WorkerState.DEFUNCT
# in case we timed out, cancel everything
for task in prediction_tasks:
task.cancel()

# tell _read_events and Mux to exit
self._terminating.set()

if self._child.is_alive():
self._child.terminate()
self.log.info("joining child worker")
self._child.join()
# close the pipe
self._events.close()

# stop reading events from the pipe
if self._read_events_task:
self._read_events_task.cancel()

Expand Down
4 changes: 2 additions & 2 deletions python/tests/server/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ async def runner():
await runner.setup()
yield runner
finally:
runner.shutdown()
await runner.shutdown()


@pytest.mark.asyncio
Expand All @@ -58,7 +58,7 @@ async def test_prediction_runner_setup():
assert isinstance(result.started_at, datetime)
assert isinstance(result.completed_at, datetime)
finally:
runner.shutdown()
await runner.shutdown()


@pytest.mark.asyncio
Expand Down