Skip to content

async runner #1352

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ optional-dependencies = { "dev" = [
'numpy; python_version >= "3.8"',
"pillow",
"pytest",
"pytest-asyncio",
"pytest-httpserver",
"pytest-rerunfailures",
"pytest-xdist",
Expand Down
18 changes: 10 additions & 8 deletions python/cog/server/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ async def root() -> Any:

@app.get("/health-check")
async def healthcheck() -> Any:
_check_setup_result()
await _check_setup_result()
if app.state.health == Health.READY:
health = Health.BUSY if runner.is_busy() else Health.READY
else:
Expand Down Expand Up @@ -137,7 +137,7 @@ async def predict(request: PredictionRequest = Body(default=None), prefer: Union
# TODO: spec-compliant parsing of Prefer header.
respond_async = prefer == "respond-async"

return _predict(request=request, respond_async=respond_async)
return await _predict(request=request, respond_async=respond_async)

@limited
@app.put(
Expand Down Expand Up @@ -172,9 +172,9 @@ async def predict_idempotent(
# TODO: spec-compliant parsing of Prefer header.
respond_async = prefer == "respond-async"

return _predict(request=request, respond_async=respond_async)
return await _predict(request=request, respond_async=respond_async)

def _predict(
async def _predict(
*, request: PredictionRequest, respond_async: bool = False
) -> Response:
# [compat] If no body is supplied, assume that this model can be run
Expand Down Expand Up @@ -203,7 +203,8 @@ def _predict(
return JSONResponse(jsonable_encoder(initial_response), status_code=202)

try:
response = PredictionResponse(**async_result.get().dict())
prediction = await async_result
response = PredictionResponse(**prediction.dict())
except ValidationError as e:
_log_invalid_output(e)
raise HTTPException(status_code=500, detail=str(e)) from e
Expand Down Expand Up @@ -239,14 +240,15 @@ async def start_shutdown() -> Any:
shutdown_event.set()
return JSONResponse({}, status_code=200)

def _check_setup_result() -> Any:
async def _check_setup_result() -> Any:
if app.state.setup_result is None:
return

if not app.state.setup_result.ready():
if not app.state.setup_result.done():
return

result = app.state.setup_result.get()
# this can raise CancelledError
result = app.state.setup_result.result()

if result["status"] == schema.Status.SUCCEEDED:
app.state.health = Health.READY
Expand Down
95 changes: 42 additions & 53 deletions python/cog/server/runner.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import asyncio
import io
import threading
import traceback
from asyncio import Task
from datetime import datetime, timezone
from multiprocessing.pool import AsyncResult, ThreadPool
from typing import Any, Callable, Dict, Optional, Tuple

import requests
Expand Down Expand Up @@ -39,48 +39,47 @@ def __init__(
self,
*,
predictor_ref: str,
shutdown_event: Optional[threading.Event],
shutdown_event: Optional[asyncio.Event],
upload_url: Optional[str] = None,
) -> None:
self._thread = None
self._threadpool = ThreadPool(processes=1)

self._response: Optional[schema.PredictionResponse] = None
self._result: Optional[AsyncResult] = None
self._result: Optional[Task] = None

self._worker = Worker(predictor_ref=predictor_ref)
self._should_cancel = threading.Event()
self._should_cancel = asyncio.Event()

self._shutdown_event = shutdown_event
self._upload_url = upload_url

def setup(self) -> AsyncResult:
if self.is_busy():
raise RunnerBusyError()

def handle_error(error: BaseException) -> None:
def make_error_handler(self, activity: str) -> Callable:
def handle_error(task: Task) -> None:
exc = task.exception()
if not exc:
return
# Re-raise the exception in order to more easily capture exc_info,
# and then trigger shutdown, as we have no easy way to resume
# worker state if an exception was thrown.
try:
raise error
raise exc
except Exception:
log.error("caught exception while running setup", exc_info=True)
log.error(f"caught exception while running {activity}", exc_info=True)
if self._shutdown_event is not None:
self._shutdown_event.set()

self._result = self._threadpool.apply_async(
func=setup,
kwds={"worker": self._worker},
error_callback=handle_error,
)
return handle_error

def setup(self) -> "Task[dict[str, Any]]":
if self.is_busy():
raise RunnerBusyError()
self._result = asyncio.create_task(setup(worker=self._worker))
self._result.add_done_callback(self.make_error_handler("setup"))
return self._result

# TODO: Make the return type AsyncResult[schema.PredictionResponse] when we
# no longer have to support Python 3.8
def predict(
self, prediction: schema.PredictionRequest, upload: bool = True
) -> Tuple[schema.PredictionResponse, AsyncResult]:
) -> Tuple[schema.PredictionResponse, "Task[schema.PredictionResponse]"]:
# It's the caller's responsibility to not call us if we're busy.
if self.is_busy():
# If self._result is set, but self._response is not, we're still
Expand All @@ -101,51 +100,38 @@ def predict(
upload_url = self._upload_url if upload else None
event_handler = create_event_handler(prediction, upload_url=upload_url)

def cleanup(_: Optional[Any] = None) -> None:
def handle_cleanup(_: Task) -> None:
if hasattr(prediction.input, "cleanup"):
prediction.input.cleanup()

def handle_error(error: BaseException) -> None:
# Re-raise the exception in order to more easily capture exc_info,
# and then trigger shutdown, as we have no easy way to resume
# worker state if an exception was thrown.
try:
raise error
except Exception:
log.error("caught exception while running prediction", exc_info=True)
if self._shutdown_event is not None:
self._shutdown_event.set()

self._response = event_handler.response
self._result = self._threadpool.apply_async(
func=predict,
kwds={
"worker": self._worker,
"request": prediction,
"event_handler": event_handler,
"should_cancel": self._should_cancel,
},
callback=cleanup,
error_callback=handle_error,
coro = predict(
worker=self._worker,
request=prediction,
event_handler=event_handler,
should_cancel=self._should_cancel,
)
self._result = asyncio.create_task(coro)
self._result.add_done_callback(handle_cleanup)
self._result.add_done_callback(self.make_error_handler("prediction"))

return (self._response, self._result)

def is_busy(self) -> bool:
if self._result is None:
return False

if not self._result.ready():
if not self._result.done():
return True

self._response = None
self._result = None
return False

def shutdown(self) -> None:
if self._result:
self._result.cancel()
self._worker.terminate()
self._threadpool.terminate()
self._threadpool.join()

def cancel(self, prediction_id: Optional[str] = None) -> None:
if not self.is_busy():
Expand Down Expand Up @@ -287,13 +273,15 @@ def _upload_files(self, output: Any) -> Any:
raise FileUploadError("Got error trying to upload output files") from error


def setup(*, worker: Worker) -> Dict[str, Any]:
async def setup(*, worker: Worker) -> Dict[str, Any]:
logs = []
status = None
started_at = datetime.now(tz=timezone.utc)

try:
# will be async
for event in worker.setup():
await asyncio.sleep(0)
if isinstance(event, Log):
logs.append(event.message)
elif isinstance(event, Done):
Expand Down Expand Up @@ -323,19 +311,19 @@ def setup(*, worker: Worker) -> Dict[str, Any]:
}


def predict(
async def predict(
*,
worker: Worker,
request: schema.PredictionRequest,
event_handler: PredictionEventHandler,
should_cancel: threading.Event,
should_cancel: asyncio.Event,
) -> schema.PredictionResponse:
# Set up logger context within prediction thread.
structlog.contextvars.clear_contextvars()
structlog.contextvars.bind_contextvars(prediction_id=request.id)

try:
return _predict(
return await _predict(
worker=worker,
request=request,
event_handler=event_handler,
Expand All @@ -348,12 +336,12 @@ def predict(
raise


def _predict(
async def _predict(
*,
worker: Worker,
request: schema.PredictionRequest,
event_handler: PredictionEventHandler,
should_cancel: threading.Event,
should_cancel: asyncio.Event,
) -> schema.PredictionResponse:
initial_prediction = request.dict()

Expand All @@ -370,8 +358,9 @@ def _predict(
event_handler.failed(error=str(e))
log.warn("failed to download url path from input", exc_info=True)
return event_handler.response

# will be async
for event in worker.predict(input_dict, poll=0.1):
await asyncio.sleep(0)
if should_cancel.is_set():
worker.cancel()
should_cancel.clear()
Expand Down
3 changes: 2 additions & 1 deletion python/cog/server/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ def _wait(
if send_heartbeats:
yield Heartbeat()
continue

# this needs aioprocessing.Pipe or similar
# multiprocessing.Pipe is not async
ev = self._events.recv()
yield ev

Expand Down
Loading