Skip to content

Commit 7aa890c

Browse files
arontechnillogue
authored andcommitted
Propagate trace context to webhook and upload requests
Based on the implementation in #1698 for sync cog. If the request to /predict contains headers `traceparent` and `tracestate` defined by w3c Trace Context[^1] then these headers are forwarded on to the webhook and upload calls. This allows observability systems to link requests passing through cog. [^1]: https://www.w3.org/TR/trace-context/ Signed-off-by: technillogue <technillogue@gmail.com>
1 parent 986550b commit 7aa890c

File tree

4 files changed

+81
-15
lines changed

4 files changed

+81
-15
lines changed

python/cog/server/clients.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,25 @@ def _get_version() -> str:
4343
WebhookSenderType = Callable[[Any, WebhookEvent], Awaitable[None]]
4444

4545

46-
def webhook_headers() -> "dict[str, str]":
46+
def common_headers() -> "dict[str, str]":
4747
headers = {"user-agent": _user_agent}
48+
return headers
49+
50+
51+
def webhook_headers() -> "dict[str, str]":
52+
headers = common_headers()
4853
auth_token = os.environ.get("WEBHOOK_AUTH_TOKEN")
4954
if auth_token:
5055
headers["authorization"] = "Bearer " + auth_token
56+
5157
return headers
5258

5359

60+
async def on_request_trace_context_hook(request: httpx.Request) -> None:
61+
ctx = current_trace_context() or {}
62+
request.headers.update(ctx)
63+
64+
5465
def httpx_webhook_client() -> httpx.AsyncClient:
5566
return httpx.AsyncClient(headers=webhook_headers(), follow_redirects=True)
5667

@@ -66,7 +77,10 @@ def httpx_retry_client() -> httpx.AsyncClient:
6677
retryable_methods=["POST"],
6778
)
6879
return httpx.AsyncClient(
69-
headers=webhook_headers(), transport=transport, follow_redirects=True
80+
event_hooks={"request": [on_request_trace_context_hook]},
81+
headers=webhook_headers(),
82+
transport=transport,
83+
follow_redirects=True,
7084
)
7185

7286

@@ -84,15 +98,14 @@ def httpx_file_client() -> httpx.AsyncClient:
8498
# requests has no write timeout, keep that
8599
# httpx default for pool is 5, use that
86100
timeout = httpx.Timeout(connect=10, read=15, write=None, pool=5)
87-
headers = {key: str(value) for key, value in (current_trace_context() or {})}
88-
headers["User-Agent"] = _user_agent
89101

90102
return httpx.AsyncClient(
103+
event_hooks={"request": [on_request_trace_context_hook]},
104+
headers=common_headers(),
91105
transport=transport,
92106
follow_redirects=True,
93107
timeout=timeout,
94108
http2=True,
95-
headers=headers,
96109
)
97110

98111

python/cog/server/http.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -287,10 +287,7 @@ async def predict(
287287
respond_async = prefer == "respond-async"
288288

289289
with trace_context(make_trace_context(traceparent, tracestate)):
290-
return _predict(
291-
request=request,
292-
respond_async=respond_async
293-
)
290+
return await shared_predict(request=request, respond_async=respond_async)
294291

295292
@limited
296293
@app.put(
@@ -328,13 +325,10 @@ async def predict_idempotent(
328325
respond_async = prefer == "respond-async"
329326

330327
with trace_context(make_trace_context(traceparent, tracestate)):
331-
return _predict(
332-
request=request,
333-
respond_async=respond_async,
334-
)
328+
return await shared_predict(request=request, respond_async=respond_async)
335329

336330

337-
async def _predict(
331+
async def shared_predict(
338332
*, request: Optional[PredictionRequest], respond_async: bool = False
339333
) -> Response:
340334
# [compat] If no body is supplied, assume that this model can be run

python/tests/cog/test_files.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
# async def test_put_file_to_signed_endpoint():
1111
# mock_fh = io.BytesIO()
1212
# mock_client = Mock()
13-
1413
# mock_response = Mock(spec=requests.Response)
1514
# mock_response.status_code = 201
1615
# mock_response.text = ""

python/tests/server/test_http.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import base64
2+
import httpx
23
import io
4+
import respx
35
import time
46
import unittest.mock as mock
57

@@ -604,6 +606,64 @@ def test_asynchronous_prediction_endpoint_with_trace_context(client, match):
604606
assert webhook.call_count == 1
605607

606608

609+
# End-to-end test for passing tracing headers on to downstream services.
610+
@pytest.mark.asyncio
611+
@pytest.mark.respx(base_url="https://example.com")
612+
@uses_predictor_with_client_options(
613+
"output_file", upload_url="https://example.com/upload"
614+
)
615+
async def test_asynchronous_prediction_endpoint_with_trace_context(
616+
respx_mock: respx.MockRouter, client, match
617+
):
618+
webhook = respx_mock.post(
619+
"/webhook",
620+
json__id="12345abcde",
621+
json__status="succeeded",
622+
json__output="https://example.com/upload/file",
623+
headers={
624+
"traceparent": "traceparent-123",
625+
"tracestate": "tracestate-123",
626+
},
627+
).respond(200)
628+
uploader = respx_mock.put(
629+
"/upload/file",
630+
headers={
631+
"content-type": "application/octet-stream",
632+
"traceparent": "traceparent-123",
633+
"tracestate": "tracestate-123",
634+
},
635+
).respond(200)
636+
637+
resp = client.post(
638+
"/predictions",
639+
json={
640+
"id": "12345abcde",
641+
"input": {},
642+
"webhook": "https://example.com/webhook",
643+
"webhook_events_filter": ["completed"],
644+
},
645+
headers={
646+
"Prefer": "respond-async",
647+
"traceparent": "traceparent-123",
648+
"tracestate": "tracestate-123",
649+
},
650+
)
651+
assert resp.status_code == 202
652+
653+
assert resp.json() == match(
654+
{"status": "processing", "output": None, "started_at": mock.ANY}
655+
)
656+
assert resp.json()["started_at"] is not None
657+
658+
n = 0
659+
while webhook.call_count < 1 and n < 10:
660+
time.sleep(0.1)
661+
n += 1
662+
663+
assert webhook.call_count == 1
664+
assert uploader.call_count == 1
665+
666+
607667
@uses_predictor("sleep")
608668
def test_prediction_cancel(client):
609669
resp = client.post("/predictions/123/cancel")

0 commit comments

Comments
 (0)