Skip to content

Commit 8b61718

Browse files
committed
Propagate trace context to webhook and upload requests
If the request to /predict contains headers `traceparent` and `tracestate` defined by w3c Trace Context[1] then these headers are forwarded on to the webhook and upload calls. This allows observability systems to link requests passing through cog. [1]: https://www.w3.org/TR/trace-context/ Signed-off-by: Aron Carroll <aron@replicate.com>
1 parent 5d7493f commit 8b61718

File tree

4 files changed

+139
-17
lines changed

4 files changed

+139
-17
lines changed

python/cog/server/http.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,8 @@ async def healthcheck() -> Any:
260260
async def predict(
261261
request: PredictionRequest = Body(default=None),
262262
prefer: Union[str, None] = Header(default=None),
263+
traceparent: Union[str, None] = Header(default=None),
264+
tracestate: Union[str, None] = Header(default=None),
263265
) -> Any: # type: ignore
264266
"""
265267
Run a single prediction on the model
@@ -272,7 +274,12 @@ async def predict(
272274
# TODO: spec-compliant parsing of Prefer header.
273275
respond_async = prefer == "respond-async"
274276

275-
return _predict(request=request, respond_async=respond_async)
277+
return _predict(
278+
request=request,
279+
respond_async=respond_async,
280+
traceparent=traceparent,
281+
tracestate=tracestate,
282+
)
276283

277284
@limited
278285
@app.put(
@@ -284,6 +291,8 @@ async def predict_idempotent(
284291
prediction_id: str = Path(..., title="Prediction ID"),
285292
request: PredictionRequest = Body(..., title="Prediction Request"),
286293
prefer: Union[str, None] = Header(default=None),
294+
traceparent: Union[str, None] = Header(default=None),
295+
tracestate: Union[str, None] = Header(default=None),
287296
) -> Any:
288297
"""
289298
Run a single prediction on the model (idempotent creation).
@@ -307,10 +316,19 @@ async def predict_idempotent(
307316
# TODO: spec-compliant parsing of Prefer header.
308317
respond_async = prefer == "respond-async"
309318

310-
return _predict(request=request, respond_async=respond_async)
319+
return _predict(
320+
request=request,
321+
respond_async=respond_async,
322+
traceparent=traceparent,
323+
tracestate=tracestate,
324+
)
311325

312326
def _predict(
313-
*, request: PredictionRequest, respond_async: bool = False
327+
*,
328+
request: PredictionRequest,
329+
respond_async: bool = False,
330+
traceparent: Optional[str] = None,
331+
tracestate: Optional[str] = None,
314332
) -> Response:
315333
# [compat] If no body is supplied, assume that this model can be run
316334
# with empty input. This will throw a ValidationError if that's not
@@ -327,7 +345,10 @@ def _predict(
327345
# async predictions. This is unfortunate but required to ensure
328346
# backwards-compatible behaviour for synchronous predictions.
329347
initial_response, async_result = runner.predict(
330-
request, upload=respond_async
348+
request,
349+
upload=respond_async,
350+
traceparent=traceparent,
351+
tracestate=tracestate,
331352
)
332353
except RunnerBusyError:
333354
return JSONResponse(

python/cog/server/runner.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,11 @@ def handle_error(error: BaseException) -> None:
9898
# TODO: Make the return type AsyncResult[schema.PredictionResponse] when we
9999
# no longer have to support Python 3.8
100100
def predict(
101-
self, prediction: schema.PredictionRequest, upload: bool = True
101+
self,
102+
prediction: schema.PredictionRequest,
103+
upload: bool = True,
104+
traceparent: Optional[str] = None,
105+
tracestate: Optional[str] = None,
102106
) -> Tuple[schema.PredictionResponse, PredictionTask]:
103107
# It's the caller's responsibility to not call us if we're busy.
104108
if self.is_busy():
@@ -119,9 +123,14 @@ def predict(
119123

120124
self._should_cancel.clear()
121125
upload_url = self._upload_url if upload else None
122-
event_handler = create_event_handler(prediction, upload_url=upload_url)
126+
event_handler = create_event_handler(
127+
prediction,
128+
upload_url=upload_url,
129+
traceparent=traceparent,
130+
tracestate=tracestate,
131+
)
123132

124-
def cleanup(_: schema.PredictionResponse = None) -> None:
133+
def cleanup(_: Optional[schema.PredictionResponse] = None) -> None:
125134
input = cast(Any, prediction.input)
126135
if hasattr(input, "cleanup"):
127136
input.cleanup()
@@ -178,7 +187,10 @@ def cancel(self, prediction_id: Optional[str] = None) -> None:
178187

179188

180189
def create_event_handler(
181-
prediction: schema.PredictionRequest, upload_url: Optional[str] = None
190+
prediction: schema.PredictionRequest,
191+
upload_url: Optional[str] = None,
192+
traceparent: Optional[str] = None,
193+
tracestate: Optional[str] = None,
182194
) -> "PredictionEventHandler":
183195
response = schema.PredictionResponse(**prediction.dict())
184196

@@ -189,7 +201,9 @@ def create_event_handler(
189201

190202
webhook_sender = None
191203
if webhook is not None:
192-
webhook_sender = webhook_caller_filtered(webhook, set(events_filter))
204+
webhook_sender = webhook_caller_filtered(
205+
webhook, set(events_filter), traceparent=traceparent, tracestate=tracestate
206+
)
193207

194208
file_uploader = None
195209
if upload_url is not None:

python/cog/server/webhook.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import os
2-
from typing import Any, Callable, Set
2+
from typing import Any, Callable, Set, Optional
33

44
import requests
55
import structlog
@@ -39,8 +39,12 @@ def _get_version() -> str:
3939
def webhook_caller_filtered(
4040
webhook: str,
4141
webhook_events_filter: Set[WebhookEvent],
42+
traceparent: Optional[str] = None,
43+
tracestate: Optional[str] = None,
4244
) -> Callable[[Any, WebhookEvent], None]:
43-
upstream_caller = webhook_caller(webhook)
45+
upstream_caller = webhook_caller(
46+
webhook, traceparent=traceparent, tracestate=tracestate
47+
)
4448

4549
def caller(response: Any, event: WebhookEvent) -> None:
4650
if event in webhook_events_filter:
@@ -49,13 +53,17 @@ def caller(response: Any, event: WebhookEvent) -> None:
4953
return caller
5054

5155

52-
def webhook_caller(webhook: str) -> Callable[[Any], None]:
56+
def webhook_caller(
57+
webhook: str, traceparent: Optional[str] = None, tracestate: Optional[str] = None
58+
) -> Callable[[Any], None]:
5359
# TODO: we probably don't need to create new sessions and new throttlers
5460
# for every prediction.
5561
throttler = ResponseThrottler(response_interval=_response_interval)
5662

57-
default_session = requests_session()
58-
retry_session = requests_session_with_retries()
63+
default_session = requests_session(traceparent=traceparent, tracestate=tracestate)
64+
retry_session = requests_session_with_retries(
65+
traceparent=traceparent, tracestate=tracestate
66+
)
5967

6068
def caller(response: Any) -> None:
6169
if throttler.should_send_response(response):
@@ -73,23 +81,33 @@ def caller(response: Any) -> None:
7381
return caller
7482

7583

76-
def requests_session() -> requests.Session:
84+
def requests_session(
85+
traceparent: Optional[str | None] = None, tracestate: Optional[str | None] = None
86+
) -> requests.Session:
7787
session = requests.Session()
7888
session.headers["user-agent"] = (
7989
_user_agent + " " + str(session.headers["user-agent"])
8090
)
91+
92+
if traceparent:
93+
session.headers["traceparent"] = traceparent
94+
if tracestate:
95+
session.headers["tracestate"] = tracestate
96+
8197
auth_token = os.environ.get("WEBHOOK_AUTH_TOKEN")
8298
if auth_token:
8399
session.headers["authorization"] = "Bearer " + auth_token
84100

85101
return session
86102

87103

88-
def requests_session_with_retries() -> requests.Session:
104+
def requests_session_with_retries(
105+
traceparent: Optional[str | None] = None, tracestate: Optional[str | None] = None
106+
) -> requests.Session:
89107
# This session will retry requests up to 12 times, with exponential
90108
# backoff. In total it'll try for up to roughly 320 seconds, providing
91109
# resilience through temporary networking and availability issues.
92-
session = requests_session()
110+
session = requests_session(traceparent=traceparent, tracestate=tracestate)
93111
adapter = HTTPAdapter(
94112
max_retries=Retry(
95113
total=12,

python/tests/server/test_http.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -531,6 +531,75 @@ def test_asynchronous_prediction_endpoint(client, match):
531531
time.sleep(0.1)
532532
n += 1
533533

534+
535+
# End-to-end test for passing tracing headers on to downstream services.
536+
@responses.activate
537+
@uses_predictor_with_client_options(
538+
"output_file", upload_url="https://example.com/upload"
539+
)
540+
def test_asynchronous_prediction_endpoint_with_trace_context(client, match):
541+
webhook = responses.post(
542+
"https://example.com/webhook",
543+
match=[
544+
matchers.json_params_matcher(
545+
{
546+
"id": "12345abcde",
547+
"status": "succeeded",
548+
"output": "https://example.com/upload/file",
549+
},
550+
strict_match=False,
551+
),
552+
matchers.header_matcher(
553+
{
554+
"traceparent": "traceparent-123",
555+
"tracestate": "tracestate-123",
556+
},
557+
strict_match=False,
558+
),
559+
],
560+
status=200,
561+
)
562+
uploader = responses.put(
563+
"https://example.com/upload/file",
564+
match=[
565+
matchers.header_matcher(
566+
{
567+
"traceparent": "traceparent-123",
568+
"tracestate": "tracestate-123",
569+
},
570+
strict_match=False,
571+
),
572+
],
573+
status=200,
574+
)
575+
576+
resp = client.post(
577+
"/predictions",
578+
json={
579+
"id": "12345abcde",
580+
"input": {},
581+
"webhook": "https://example.com/webhook",
582+
"webhook_events_filter": ["completed"],
583+
},
584+
headers={
585+
"Prefer": "respond-async",
586+
"traceparent": "traceparent-123",
587+
"tracestate": "tracestate-123",
588+
},
589+
)
590+
assert resp.status_code == 202
591+
592+
assert resp.json() == match(
593+
{"status": "processing", "output": None, "started_at": mock.ANY}
594+
)
595+
assert resp.json()["started_at"] is not None
596+
597+
n = 0
598+
while webhook.call_count < 1 and n < 10:
599+
time.sleep(0.1)
600+
n += 1
601+
602+
assert uploader.call_count == 1
534603
assert webhook.call_count == 1
535604

536605

0 commit comments

Comments
 (0)