Skip to content

optimize webhook serialization and logging #1651

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions python/cog/server/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@

import httpx
import structlog
from fastapi.encoders import jsonable_encoder

from .. import types
from ..schema import Status, WebhookEvent
from ..schema import PredictionResponse, Status, WebhookEvent
from ..types import Path
from .eventtypes import PredictionInput
from .response_throttler import ResponseThrottler
Expand Down Expand Up @@ -105,6 +106,7 @@ def __init__(self) -> None:
self.retry_webhook_client = httpx_retry_client()
self.file_client = httpx_file_client()
self.download_client = httpx.AsyncClient(follow_redirects=True, http2=True)
self.log = structlog.get_logger(__name__).bind()

async def aclose(self) -> None:
# not used but it's not actually critical to close them
Expand All @@ -119,26 +121,29 @@ async def send_webhook(
self, url: str, response: Dict[str, Any], event: WebhookEvent
) -> None:
if Status.is_terminal(response["status"]):
log.info("sending terminal webhook with status %s", response["status"])
self.log.info("sending terminal webhook with status %s", response["status"])
# For terminal updates, retry persistently
await self.retry_webhook_client.post(url, json=response)
else:
log.info("sending webhook with status %s", response["status"])
self.log.info("sending webhook with status %s", response["status"])
# For other requests, don't retry, and ignore any errors
try:
await self.webhook_client.post(url, json=response)
except httpx.RequestError:
log.warn("caught exception while sending webhook", exc_info=True)
self.log.warn("caught exception while sending webhook", exc_info=True)

def make_webhook_sender(
self, url: Optional[str], webhook_events_filter: Collection[WebhookEvent]
) -> WebhookSenderType:
throttler = ResponseThrottler(response_interval=_response_interval)

async def sender(response: Any, event: WebhookEvent) -> None:
async def sender(response: PredictionResponse, event: WebhookEvent) -> None:
if url and event in webhook_events_filter:
if throttler.should_send_response(response):
await self.send_webhook(url, response, event)
# jsonable_encoder is quite slow in context, it would be ideal
# to skip the heavy parts of this for well-known output types
dict_response = jsonable_encoder(response.dict(exclude_unset=True))
await self.send_webhook(url, dict_response, event)
throttler.update_last_sent_response_time()

return sender
Expand Down Expand Up @@ -213,6 +218,9 @@ async def upload_files(self, obj: Any, url: Optional[str]) -> Any:
Iterates through an object from make_encodeable and uploads any files.
When a file is encountered, it will be passed to upload_file. Any paths will be opened and converted to files.
"""
# skip four isinstance checks for fast text models
if type(obj) == str: # noqa: E721
return obj
# # it would be kind of cleaner to make the default file_url
# # instead of skipping entirely, we need to convert to datauri
# if url is None:
Expand Down
7 changes: 3 additions & 4 deletions python/cog/server/response_throttler.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
import time
from typing import Any, Dict

from ..schema import Status
from ..schema import PredictionResponse, Status


class ResponseThrottler:
def __init__(self, response_interval: float) -> None:
self.last_sent_response_time = 0.0
self.response_interval = response_interval

def should_send_response(self, response: Dict[str, Any]) -> bool:
if Status.is_terminal(response["status"]):
def should_send_response(self, response: PredictionResponse) -> bool:
if Status.is_terminal(response.status):
return True

return self.seconds_since_last_response() >= self.response_interval
Expand Down
25 changes: 14 additions & 11 deletions python/cog/server/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import httpx
import structlog
from attrs import define
from fastapi.encoders import jsonable_encoder

from .. import schema, types
from .clients import SKIP_START_EVENT, ClientManager
Expand Down Expand Up @@ -72,6 +71,9 @@ def __init__(

self.client_manager = ClientManager()

# bind logger instead of the module-level logger proxy for performance
self.log = log.bind()

def make_error_handler(self, activity: str) -> Callable[[RunnerTask], None]:
def handle_error(task: RunnerTask) -> None:
exc = task.exception()
Expand All @@ -83,7 +85,7 @@ def handle_error(task: RunnerTask) -> None:
try:
raise exc
except Exception:
log.error(f"caught exception while running {activity}", exc_info=True)
self.log.error(f"caught exception while running {activity}", exc_info=True)
if self._shutdown_event is not None:
self._shutdown_event.set()

Expand Down Expand Up @@ -121,7 +123,7 @@ def predict(
# if upload url was not set, we can respect output_file_prefix
# but maybe we should just throw an error
upload_url = request.output_file_prefix or self._upload_url
event_handler = PredictionEventHandler(request, self.client_manager, upload_url)
event_handler = PredictionEventHandler(request, self.client_manager, upload_url, self.log)
self._response = event_handler.response

prediction_input = PredictionInput.from_request(request)
Expand All @@ -143,13 +145,13 @@ async def async_predict_handling_errors() -> schema.PredictionResponse:
tb = traceback.format_exc()
await event_handler.append_logs(tb)
await event_handler.failed(error=str(e))
log.warn("failed to download url path from input", exc_info=True)
self.log.warn("failed to download url path from input", exc_info=True)
return event_handler.response
except Exception as e:
tb = traceback.format_exc()
await event_handler.append_logs(tb)
await event_handler.failed(error=str(e))
log.error("caught exception while running prediction", exc_info=True)
self.log.error("caught exception while running prediction", exc_info=True)
if self._shutdown_event is not None:
self._shutdown_event.set()
raise # we don't actually want to raise anymore but w/e
Expand Down Expand Up @@ -195,8 +197,10 @@ def __init__(
request: schema.PredictionRequest,
client_manager: ClientManager,
upload_url: Optional[str],
logger: Optional[structlog.BoundLogger] = None,
) -> None:
log.info("starting prediction")
self.logger = logger or log.bind()
self.logger.info("starting prediction")
# maybe this should be a deep copy to not share File state with child worker
self.p = schema.PredictionResponse(**request.dict())
self.p.status = schema.Status.PROCESSING
Expand Down Expand Up @@ -244,7 +248,7 @@ async def append_logs(self, logs: str) -> None:
await self._send_webhook(schema.WebhookEvent.LOGS)

async def succeeded(self) -> None:
log.info("prediction succeeded")
self.logger.info("prediction succeeded")
self.p.status = schema.Status.SUCCEEDED
self._set_completed_at()
# These have been set already: this is to convince the typechecker of
Expand All @@ -257,14 +261,14 @@ async def succeeded(self) -> None:
await self._send_webhook(schema.WebhookEvent.COMPLETED)

async def failed(self, error: str) -> None:
log.info("prediction failed", error=error)
self.logger.info("prediction failed", error=error)
self.p.status = schema.Status.FAILED
self.p.error = error
self._set_completed_at()
await self._send_webhook(schema.WebhookEvent.COMPLETED)

async def canceled(self) -> None:
log.info("prediction canceled")
self.logger.info("prediction canceled")
self.p.status = schema.Status.CANCELED
self._set_completed_at()
await self._send_webhook(schema.WebhookEvent.COMPLETED)
Expand All @@ -273,8 +277,7 @@ def _set_completed_at(self) -> None:
self.p.completed_at = datetime.now(tz=timezone.utc)

async def _send_webhook(self, event: schema.WebhookEvent) -> None:
dict_response = jsonable_encoder(self.response.dict(exclude_unset=True))
await self._webhook_sender(dict_response, event)
await self._webhook_sender(self.response, event)

async def _upload_files(self, output: Any) -> Any:
try:
Expand Down
21 changes: 12 additions & 9 deletions python/tests/server/test_response_throttler.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,38 @@
import time

from cog.schema import Status
from cog.schema import PredictionResponse, Status
from cog.server.response_throttler import ResponseThrottler

processing = PredictionResponse(input={}, status=Status.PROCESSING)
succeeded = PredictionResponse(input={}, status=Status.SUCCEEDED)


def test_zero_interval():
throttler = ResponseThrottler(response_interval=0)

assert throttler.should_send_response({"status": Status.PROCESSING})
assert throttler.should_send_response(processing)
throttler.update_last_sent_response_time()
assert throttler.should_send_response({"status": Status.SUCCEEDED})
assert throttler.should_send_response(succeeded)


def test_terminal_status():
throttler = ResponseThrottler(response_interval=10)

assert throttler.should_send_response({"status": Status.PROCESSING})
assert throttler.should_send_response(processing)
throttler.update_last_sent_response_time()
assert not throttler.should_send_response({"status": Status.PROCESSING})
assert not throttler.should_send_response(processing)
throttler.update_last_sent_response_time()
assert throttler.should_send_response({"status": Status.SUCCEEDED})
assert throttler.should_send_response(succeeded)


def test_nonzero_internal():
throttler = ResponseThrottler(response_interval=0.2)

assert throttler.should_send_response({"status": Status.PROCESSING})
assert throttler.should_send_response(processing)
throttler.update_last_sent_response_time()
assert not throttler.should_send_response({"status": Status.PROCESSING})
assert not throttler.should_send_response(processing)
throttler.update_last_sent_response_time()

time.sleep(0.3)

assert throttler.should_send_response({"status": Status.PROCESSING})
assert throttler.should_send_response(processing)