Skip to content

predict_time_share needs to be set before sending the completed webhook #1683

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 33 additions & 17 deletions python/cog/server/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,29 @@ class SetupResult:
# this is a major outstanding piece of work for merging into main


class TimeShareTracker:
def __init__(self) -> None:
self._time_shares_per_prediction: "dict[str, float]" = {}
self._last_updated_time_shares = 0.0

def update_time_shares(self) -> None:
now = time.time()
if self._time_shares_per_prediction:
elapsed = now - self._last_updated_time_shares
incurred_cost = elapsed / len(self._time_shares_per_prediction)
for prediction_id in self._time_shares_per_prediction:
self._time_shares_per_prediction[prediction_id] += incurred_cost
self._last_updated_time_shares = now

def start_tracking(self, id: str) -> None:
self.update_time_shares()
self._time_shares_per_prediction[id] = 0.0

def end_tracking(self, id: str) -> float:
self.update_time_shares()
return self._time_shares_per_prediction.pop(id)


class PredictionRunner:
def __init__(
self,
Expand Down Expand Up @@ -118,8 +141,6 @@ def __init__(

# A pipe with which to communicate with the child worker.
events, child_events = _spawn.Pipe()
self._time_shares_per_prediction: "dict[str, float]" = {}
self._last_updated_time_shares = 0.0
self._child = _ChildWorker(predictor_ref, child_events, tee_output)
self._events: "AsyncConnection[tuple[str, PublicEventType]]" = AsyncConnection(
events
Expand All @@ -132,15 +153,7 @@ def __init__(
# </worker code>
# bind logger instead of the module-level logger proxy for performance
self.log = log.bind()

def update_time_shares(self) -> None:
now = time.time()
if self._time_shares_per_prediction:
elapsed = now - self._last_updated_time_shares
incurred_cost = elapsed / len(self._time_shares_per_prediction)
for prediction_id in self._time_shares_per_prediction:
self._time_shares_per_prediction[prediction_id] += incurred_cost
self._last_updated_time_shares = now
self.time_share_tracker = TimeShareTracker() if concurrency > 1 else None

def activity_info(self) -> "dict[str, int]":
return {"max": self._concurrency, "current": len(self._predictions_in_flight)}
Expand Down Expand Up @@ -284,7 +297,7 @@ def predict(
# that breaks one of the tests, but happens Rarely in production,
# so let's ignore it for now
event_handler = PredictionEventHandler(
request, self.client_manager, upload_url, self.log
request, self.client_manager, upload_url, self.log, self.time_share_tracker
)
response = event_handler.response

Expand All @@ -302,14 +315,11 @@ async def async_predict_handling_errors() -> schema.PredictionResponse:
real_path = await v.convert(self.client_manager.download_client)
prediction_input.payload[k] = real_path
async with self._semaphore:
self.update_time_shares()
self._time_shares_per_prediction[request.id] = 0.0
if self.time_share_tracker:
self.time_share_tracker.start_tracking(request.id)
self._events.send(prediction_input)
event_stream = self._mux.read(prediction_input.id, poll=poll)
result = await event_handler.handle_event_stream(event_stream)
self.update_time_shares()
time_share = self._time_shares_per_prediction.pop(request.id)
result.metrics["predict_time_share"] = time_share
return result
except httpx.HTTPError as e:
tb = traceback.format_exc()
Expand Down Expand Up @@ -434,6 +444,7 @@ def __init__(
client_manager: ClientManager,
upload_url: Optional[str],
logger: Optional[structlog.BoundLogger] = None,
time_share_tracker: Optional[TimeShareTracker] = None,
) -> None:
self.logger = logger or log.bind()
self.logger.info("starting prediction")
Expand All @@ -452,6 +463,7 @@ def __init__(
)
self._upload_url = upload_url
self._output_type = None
self.time_share_tracker = time_share_tracker

# HACK: don't send an initial webhook if we're trying to optimize for
# latency (this guarantees that the first output webhook won't be
Expand Down Expand Up @@ -494,6 +506,10 @@ async def succeeded(self) -> None:
self.p.metrics["predict_time"] = (
self.p.completed_at - self.p.started_at
).total_seconds()
# there shouldn't be a PredictionResponse without an id, but make the types good
if self.time_share_tracker and self.p.id:
time_share = self.time_share_tracker.end_tracking(self.p.id)
self.p.metrics["predict_time_share"] = time_share
await self._send_webhook(schema.WebhookEvent.COMPLETED)

async def failed(self, error: str) -> None:
Expand Down
2 changes: 1 addition & 1 deletion python/cog/server/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,14 @@ async def read(
# janky mutable container for a single eventual ChildWorker
worker_reference: "dict[None, _ChildWorker]" = {}


def emit_metric(metric_name: str, metric_value: "float | int") -> None:
worker = worker_reference.get(None, None)
if worker is None:
raise Exception("Attempted to emit metric but worker is not running")
worker._emit_metric(metric_name, metric_value)



class _ChildWorker(_spawn.Process): # type: ignore
def __init__(
self,
Expand Down