Skip to content

feat(asm): add handlers to support the AWS Lambda framework #13638

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 81 additions & 2 deletions ddtrace/appsec/_handlers.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
import io
import json
from typing import Any
from typing import Dict
from typing import Optional

import xmltodict

from ddtrace._trace.span import Span
from ddtrace.appsec._asm_request_context import _call_waf
from ddtrace.appsec._asm_request_context import _call_waf_first
from ddtrace.appsec._asm_request_context import get_blocked
from ddtrace.appsec._constants import SPAN_DATA_NAMES
from ddtrace.appsec._http_utils import extract_cookies_from_headers
from ddtrace.appsec._http_utils import normalize_headers
from ddtrace.appsec._http_utils import parse_http_body
from ddtrace.contrib import trace_utils
from ddtrace.contrib.internal.trace_utils_base import _get_request_header_user_agent
from ddtrace.contrib.internal.trace_utils_base import _set_url_tag
from ddtrace.ext import SpanTypes
from ddtrace.ext import http
from ddtrace.internal import core
from ddtrace.internal.constants import RESPONSE_HEADERS
Expand Down Expand Up @@ -53,7 +61,7 @@ def _on_set_http_meta(
response_headers,
response_cookies,
):
if asm_config._asm_enabled and span.span_type == SpanTypes.WEB:
if asm_config._asm_enabled and span.span_type in asm_config._asm_http_span_types:
# avoid circular import
from ddtrace.appsec._asm_request_context import set_waf_address

Expand All @@ -77,6 +85,74 @@ def _on_set_http_meta(
set_waf_address(k, v)


# AWS Lambda
def _on_lambda_start_request(
span: Span,
request_headers: Dict[str, str],
request_ip: Optional[str],
body: Optional[str],
is_body_base64: bool,
raw_uri: str,
route: str,
method: str,
parsed_query: Dict[str, Any],
):
if not (asm_config._asm_enabled and span.span_type in asm_config._asm_http_span_types):
return

headers = normalize_headers(request_headers)
request_body = parse_http_body(headers, body, is_body_base64)
request_cookies = extract_cookies_from_headers(headers)

_on_set_http_meta(
span,
request_ip,
raw_uri,
route,
method,
headers,
request_cookies,
parsed_query,
None,
request_body,
None,
None,
None,
)

_call_waf_first(("aws_lambda",))


def _on_lambda_start_response(
span: Span,
status_code: str,
response_headers: Dict[str, str],
):
if not (asm_config._asm_enabled and span.span_type in asm_config._asm_http_span_types):
return

waf_headers = normalize_headers(response_headers)
response_cookies = extract_cookies_from_headers(waf_headers)

_on_set_http_meta(
span,
None,
None,
None,
None,
None,
None,
None,
None,
None,
status_code,
waf_headers,
response_cookies,
)

_call_waf(("aws_lambda",))


# ASGI


Expand Down Expand Up @@ -307,6 +383,9 @@ def listen():

core.on("asgi.request.parse.body", _on_asgi_request_parse_body, "await_receive_and_body")

core.on("aws_lambda.start_request", _on_lambda_start_request)
core.on("aws_lambda.start_response", _on_lambda_start_response)

core.on("grpc.server.response.message", _on_grpc_server_response)
core.on("grpc.server.data", _on_grpc_server_data)

Expand Down
81 changes: 81 additions & 0 deletions ddtrace/appsec/_http_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import base64
from http.cookies import SimpleCookie
import json
from typing import Any
from typing import Dict
from typing import Optional
from typing import Union
from urllib.parse import parse_qs

import xmltodict

from ddtrace.internal.utils import http as http_utils


def normalize_headers(
request_headers: Dict[str, str],
) -> Dict[str, Optional[str]]:
"""Normalize headers according to the WAF expectations.

The WAF expects headers to be lowercased and empty values to be None.
"""
headers: Dict[str, Optional[str]] = {}
for key, value in request_headers.items():
normalized_key = http_utils.normalize_header_name(key)
if value:
headers[normalized_key] = str(value).strip()
else:
headers[normalized_key] = None
return headers


def parse_http_body(
normalized_headers: Dict[str, Optional[str]],
body: Optional[str],
is_body_base64: bool,
) -> Union[str, Dict[str, Any], None]:
"""Parse a request body based on the content-type header."""
if body is None:
return None
if is_body_base64:
try:
body = base64.b64decode(body).decode()
except (ValueError, TypeError):
return None

try:
content_type = normalized_headers.get("content-type")
if not content_type:
return None

if content_type in ("application/json", "application/vnd.api+json", "text/json"):
return json.loads(body)
elif content_type in ("application/x-url-encoded", "application/x-www-form-urlencoded"):
return parse_qs(body)
elif content_type in ("application/xml", "text/xml"):
return xmltodict.parse(body)
elif content_type.startswith("multipart/form-data"):
return http_utils.parse_form_multipart(body, normalized_headers)
elif content_type == "text/plain":
return body
else:
return None

except Exception:
return None


def extract_cookies_from_headers(
normalized_headers: Dict[str, Optional[str]],
) -> Optional[Dict[str, str]]:
"""Extract cookies from the WAF headers."""
cookie_names = {"cookie", "set-cookie"}
for name in cookie_names:
if name in normalized_headers:
cookie = SimpleCookie()
header = normalized_headers[name]
del normalized_headers[name]
if header:
cookie.load(header)
return {k: v.value for k, v in cookie.items()}
return None
6 changes: 1 addition & 5 deletions ddtrace/appsec/_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
from ddtrace.appsec._utils import DDWaf_result
from ddtrace.constants import _ORIGIN_KEY
from ddtrace.constants import _RUNTIME_FAMILY
from ddtrace.ext import SpanTypes
from ddtrace.internal._unpatched import unpatched_open as open # noqa: A004
from ddtrace.internal.logger import get_logger
from ddtrace.internal.rate_limiter import RateLimiter
Expand Down Expand Up @@ -232,9 +231,6 @@ def _waf_action(
be retrieved from the `core`. This can be used when you don't want to store
the value in the `core` before checking the `WAF`.
"""
if span.span_type not in (SpanTypes.WEB, SpanTypes.HTTP, SpanTypes.GRPC):
return None

if _asm_request_context.get_blocked():
# We still must run the waf if we need to extract schemas for API SECURITY
if not custom_data or not custom_data.get("PROCESSOR_SETTINGS", {}).get("extract-schema", False):
Expand Down Expand Up @@ -365,7 +361,7 @@ def _is_needed(self, address: str) -> bool:
return address in self._addresses_to_keep

def on_span_finish(self, span: Span) -> None:
if span.span_type in {SpanTypes.WEB, SpanTypes.GRPC}:
if span.span_type in asm_config._asm_processed_span_types:
_asm_request_context.call_waf_callback_no_instrumentation()
self._ddwaf._at_request_end()
_asm_request_context.end_context(span)
2 changes: 2 additions & 0 deletions ddtrace/settings/asm.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class ASMConfig(DDConfig):
if _asm_static_rule_file == "":
_asm_static_rule_file = None
_asm_processed_span_types = {SpanTypes.WEB, SpanTypes.GRPC}
_asm_http_span_types = {SpanTypes.WEB}
_iast_enabled = tracer_config._from_endpoint.get("iast_enabled", DDConfig.var(bool, IAST.ENV, default=False))
_iast_request_sampling = DDConfig.var(float, IAST.ENV_REQUEST_SAMPLING, default=30.0)
_iast_debug = DDConfig.var(bool, IAST.ENV_DEBUG, default=False, private=True)
Expand Down Expand Up @@ -230,6 +231,7 @@ def __init__(self):

if in_aws_lambda():
self._asm_processed_span_types.add(SpanTypes.SERVERLESS)
self._asm_http_span_types.add(SpanTypes.SERVERLESS)

# As a first step, only Threat Management in monitoring mode should be enabled in AWS Lambda
tracer_config._remote_config_enabled = False
Expand Down
118 changes: 118 additions & 0 deletions tests/appsec/appsec/test_appsec_http_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import pytest

from ddtrace.appsec import _http_utils


@pytest.mark.parametrize(
"input_headers, expected",
[
({"Host": "Example.COM"}, {"host": "Example.COM"}),
(
{"X-Custom-None": "", "Content-Type": "application/json", "X-Custom-Spacing ": " trim spaces "},
{"x-custom-none": None, "content-type": "application/json", "x-custom-spacing": "trim spaces"},
),
],
)
def test_normalize_headers(input_headers, expected):
result = _http_utils.normalize_headers(input_headers)
assert result == expected


@pytest.mark.parametrize(
"headers, body, is_body_base64, expected_output",
[
# Body is None
({}, None, False, None),
# Base64 encoded body - text/plain
(
{"content-type": "text/plain"},
"dGV4dCBib2R5",
True,
"text body",
),
# Base64 encoded body - application/json
(
{"content-type": "application/json"},
"eyJrZXkiOiAidmFsdWUifQ==",
True,
{"key": "value"},
),
# Base64 decoding failure - text/plain
(
{"content-type": "text/plain"},
"invalid_base64_string",
True,
None,
),
# JSON content types
({"content-type": "application/json"}, '{"key": "value"}', False, {"key": "value"}),
({"content-type": "application/vnd.api+json"}, '{"key": "value"}', False, {"key": "value"}),
({"content-type": "text/json"}, '{"key": "value"}', False, {"key": "value"}),
# Form urlencoded
(
{"content-type": "application/x-www-form-urlencoded"},
"key=value&key2=value2",
False,
{"key": ["value"], "key2": ["value2"]},
),
# XML content types
({"content-type": "application/xml"}, "<root><key>value</key></root>", False, {"root": {"key": "value"}}),
({"content-type": "text/xml"}, "<root><key>value</key></root>", False, {"root": {"key": "value"}}),
# Text plain
({"content-type": "text/plain"}, "simple text body", False, "simple text body"),
# Unsupported content type
({"content-type": "application/octet-stream"}, "binary data", False, None),
# No content type provided
({}, "some body", False, None),
# Invalid JSON
({"content-type": "application/json"}, "not a valid json string", False, None),
# Invalid XML
({"content-type": "application/xml"}, "<root><key>value</missing_key></root>", False, None),
# Multipart form data
(
{"content-type": "multipart/form-data; boundary=boundary"},
(
"--boundary\r\n"
'Content-Disposition: form-data; name="formPart"\r\n'
"content-type: application/x-www-form-urlencoded\r\n"
"\r\n"
"key=value\r\n"
"--boundary--"
),
False,
{"formPart": {"key": ["value"]}}, # Mocked return value for parse_form_multipart
),
# Invalid base64 encoded body (decoding fails)
(
{"content-type": "application/xml"},
"invalid_base64_and_invalid_xml",
True,
None,
),
],
)
def test_parse_http_body(headers, body, is_body_base64, expected_output, mocker):
result = _http_utils.parse_http_body(headers, body, is_body_base64)
assert result == expected_output


@pytest.mark.parametrize(
"input_headers, expected",
[
(
{"cookie": "sessionid=abc123; csrftoken=xyz789"},
{"sessionid": "abc123", "csrftoken": "xyz789"},
),
(
{"set-cookie": "sessionid=abc123; Path=/; HttpOnly"},
{"sessionid": "abc123"},
),
({"cookie": ""}, {}),
({"cookie": None}, {}),
({"set-cookie": None}, {}),
],
)
# Tests for extract_cookies_from_headers
def test_extract_cookies_from_headers(input_headers, expected):
result = _http_utils.extract_cookies_from_headers(input_headers)
assert result == expected
Loading