Skip to content

Commit

Permalink
Feat: Use the lambda runtime as parent context as opt-in
Browse files Browse the repository at this point in the history
In a previous PR open-telemetry#1657 the lambda behaviour was changed so that it
extracs the context from the headers of the lambda event instead of from
the lambda runtime, using the _X_AMZN_TRACE_ID env var.

This behaviour is undesireble as it is a breaking change to existing
users.

This PR will add the previous behaviour through an opt-in flag so that
users can be gracefully migrated or these two "modes" of operation can
exist without conflict.

Signed-off-by: Raphael Silva <rapphil@gmail.com>
  • Loading branch information
rapphil committed Aug 15, 2023
1 parent 1beab82 commit 883592d
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 75 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def custom_event_context_extractor(lambda_event):
OTEL_INSTRUMENTATION_AWS_LAMBDA_FLUSH_TIMEOUT = (
"OTEL_INSTRUMENTATION_AWS_LAMBDA_FLUSH_TIMEOUT"
)
OTEL_LAMBDA_USE_AWS_CONTEXT_PROPAGATION = "OTEL_LAMBDA_USE_AWS_CONTEXT_PROPAGATION"


def _default_event_context_extractor(lambda_event: Any) -> Context:
Expand Down Expand Up @@ -140,7 +141,9 @@ def _default_event_context_extractor(lambda_event: Any) -> Context:


def _determine_parent_context(
lambda_event: Any, event_context_extractor: Callable[[Any], Context]
lambda_event: Any,
use_aws_context_propagation: bool,
event_context_extractor: Callable[[Any], Context],
) -> Context:
"""Determine the parent context for the current Lambda invocation.
Expand All @@ -158,15 +161,26 @@ def _determine_parent_context(
A Context with configuration found in the carrier.
"""
parent_context = None

if event_context_extractor:
if use_aws_context_propagation:
parent_context = _get_x_ray_context()
elif event_context_extractor:
parent_context = event_context_extractor(lambda_event)
else:
parent_context = _default_event_context_extractor(lambda_event)

return parent_context


def _get_x_ray_context() -> Optional[Context]:
"""Determine teh context propagated through the lambda runtime"""
xray_env_var = os.environ.get(_X_AMZN_TRACE_ID)
if xray_env_var:
env_context = AwsXRayPropagator().extract({TRACE_HEADER_KEY: xray_env_var})
return env_context

return None


def _determine_links() -> Optional[Sequence[Link]]:
"""Determine if a Link should be added to the Span based on the
environment variable `_X_AMZN_TRACE_ID`.
Expand All @@ -180,31 +194,23 @@ def _determine_links() -> Optional[Sequence[Link]]:
"""
links = None

xray_env_var = os.environ.get(_X_AMZN_TRACE_ID)
x_ray_context = _get_x_ray_context()

if xray_env_var:
env_context = AwsXRayPropagator().extract(
{TRACE_HEADER_KEY: xray_env_var}
)

span_context = get_current_span(env_context).get_span_context()
if x_ray_context:
span_context = get_current_span(x_ray_context).get_span_context()
if span_context.is_valid:
links = [Link(span_context, {"source": "x-ray-env"})]

return links


def _set_api_gateway_v1_proxy_attributes(
lambda_event: Any, span: Span
) -> Span:
def _set_api_gateway_v1_proxy_attributes(lambda_event: Any, span: Span) -> Span:
"""Sets HTTP attributes for REST APIs and v1 HTTP APIs
More info:
https://docs.aws.amazon.com/apigateway/latest/developerguide/set-up-lambda-proxy-integrations.html#api-gateway-simple-proxy-for-lambda-input-format
"""
span.set_attribute(
SpanAttributes.HTTP_METHOD, lambda_event.get("httpMethod")
)
span.set_attribute(SpanAttributes.HTTP_METHOD, lambda_event.get("httpMethod"))

if lambda_event.get("headers"):
if "User-Agent" in lambda_event["headers"]:
Expand All @@ -231,16 +237,12 @@ def _set_api_gateway_v1_proxy_attributes(
f"{lambda_event['resource']}?{urlencode(lambda_event['queryStringParameters'])}",
)
else:
span.set_attribute(
SpanAttributes.HTTP_TARGET, lambda_event["resource"]
)
span.set_attribute(SpanAttributes.HTTP_TARGET, lambda_event["resource"])

return span


def _set_api_gateway_v2_proxy_attributes(
lambda_event: Any, span: Span
) -> Span:
def _set_api_gateway_v2_proxy_attributes(lambda_event: Any, span: Span) -> Span:
"""Sets HTTP attributes for v2 HTTP APIs
More info:
Expand Down Expand Up @@ -289,21 +291,26 @@ def _instrument(
event_context_extractor: Callable[[Any], Context],
tracer_provider: TracerProvider = None,
meter_provider: MeterProvider = None,
use_aws_context_propagation: bool = False,
):
def _instrumented_lambda_handler_call( # noqa pylint: disable=too-many-branches
call_wrapped, instance, args, kwargs
):
orig_handler_name = ".".join(
[wrapped_module_name, wrapped_function_name]
)
orig_handler_name = ".".join([wrapped_module_name, wrapped_function_name])

lambda_event = args[0]

# We are not fully complying with the specification here to be backwards
# compatible with the old version of the specification.
# the ``use_aws_context_propagation`` flag allow us
# to opt-in into the previous behavior
parent_context = _determine_parent_context(
lambda_event, event_context_extractor
lambda_event, use_aws_context_propagation, event_context_extractor
)

links = _determine_links()
links = None
if not use_aws_context_propagation:
links = _determine_links()

span_kind = None
try:
Expand Down Expand Up @@ -354,9 +361,7 @@ def _instrumented_lambda_handler_call( # noqa pylint: disable=too-many-branches
# If the request came from an API Gateway, extract http attributes from the event
# https://github.com/open-telemetry/opentelemetry-specification/blob/main/specification/trace/semantic_conventions/instrumentation/aws-lambda.md#api-gateway
# https://github.com/open-telemetry/opentelemetry-specification/blob/main/specification/trace/semantic_conventions/http.md#http-server-semantic-conventions
if isinstance(lambda_event, dict) and lambda_event.get(
"requestContext"
):
if isinstance(lambda_event, dict) and lambda_event.get("requestContext"):
span.set_attribute(SpanAttributes.FAAS_TRIGGER, "http")

if lambda_event.get("version") == "2.0":
Expand Down Expand Up @@ -424,6 +429,13 @@ def _instrument(self, **kwargs):
Event as input and extracts an OTel Context from it. By default,
the context is extracted from the HTTP headers of an API Gateway
request.
``use_aws_context_propagation``: whether to use the AWS context propagation
to populate the parent context. When set to true, the spans
from the lambda runtime will not be added as span link to the
span of the lambda invocation.
Defaults to False.
"""
_instrument(**kwargs)
"""
lambda_handler = os.environ.get(ORIG_HANDLER, os.environ.get(_HANDLER))
# pylint: disable=attribute-defined-outside-init
Expand Down Expand Up @@ -454,6 +466,11 @@ def _instrument(self, **kwargs):
),
tracer_provider=kwargs.get("tracer_provider"),
meter_provider=kwargs.get("meter_provider"),
use_aws_context_propagation=kwargs.get(
"use_aws_context_propagation",
os.environ.get(OTEL_LAMBDA_USE_AWS_CONTEXT_PROPAGATION, "False").lower()
in ("true", "1", "t"),
),
)

def _uninstrument(self, **kwargs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
_HANDLER,
_X_AMZN_TRACE_ID,
OTEL_INSTRUMENTATION_AWS_LAMBDA_FLUSH_TIMEOUT,
OTEL_LAMBDA_USE_AWS_CONTEXT_PROPAGATION,
AwsLambdaInstrumentor,
)
from opentelemetry.propagate import get_global_textmap
Expand Down Expand Up @@ -59,9 +60,7 @@ def __init__(self, aws_request_id, invoked_function_arn):
MOCK_XRAY_PARENT_SPAN_ID = 0x3328B8445A6DBAD2
MOCK_XRAY_TRACE_CONTEXT_COMMON = f"Root={TRACE_ID_VERSION}-{MOCK_XRAY_TRACE_ID_STR[:TRACE_ID_FIRST_PART_LENGTH]}-{MOCK_XRAY_TRACE_ID_STR[TRACE_ID_FIRST_PART_LENGTH:]};Parent={MOCK_XRAY_PARENT_SPAN_ID:x}"
MOCK_XRAY_TRACE_CONTEXT_SAMPLED = f"{MOCK_XRAY_TRACE_CONTEXT_COMMON};Sampled=1"
MOCK_XRAY_TRACE_CONTEXT_NOT_SAMPLED = (
f"{MOCK_XRAY_TRACE_CONTEXT_COMMON};Sampled=0"
)
MOCK_XRAY_TRACE_CONTEXT_NOT_SAMPLED = f"{MOCK_XRAY_TRACE_CONTEXT_COMMON};Sampled=0"

# See more:
# https://www.w3.org/TR/trace-context/#examples-of-http-traceparent-headers
Expand Down Expand Up @@ -116,43 +115,70 @@ def tearDown(self):
AwsLambdaInstrumentor().uninstrument()

def test_active_tracing(self):
test_env_patch = mock.patch.dict(
"os.environ",
{
**os.environ,
# Using Active tracing
_X_AMZN_TRACE_ID: MOCK_XRAY_TRACE_CONTEXT_SAMPLED,
},
)
test_env_patch.start()
@dataclass
class Testcase:
name: str
use_aws_context_propagation: str
expected_trace_id: str

AwsLambdaInstrumentor().instrument()
tests = [
Testcase(
name="Use aws context propgation",
use_aws_context_propagation="true",
expected_trace_id=MOCK_XRAY_TRACE_ID,
),
Testcase(
name="Do not use aws context propgation",
use_aws_context_propagation="false",
expected_trace_id=None,
),
]

mock_execute_lambda()
for test in tests:
test_env_patch = mock.patch.dict(
"os.environ",
{
**os.environ,
# Using Active tracing
_X_AMZN_TRACE_ID: MOCK_XRAY_TRACE_CONTEXT_SAMPLED,
OTEL_LAMBDA_USE_AWS_CONTEXT_PROPAGATION: test.use_aws_context_propagation,
},
)
test_env_patch.start()

spans = self.memory_exporter.get_finished_spans()
AwsLambdaInstrumentor().instrument()

assert spans
mock_execute_lambda()

self.assertEqual(len(spans), 1)
span = spans[0]
self.assertEqual(span.name, os.environ[_HANDLER])
self.assertNotEqual(
span.get_span_context().trace_id, MOCK_XRAY_TRACE_ID
)
self.assertEqual(span.kind, SpanKind.SERVER)
self.assertSpanHasAttributes(
span,
{
ResourceAttributes.FAAS_ID: MOCK_LAMBDA_CONTEXT.invoked_function_arn,
SpanAttributes.FAAS_EXECUTION: MOCK_LAMBDA_CONTEXT.aws_request_id,
},
)
spans = self.memory_exporter.get_finished_spans()

parent_context = span.parent
self.assertEqual(None, parent_context)
assert spans

test_env_patch.stop()
self.assertEqual(len(spans), 1)
span = spans[0]
self.assertEqual(span.name, os.environ[_HANDLER])
parent_context = span.parent
if test.expected_trace_id is None:
self.assertNotEqual(
span.get_span_context().trace_id, MOCK_XRAY_TRACE_ID
)
self.assertEqual(None, parent_context)
else:
self.assertEqual(span.get_span_context().trace_id, MOCK_XRAY_TRACE_ID)
self.assertEqual(
parent_context.trace_id, span.get_span_context().trace_id
)
self.assertEqual(span.kind, SpanKind.SERVER)
self.assertSpanHasAttributes(
span,
{
ResourceAttributes.FAAS_ID: MOCK_LAMBDA_CONTEXT.invoked_function_arn,
SpanAttributes.FAAS_EXECUTION: MOCK_LAMBDA_CONTEXT.aws_request_id,
},
)
self.memory_exporter.clear()
AwsLambdaInstrumentor().uninstrument()
test_env_patch.stop()

def test_parent_context_from_lambda_event(self):
@dataclass
Expand Down Expand Up @@ -218,14 +244,10 @@ def custom_event_context_extractor(lambda_event):
assert spans
self.assertEqual(len(spans), 1)
span = spans[0]
self.assertEqual(
span.get_span_context().trace_id, test.expected_traceid
)
self.assertEqual(span.get_span_context().trace_id, test.expected_traceid)

parent_context = span.parent
self.assertEqual(
parent_context.trace_id, span.get_span_context().trace_id
)
self.assertEqual(parent_context.trace_id, span.get_span_context().trace_id)
self.assertEqual(parent_context.span_id, test.expected_parentid)
self.assertEqual(
len(parent_context.trace_state), test.expected_trace_state_len
Expand All @@ -247,6 +269,7 @@ class TestCase:
expected_link_trace_id: int
expected_link_attributes: dict
xray_traceid: str
use_xray_propagator: str

tests = [
TestCase(
Expand All @@ -255,13 +278,23 @@ class TestCase:
expected_link_trace_id=MOCK_XRAY_TRACE_ID,
expected_link_attributes={"source": "x-ray-env"},
xray_traceid=MOCK_XRAY_TRACE_CONTEXT_SAMPLED,
use_xray_propagator="false",
),
TestCase(
name="invalid_xray_trace",
context={},
expected_link_trace_id=None,
expected_link_attributes={},
xray_traceid="0",
use_xray_propagator="false",
),
TestCase(
name="use_xra",
context={},
expected_link_trace_id=None,
expected_link_attributes={},
xray_traceid=MOCK_XRAY_TRACE_CONTEXT_SAMPLED,
use_xray_propagator="true",
),
]
for test in tests:
Expand All @@ -273,6 +306,7 @@ class TestCase:
_X_AMZN_TRACE_ID: test.xray_traceid,
# NOT using the X-Ray Propagator
OTEL_PROPAGATORS: "tracecontext",
OTEL_LAMBDA_USE_AWS_CONTEXT_PROPAGATION: test.use_xray_propagator,
},
)
test_env_patch.start()
Expand All @@ -287,12 +321,8 @@ class TestCase:
self.assertEqual(0, len(span.links))
else:
link = span.links[0]
self.assertEqual(
link.context.trace_id, test.expected_link_trace_id
)
self.assertEqual(
link.attributes, test.expected_link_attributes
)
self.assertEqual(link.context.trace_id, test.expected_link_trace_id)
self.assertEqual(link.attributes, test.expected_link_attributes)

self.memory_exporter.clear()
AwsLambdaInstrumentor().uninstrument()
Expand Down

0 comments on commit 883592d

Please sign in to comment.