Skip to content

Commit 8f196cb

Browse files
arontechnillogue
authored andcommitted
Propagate trace context to webhook and upload requests
Based on the implementation in #1698 for sync cog. If the request to /predict contains headers `traceparent` and `tracestate` defined by w3c Trace Context[^1] then these headers are forwarded on to the webhook and upload calls. This allows observability systems to link requests passing through cog. [^1]: https://www.w3.org/TR/trace-context/ Signed-off-by: technillogue <technillogue@gmail.com>
1 parent 8d834f0 commit 8f196cb

File tree

4 files changed

+163
-11
lines changed

4 files changed

+163
-11
lines changed

python/cog/server/clients.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .eventtypes import PredictionInput
1616
from .response_throttler import ResponseThrottler
1717
from .retry_transport import RetryTransport
18+
from .telemetry import current_trace_context
1819

1920
log = structlog.get_logger(__name__)
2021

@@ -45,14 +46,25 @@ def _get_version() -> str:
4546
WebhookSenderType = Callable[[Any, WebhookEvent], Awaitable[None]]
4647

4748

48-
def webhook_headers() -> "dict[str, str]":
49+
def common_headers() -> "dict[str, str]":
4950
headers = {"user-agent": _user_agent}
51+
return headers
52+
53+
54+
def webhook_headers() -> "dict[str, str]":
55+
headers = common_headers()
5056
auth_token = os.environ.get("WEBHOOK_AUTH_TOKEN")
5157
if auth_token:
5258
headers["authorization"] = "Bearer " + auth_token
59+
5360
return headers
5461

5562

63+
async def on_request_trace_context_hook(request: httpx.Request) -> None:
64+
ctx = current_trace_context() or {}
65+
request.headers.update(ctx)
66+
67+
5668
def httpx_webhook_client() -> httpx.AsyncClient:
5769
return httpx.AsyncClient(headers=webhook_headers(), follow_redirects=True)
5870

@@ -68,7 +80,10 @@ def httpx_retry_client() -> httpx.AsyncClient:
6880
retryable_methods=["POST"],
6981
)
7082
return httpx.AsyncClient(
71-
headers=webhook_headers(), transport=transport, follow_redirects=True
83+
event_hooks={"request": [on_request_trace_context_hook]},
84+
headers=webhook_headers(),
85+
transport=transport,
86+
follow_redirects=True,
7287
)
7388

7489

@@ -87,6 +102,8 @@ def httpx_file_client() -> httpx.AsyncClient:
87102
# httpx default for pool is 5, use that
88103
timeout = httpx.Timeout(connect=10, read=15, write=None, pool=5)
89104
return httpx.AsyncClient(
105+
event_hooks={"request": [on_request_trace_context_hook]},
106+
headers=common_headers(),
90107
transport=transport,
91108
follow_redirects=True,
92109
timeout=timeout,

python/cog/server/http.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
Dict,
2020
Optional,
2121
TypeVar,
22-
Union,
2322
)
2423

2524
if TYPE_CHECKING:
@@ -52,6 +51,7 @@
5251
SetupTask,
5352
UnknownPredictionError,
5453
)
54+
from .telemetry import make_trace_context, trace_context
5555

5656
log = structlog.get_logger("cog.server.http")
5757

@@ -190,9 +190,16 @@ class TrainingRequest(
190190
)
191191
def train(
192192
request: TrainingRequest = Body(default=None),
193-
prefer: Union[str, None] = Header(default=None),
193+
prefer: Optional[str] = Header(default=None),
194+
traceparent: Optional[str] = Header(default=None, include_in_schema=False),
195+
tracestate: Optional[str] = Header(default=None, include_in_schema=False),
194196
) -> Any: # type: ignore
195-
return predict(request, prefer)
197+
return predict(
198+
request,
199+
prefer=prefer,
200+
traceparent=traceparent,
201+
tracestate=tracestate,
202+
)
196203

197204
@app.put(
198205
"/trainings/{training_id}",
@@ -202,9 +209,17 @@ def train(
202209
def train_idempotent(
203210
training_id: str = Path(..., title="Training ID"),
204211
request: TrainingRequest = Body(..., title="Training Request"),
205-
prefer: Union[str, None] = Header(default=None),
212+
prefer: Optional[str] = Header(default=None),
213+
traceparent: Optional[str] = Header(default=None, include_in_schema=False),
214+
tracestate: Optional[str] = Header(default=None, include_in_schema=False),
206215
) -> Any:
207-
return predict_idempotent(training_id, request, prefer)
216+
return predict_idempotent(
217+
prediction_id=training_id,
218+
request=request,
219+
prefer=prefer,
220+
traceparent=traceparent,
221+
tracestate=tracestate,
222+
)
208223

209224
@app.post("/trainings/{training_id}/cancel")
210225
def cancel_training(training_id: str = Path(..., title="Training ID")) -> Any:
@@ -270,7 +285,9 @@ async def ready() -> Any:
270285
)
271286
async def predict(
272287
request: PredictionRequest = Body(default=None),
273-
prefer: Union[str, None] = Header(default=None),
288+
prefer: Optional[str] = Header(default=None),
289+
traceparent: Optional[str] = Header(default=None, include_in_schema=False),
290+
tracestate: Optional[str] = Header(default=None, include_in_schema=False),
274291
) -> Any: # type: ignore
275292
"""
276293
Run a single prediction on the model
@@ -285,7 +302,8 @@ async def predict(
285302
# TODO: spec-compliant parsing of Prefer header.
286303
respond_async = prefer == "respond-async"
287304

288-
return await shared_predict(request=request, respond_async=respond_async)
305+
with trace_context(make_trace_context(traceparent, tracestate)):
306+
return await shared_predict(request=request, respond_async=respond_async)
289307

290308
@limited
291309
@app.put(
@@ -296,7 +314,9 @@ async def predict(
296314
async def predict_idempotent(
297315
prediction_id: str = Path(..., title="Prediction ID"),
298316
request: PredictionRequest = Body(..., title="Prediction Request"),
299-
prefer: Union[str, None] = Header(default=None),
317+
prefer: Optional[str] = Header(default=None),
318+
traceparent: Optional[str] = Header(default=None, include_in_schema=False),
319+
tracestate: Optional[str] = Header(default=None, include_in_schema=False),
300320
) -> Any:
301321
"""
302322
Run a single prediction on the model (idempotent creation).
@@ -314,7 +334,8 @@ async def predict_idempotent(
314334
# TODO: spec-compliant parsing of Prefer header.
315335
respond_async = prefer == "respond-async"
316336

317-
return await shared_predict(request=request, respond_async=respond_async)
337+
with trace_context(make_trace_context(traceparent, tracestate)):
338+
return await shared_predict(request=request, respond_async=respond_async)
318339

319340
async def shared_predict(
320341
*, request: Optional[PredictionRequest], respond_async: bool = False

python/cog/server/telemetry.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
from contextlib import contextmanager
2+
from contextvars import ContextVar
3+
from typing import Generator, Optional
4+
5+
# TypedDict was added in 3.8
6+
from typing_extensions import TypedDict
7+
8+
9+
# See: https://www.w3.org/TR/trace-context/
10+
class TraceContext(TypedDict, total=False):
11+
traceparent: str
12+
tracestate: str
13+
14+
15+
TRACE_CONTEXT: ContextVar[Optional[TraceContext]] = ContextVar(
16+
"trace_context", default=None
17+
)
18+
19+
20+
def make_trace_context(
21+
traceparent: Optional[str] = None, tracestate: Optional[str] = None
22+
) -> TraceContext:
23+
"""
24+
Creates a trace context dictionary from the given traceparent and tracestate
25+
headers. This is used to pass the trace context between services.
26+
"""
27+
ctx: TraceContext = {}
28+
if traceparent:
29+
ctx["traceparent"] = traceparent
30+
if tracestate:
31+
ctx["tracestate"] = tracestate
32+
return ctx
33+
34+
35+
def current_trace_context() -> Optional[TraceContext]:
36+
"""
37+
Returns the current trace context, this needs to be added via HTTP headers
38+
to all outgoing HTTP requests.
39+
"""
40+
return TRACE_CONTEXT.get()
41+
42+
43+
@contextmanager
44+
def trace_context(ctx: TraceContext) -> Generator[None, None, None]:
45+
"""
46+
A helper for managing the current trace context provided by the inbound
47+
HTTP request. This context is used to link requests across the system and
48+
needs to be added to all internal outgoing HTTP requests.
49+
"""
50+
t = TRACE_CONTEXT.set(ctx)
51+
try:
52+
yield
53+
finally:
54+
TRACE_CONTEXT.reset(t)

python/tests/server/test_http.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import base64
2+
import httpx
23
import io
4+
import respx
35
import time
46
import unittest.mock as mock
57

@@ -560,6 +562,64 @@ def test_asynchronous_prediction_endpoint(client, match):
560562
assert webhook.call_count == 1
561563

562564

565+
# End-to-end test for passing tracing headers on to downstream services.
566+
@pytest.mark.asyncio
567+
@pytest.mark.respx(base_url="https://example.com")
568+
@uses_predictor_with_client_options(
569+
"output_file", upload_url="https://example.com/upload"
570+
)
571+
async def test_asynchronous_prediction_endpoint_with_trace_context(
572+
respx_mock: respx.MockRouter, client, match
573+
):
574+
webhook = respx_mock.post(
575+
"/webhook",
576+
json__id="12345abcde",
577+
json__status="succeeded",
578+
json__output="https://example.com/upload/file",
579+
headers={
580+
"traceparent": "traceparent-123",
581+
"tracestate": "tracestate-123",
582+
},
583+
).respond(200)
584+
uploader = respx_mock.put(
585+
"/upload/file",
586+
headers={
587+
"content-type": "application/octet-stream",
588+
"traceparent": "traceparent-123",
589+
"tracestate": "tracestate-123",
590+
},
591+
).respond(200)
592+
593+
resp = client.post(
594+
"/predictions",
595+
json={
596+
"id": "12345abcde",
597+
"input": {},
598+
"webhook": "https://example.com/webhook",
599+
"webhook_events_filter": ["completed"],
600+
},
601+
headers={
602+
"Prefer": "respond-async",
603+
"traceparent": "traceparent-123",
604+
"tracestate": "tracestate-123",
605+
},
606+
)
607+
assert resp.status_code == 202
608+
609+
assert resp.json() == match(
610+
{"status": "processing", "output": None, "started_at": mock.ANY}
611+
)
612+
assert resp.json()["started_at"] is not None
613+
614+
n = 0
615+
while webhook.call_count < 1 and n < 10:
616+
time.sleep(0.1)
617+
n += 1
618+
619+
assert webhook.call_count == 1
620+
assert uploader.call_count == 1
621+
622+
563623
@uses_predictor("sleep")
564624
def test_prediction_cancel(client):
565625
resp = client.post("/predictions/123/cancel")

0 commit comments

Comments
 (0)