Skip to content

Propagate trace context to webhook and upload requests #1698

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/cog/code_xforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@


def load_module_from_string(
name: str, source: Union[str, None]
) -> Union[types.ModuleType, None]:
name: str, source: Optional[str]
) -> Optional[types.ModuleType]:
if not source or not name:
return None
module = types.ModuleType(name)
Expand Down
2 changes: 1 addition & 1 deletion python/cog/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def load_full_predictor_from_file(

def load_slim_predictor_from_file(
module_path: str, class_name: str, method_name: str
) -> Union[types.ModuleType, None]:
) -> Optional[types.ModuleType]:
with open(module_path, encoding="utf-8") as file:
source_code = file.read()
stripped_source = code_xforms.strip_model_source_code(
Expand Down
51 changes: 40 additions & 11 deletions python/cog/server/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
Dict,
Optional,
TypeVar,
Union,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -56,6 +55,7 @@
SetupTask,
UnknownPredictionError,
)
from .telemetry import make_trace_context, trace_context

log = structlog.get_logger("cog.server.http")

Expand Down Expand Up @@ -185,9 +185,16 @@ class TrainingRequest(
)
def train(
request: TrainingRequest = Body(default=None),
prefer: Union[str, None] = Header(default=None),
prefer: Optional[str] = Header(default=None),
traceparent: Optional[str] = Header(
default=None, include_in_schema=False
),
tracestate: Optional[str] = Header(
default=None, include_in_schema=False
),
) -> Any: # type: ignore
return predict(request, prefer)
with trace_context(make_trace_context(traceparent, tracestate)):
return predict(request, prefer)

@app.put(
"/trainings/{training_id}",
Expand All @@ -197,9 +204,16 @@ def train(
def train_idempotent(
training_id: str = Path(..., title="Training ID"),
request: TrainingRequest = Body(..., title="Training Request"),
prefer: Union[str, None] = Header(default=None),
prefer: Optional[str] = Header(default=None),
traceparent: Optional[str] = Header(
default=None, include_in_schema=False
),
tracestate: Optional[str] = Header(
default=None, include_in_schema=False
),
) -> Any:
return predict_idempotent(training_id, request, prefer)
with trace_context(make_trace_context(traceparent, tracestate)):
return predict_idempotent(training_id, request, prefer)

@app.post("/trainings/{training_id}/cancel")
def cancel_training(
Expand Down Expand Up @@ -259,7 +273,9 @@ async def healthcheck() -> Any:
)
async def predict(
request: PredictionRequest = Body(default=None),
prefer: Union[str, None] = Header(default=None),
prefer: Optional[str] = Header(default=None),
traceparent: Optional[str] = Header(default=None, include_in_schema=False),
tracestate: Optional[str] = Header(default=None, include_in_schema=False),
) -> Any: # type: ignore
"""
Run a single prediction on the model
Expand All @@ -272,7 +288,11 @@ async def predict(
# TODO: spec-compliant parsing of Prefer header.
respond_async = prefer == "respond-async"

return _predict(request=request, respond_async=respond_async)
with trace_context(make_trace_context(traceparent, tracestate)):
return _predict(
request=request,
respond_async=respond_async,
)

@limited
@app.put(
Expand All @@ -283,7 +303,9 @@ async def predict(
async def predict_idempotent(
prediction_id: str = Path(..., title="Prediction ID"),
request: PredictionRequest = Body(..., title="Prediction Request"),
prefer: Union[str, None] = Header(default=None),
prefer: Optional[str] = Header(default=None),
traceparent: Optional[str] = Header(default=None, include_in_schema=False),
tracestate: Optional[str] = Header(default=None, include_in_schema=False),
) -> Any:
"""
Run a single prediction on the model (idempotent creation).
Expand All @@ -307,10 +329,16 @@ async def predict_idempotent(
# TODO: spec-compliant parsing of Prefer header.
respond_async = prefer == "respond-async"

return _predict(request=request, respond_async=respond_async)
with trace_context(make_trace_context(traceparent, tracestate)):
return _predict(
request=request,
respond_async=respond_async,
)

def _predict(
*, request: PredictionRequest, respond_async: bool = False
*,
request: Optional[PredictionRequest],
respond_async: bool = False,
) -> Response:
# [compat] If no body is supplied, assume that this model can be run
# with empty input. This will throw a ValidationError if that's not
Expand All @@ -327,7 +355,8 @@ def _predict(
# async predictions. This is unfortunate but required to ensure
# backwards-compatible behaviour for synchronous predictions.
initial_response, async_result = runner.predict(
request, upload=respond_async
request,
upload=respond_async,
)
except RunnerBusyError:
return JSONResponse(
Expand Down
24 changes: 20 additions & 4 deletions python/cog/server/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from ..json import upload_files
from .eventtypes import Done, Heartbeat, Log, PredictionOutput, PredictionOutputType
from .probes import ProbeHelper
from .telemetry import current_trace_context
from .useragent import get_user_agent
from .webhook import SKIP_START_EVENT, webhook_caller_filtered
from .worker import Worker

Expand Down Expand Up @@ -98,7 +100,9 @@ def handle_error(error: BaseException) -> None:
# TODO: Make the return type AsyncResult[schema.PredictionResponse] when we
# no longer have to support Python 3.8
def predict(
self, prediction: schema.PredictionRequest, upload: bool = True
self,
prediction: schema.PredictionRequest,
upload: bool = True,
) -> Tuple[schema.PredictionResponse, PredictionTask]:
# It's the caller's responsibility to not call us if we're busy.
if self.is_busy():
Expand All @@ -119,9 +123,12 @@ def predict(

self._should_cancel.clear()
upload_url = self._upload_url if upload else None
event_handler = create_event_handler(prediction, upload_url=upload_url)
event_handler = create_event_handler(
prediction,
upload_url=upload_url,
)

def cleanup(_: schema.PredictionResponse = None) -> None:
def cleanup(_: Optional[schema.PredictionResponse] = None) -> None:
input = cast(Any, prediction.input)
if hasattr(input, "cleanup"):
input.cleanup()
Expand Down Expand Up @@ -178,7 +185,8 @@ def cancel(self, prediction_id: Optional[str] = None) -> None:


def create_event_handler(
prediction: schema.PredictionRequest, upload_url: Optional[str] = None
prediction: schema.PredictionRequest,
upload_url: Optional[str] = None,
) -> "PredictionEventHandler":
response = schema.PredictionResponse(**prediction.dict())

Expand Down Expand Up @@ -452,6 +460,14 @@ def _predict(

def _make_file_upload_http_client() -> requests.Session:
session = requests.Session()
session.headers["user-agent"] = (
get_user_agent() + " " + str(session.headers["user-agent"])
)

ctx = current_trace_context() or {}
for key, value in ctx.items():
session.headers[key] = str(value)

adapter = HTTPAdapter(
max_retries=Retry(
total=3,
Expand Down
54 changes: 54 additions & 0 deletions python/cog/server/telemetry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Generator, Optional

# TypedDict was added in 3.8
from typing_extensions import TypedDict


# See: https://www.w3.org/TR/trace-context/
class TraceContext(TypedDict, total=False):
traceparent: str
tracestate: str


TRACE_CONTEXT: ContextVar[Optional[TraceContext]] = ContextVar(
"trace_context", default=None
)


def make_trace_context(
traceparent: Optional[str] = None, tracestate: Optional[str] = None
) -> TraceContext:
"""
Creates a trace context dictionary from the given traceparent and tracestate
headers. This is used to pass the trace context between services.
"""
ctx: TraceContext = {}
if traceparent:
ctx["traceparent"] = traceparent
if tracestate:
ctx["tracestate"] = tracestate
return ctx


def current_trace_context() -> Optional[TraceContext]:
"""
Returns the current trace context, this needs to be added via HTTP headers
to all outgoing HTTP requests.
"""
return TRACE_CONTEXT.get()


@contextmanager
def trace_context(ctx: TraceContext) -> Generator[None, None, None]:
"""
A helper for managing the current trace context provided by the inbound
HTTP request. This context is used to link requests across the system and
needs to be added to all internal outgoing HTTP requests.
"""
t = TRACE_CONTEXT.set(ctx)
try:
yield
finally:
TRACE_CONTEXT.reset(t)
17 changes: 17 additions & 0 deletions python/cog/server/useragent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
def _get_version() -> str:
try:
try:
from importlib.metadata import version
except ImportError:
pass
else:
return version("cog")
import pkg_resources

return pkg_resources.get_distribution("cog").version
except Exception:
return "unknown"


def get_user_agent() -> str:
return f"cog-worker/{_get_version()}"
26 changes: 8 additions & 18 deletions python/cog/server/webhook.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,11 @@

from ..schema import Status, WebhookEvent
from .response_throttler import ResponseThrottler
from .telemetry import current_trace_context
from .useragent import get_user_agent

log = structlog.get_logger(__name__)


def _get_version() -> str:
try:
try:
from importlib.metadata import version
except ImportError:
pass
else:
return version("cog")
import pkg_resources

return pkg_resources.get_distribution("cog").version
except Exception:
return "unknown"


_user_agent = f"cog-worker/{_get_version()}"
_response_interval = float(os.environ.get("COG_THROTTLE_RESPONSE_INTERVAL", 0.5))

# HACK: signal that we should skip the start webhook when the response interval
Expand Down Expand Up @@ -76,8 +61,13 @@ def caller(response: Any) -> None:
def requests_session() -> requests.Session:
session = requests.Session()
session.headers["user-agent"] = (
_user_agent + " " + str(session.headers["user-agent"])
get_user_agent() + " " + str(session.headers["user-agent"])
)

ctx = current_trace_context() or {}
for key, value in ctx.items():
session.headers[key] = str(value)

auth_token = os.environ.get("WEBHOOK_AUTH_TOKEN")
if auth_token:
session.headers["authorization"] = "Bearer " + auth_token
Expand Down
69 changes: 69 additions & 0 deletions python/tests/server/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,75 @@ def test_asynchronous_prediction_endpoint(client, match):
time.sleep(0.1)
n += 1


# End-to-end test for passing tracing headers on to downstream services.
@responses.activate
@uses_predictor_with_client_options(
"output_file", upload_url="https://example.com/upload"
)
def test_asynchronous_prediction_endpoint_with_trace_context(client, match):
webhook = responses.post(
"https://example.com/webhook",
match=[
matchers.json_params_matcher(
{
"id": "12345abcde",
"status": "succeeded",
"output": "https://example.com/upload/file",
},
strict_match=False,
),
matchers.header_matcher(
{
"traceparent": "traceparent-123",
"tracestate": "tracestate-123",
},
strict_match=False,
),
],
status=200,
)
uploader = responses.put(
"https://example.com/upload/file",
match=[
matchers.header_matcher(
{
"traceparent": "traceparent-123",
"tracestate": "tracestate-123",
},
strict_match=False,
),
],
status=200,
)

resp = client.post(
"/predictions",
json={
"id": "12345abcde",
"input": {},
"webhook": "https://example.com/webhook",
"webhook_events_filter": ["completed"],
},
headers={
"Prefer": "respond-async",
"traceparent": "traceparent-123",
"tracestate": "tracestate-123",
},
)
assert resp.status_code == 202

assert resp.json() == match(
{"status": "processing", "output": None, "started_at": mock.ANY}
)
assert resp.json()["started_at"] is not None

n = 0
while webhook.call_count < 1 and n < 10:
time.sleep(0.1)
n += 1

assert uploader.call_count == 1
assert webhook.call_count == 1


Expand Down
Loading