Skip to content

syl/fix setup shutdown bug #1819

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion python/cog/server/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,8 @@ async def wrapped(*args: "P.args", **kwargs: "P.kwargs") -> "T":

return wrapped

if "train" in config:
# if train is set but null/blank, don't do training
if config.get("train"):
try:
# TODO: avoid loading trainer code in this process
trainer = load_predictor_from_ref(config["train"])
Expand Down Expand Up @@ -240,6 +241,7 @@ def startup() -> None:

@app.on_event("shutdown")
def shutdown() -> None:
log.info("app shutdown event has occurred")
runner.shutdown()

@app.get("/")
Expand Down Expand Up @@ -447,6 +449,7 @@ def stop(self) -> None:

self._thread.join(timeout=5)
if not self._thread.is_alive():
log.info("server thread is not running after join")
return

log.warn("failed to exit after 5 seconds, setting force_exit")
Expand Down
18 changes: 17 additions & 1 deletion python/cog/server/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,12 @@ async def inner() -> SetupResult:
except Exception:
logs.append(traceback.format_exc())
status = schema.Status.FAILED
except BaseException:
self.log("caught BaseException during setup, did something go wrong?")
logs.append(traceback.format_exc())
status = schema.Status.FAILED

# fixme: handle BaseException is mux.read times out and gets cancelled

if status is None:
logs.append("Error: did not receive 'done' event from setup!")
Expand Down Expand Up @@ -230,6 +236,10 @@ def handle_error(task: RunnerTask) -> None:
self.log.error("caught exception while running setup", exc_info=True)
if self._shutdown_event is not None:
self._shutdown_event.set()
except BaseException:
self.log.error("caught base exception while running setup", exc_info=True)
if self._shutdown_event is not None:
self._shutdown_event.set()

result = asyncio.create_task(inner())
result.add_done_callback(handle_error)
Expand Down Expand Up @@ -330,14 +340,15 @@ async def async_predict_handling_errors() -> schema.PredictionResponse:
await event_handler.failed(error=str(e))
self.log.warn("failed to download url path from input", exc_info=True)
return event_handler.response
except Exception as e:
except Exception as e: # should this be BaseException?
tb = traceback.format_exc()
await event_handler.append_logs(tb)
await event_handler.failed(error=str(e))
self.log.error(
"caught exception while running prediction", exc_info=True
)
if self._shutdown_event is not None:
self.log.info("setting shutdown_event")
self._shutdown_event.set()
raise # we don't actually want to raise anymore but w/e
finally:
Expand All @@ -361,15 +372,18 @@ async def async_predict_handling_errors() -> schema.PredictionResponse:
return (response, result)

def shutdown(self) -> None:
self.log.info("runner.shutdown called")
if self._state == WorkerState.DEFUNCT:
return
# shutdown requested, but keep reading events
self._shutting_down = True

if self._child.is_alive():
self.log.info("child is alive during shutdown, sending Shutdown event")
self._events.send(Shutdown())

def terminate(self) -> None:
self.log.info("runner.terminate is called")
for _, task in self._predictions.values():
task.cancel()
if self._state == WorkerState.DEFUNCT:
Expand All @@ -380,6 +394,7 @@ def terminate(self) -> None:

if self._child.is_alive():
self._child.terminate()
self.log.info("joining child worker")
self._child.join()
self._events.close()

Expand Down Expand Up @@ -438,6 +453,7 @@ async def _read_events(self) -> None:
# this is the same event as self._terminating
# we need to set it so mux.reads wake up and throw an error if needed
self._mux.terminating.set()
self.log.info("exited _read_events")


class PredictionEventHandler:
Expand Down
4 changes: 4 additions & 0 deletions python/cog/server/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ def _setup(self) -> None:
# Could be a function or a class
if hasattr(self._predictor, "setup"):
if inspect.iscoroutinefunction(self._predictor.setup):
# we should probably handle Shutdown during this process?
self.loop.run_until_complete(run_setup_async(self._predictor))
else:
run_setup(self._predictor)
Expand Down Expand Up @@ -188,6 +189,7 @@ def _loop_sync(self) -> None:
while True:
ev = self._events.recv()
if isinstance(ev, Shutdown):
self._log("got Shutdown event")
break
if isinstance(ev, PredictionInput):
self._predict_sync(ev)
Expand All @@ -210,6 +212,7 @@ async def _loop_async(self) -> None:
except asyncio.CancelledError:
return
if isinstance(ev, Shutdown):
self._log("got shutdown event [async]")
return
if isinstance(ev, PredictionInput):
# keep track of these so they can be cancelled
Expand Down Expand Up @@ -295,6 +298,7 @@ def _predict_sync(self, input: PredictionInput) -> None:
send(PredictionOutput(payload=make_encodeable(result)))

def _signal_handler(self, signum: int, frame: Optional[types.FrameType]) -> None:
# perhaps we should handle shutdown during setup using a signal?
if self._predictor and is_async(get_predict(self._predictor)):
# we could try also canceling the async task around here
# but for now in async mode signals are ignored
Expand Down