Skip to content

Commit 1e7d482

Browse files
aronmattt
authored andcommitted
Propagate trace context to webhook and upload requests
Based on the implementation in #1698 for sync cog. If the request to /predict contains headers `traceparent` and `tracestate` defined by w3c Trace Context[^1] then these headers are forwarded on to the webhook and upload calls. This allows observability systems to link requests passing through cog. [^1]: https://www.w3.org/TR/trace-context/ Signed-off-by: technillogue <technillogue@gmail.com>
1 parent 001295d commit 1e7d482

File tree

3 files changed

+79
-6
lines changed

3 files changed

+79
-6
lines changed

python/cog/server/clients.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,25 @@ def _get_version() -> str:
4646
WebhookSenderType = Callable[[Any, WebhookEvent], Awaitable[None]]
4747

4848

49-
def webhook_headers() -> "dict[str, str]":
49+
def common_headers() -> "dict[str, str]":
5050
headers = {"user-agent": _user_agent}
51+
return headers
52+
53+
54+
def webhook_headers() -> "dict[str, str]":
55+
headers = common_headers()
5156
auth_token = os.environ.get("WEBHOOK_AUTH_TOKEN")
5257
if auth_token:
5358
headers["authorization"] = "Bearer " + auth_token
59+
5460
return headers
5561

5662

63+
async def on_request_trace_context_hook(request: httpx.Request) -> None:
64+
ctx = current_trace_context() or {}
65+
request.headers.update(ctx)
66+
67+
5768
def httpx_webhook_client() -> httpx.AsyncClient:
5869
return httpx.AsyncClient(headers=webhook_headers(), follow_redirects=True)
5970

@@ -69,7 +80,10 @@ def httpx_retry_client() -> httpx.AsyncClient:
6980
retryable_methods=["POST"],
7081
)
7182
return httpx.AsyncClient(
72-
headers=webhook_headers(), transport=transport, follow_redirects=True
83+
event_hooks={"request": [on_request_trace_context_hook]},
84+
headers=webhook_headers(),
85+
transport=transport,
86+
follow_redirects=True,
7387
)
7488

7589

@@ -91,6 +105,8 @@ def httpx_file_client() -> httpx.AsyncClient:
91105
headers["User-Agent"] = _user_agent
92106

93107
return httpx.AsyncClient(
108+
event_hooks={"request": [on_request_trace_context_hook]},
109+
headers=common_headers(),
94110
transport=transport,
95111
follow_redirects=True,
96112
timeout=timeout,

python/cog/server/http.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -341,10 +341,7 @@ async def predict_idempotent(
341341
respond_async = prefer == "respond-async"
342342

343343
with trace_context(make_trace_context(traceparent, tracestate)):
344-
return shared_predict(
345-
request=request,
346-
respond_async=respond_async,
347-
)
344+
return await shared_predict(request=request, respond_async=respond_async)
348345

349346
async def shared_predict(
350347
*, request: Optional[PredictionRequest], respond_async: bool = False

python/tests/server/test_http.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import base64
2+
import httpx
23
import io
4+
import respx
35
import time
46
import unittest.mock as mock
57

@@ -629,6 +631,64 @@ def test_asynchronous_prediction_endpoint_with_trace_context(client, match):
629631
assert webhook.call_count == 1
630632

631633

634+
# End-to-end test for passing tracing headers on to downstream services.
635+
@pytest.mark.asyncio
636+
@pytest.mark.respx(base_url="https://example.com")
637+
@uses_predictor_with_client_options(
638+
"output_file", upload_url="https://example.com/upload"
639+
)
640+
async def test_asynchronous_prediction_endpoint_with_trace_context(
641+
respx_mock: respx.MockRouter, client, match
642+
):
643+
webhook = respx_mock.post(
644+
"/webhook",
645+
json__id="12345abcde",
646+
json__status="succeeded",
647+
json__output="https://example.com/upload/file",
648+
headers={
649+
"traceparent": "traceparent-123",
650+
"tracestate": "tracestate-123",
651+
},
652+
).respond(200)
653+
uploader = respx_mock.put(
654+
"/upload/file",
655+
headers={
656+
"content-type": "application/octet-stream",
657+
"traceparent": "traceparent-123",
658+
"tracestate": "tracestate-123",
659+
},
660+
).respond(200)
661+
662+
resp = client.post(
663+
"/predictions",
664+
json={
665+
"id": "12345abcde",
666+
"input": {},
667+
"webhook": "https://example.com/webhook",
668+
"webhook_events_filter": ["completed"],
669+
},
670+
headers={
671+
"Prefer": "respond-async",
672+
"traceparent": "traceparent-123",
673+
"tracestate": "tracestate-123",
674+
},
675+
)
676+
assert resp.status_code == 202
677+
678+
assert resp.json() == match(
679+
{"status": "processing", "output": None, "started_at": mock.ANY}
680+
)
681+
assert resp.json()["started_at"] is not None
682+
683+
n = 0
684+
while webhook.call_count < 1 and n < 10:
685+
time.sleep(0.1)
686+
n += 1
687+
688+
assert webhook.call_count == 1
689+
assert uploader.call_count == 1
690+
691+
632692
@uses_predictor("sleep")
633693
def test_prediction_cancel(client):
634694
resp = client.post("/predictions/123/cancel")

0 commit comments

Comments
 (0)