Skip to content

feat(appsec): enable request blocking #630

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 47 additions & 1 deletion datadog_lambda/asm.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from copy import deepcopy
import logging
import urllib.parse
from copy import deepcopy
from typing import Any, Dict, List, Optional, Union

from ddtrace.contrib.internal.trace_utils import _get_request_header_client_ip
from ddtrace.internal import core
from ddtrace.internal.utils import get_blocked, set_blocked
from ddtrace.internal.utils import http as http_utils
from ddtrace.trace import Span

from datadog_lambda.trigger import (
Expand Down Expand Up @@ -50,6 +53,7 @@ def asm_set_context(event_source: _EventSource):
This allows the AppSecSpanProcessor to know information about the event
at the moment the span is created and skip it when not relevant.
"""

if event_source.event_type not in _http_event_types:
core.set_item("appsec_skip_next_lambda_event", True)

Expand Down Expand Up @@ -126,6 +130,14 @@ def asm_start_request(
span.set_tag_str("http.client_ip", request_ip)
span.set_tag_str("network.client.ip", request_ip)

# Encode the parsed query and append it to reconstruct the original raw URI expected by AppSec.
if parsed_query:
try:
encoded_query = urllib.parse.urlencode(parsed_query, doseq=True)
raw_uri += "?" + encoded_query # type: ignore
except Exception:
pass

core.dispatch(
# The matching listener is registered in ddtrace.appsec._handlers
"aws_lambda.start_request",
Expand Down Expand Up @@ -182,3 +194,37 @@ def asm_start_response(
response_headers,
),
)


def get_asm_blocked_response(
event_source: _EventSource,
) -> Optional[Dict[str, Any]]:
"""Get the blocked response for the given event source."""
if event_source.event_type not in _http_event_types:
return None

blocked = get_blocked()
if not blocked:
return None
set_blocked(blocked)

desired_type = blocked.get("type", "auto")
if desired_type == "none":
content_type = "text/plain; charset=utf-8"
content = ""
else:
content_type = blocked.get("content-type", "application/json")
content = http_utils._get_blocked_template(content_type)

response_headers = {
"content-type": content_type,
}
if "location" in blocked:
response_headers["location"] = blocked["location"]

return {
"statusCode": blocked.get("status_code", 403),
"headers": response_headers,
"body": content,
"isBase64Encoded": False,
}
25 changes: 23 additions & 2 deletions datadog_lambda/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,13 @@
from importlib import import_module
from time import time_ns

from datadog_lambda.asm import asm_set_context, asm_start_response, asm_start_request
from ddtrace.internal._exceptions import BlockingException
from datadog_lambda.asm import (
asm_set_context,
asm_start_response,
asm_start_request,
get_asm_blocked_response,
)
from datadog_lambda.extension import should_use_extension, flush_extension
from datadog_lambda.cold_start import (
set_cold_start,
Expand Down Expand Up @@ -120,6 +126,7 @@ def __init__(self, func):
self.span = None
self.inferred_span = None
self.response = None
self.blocking_response = None

if config.profiling_enabled:
self.prof = profiler.Profiler(env=config.env, service=config.service)
Expand Down Expand Up @@ -155,12 +162,21 @@ def __init__(self, func):
except Exception as e:
logger.error(format_err_with_traceback(e))

def _get_blocking_response(self):
if not config.appsec_enabled:
return None
return get_asm_blocked_response(self.event_source)

def __call__(self, event, context, **kwargs):
"""Executes when the wrapped function gets called"""
self._before(event, context)
try:
if self.blocking_response:
return self.blocking_response
self.response = self.func(event, context, **kwargs)
return self.response
except BlockingException:
self.blocking_response = self._get_blocking_response()
except Exception:
from datadog_lambda.metric import submit_errors_metric

Expand All @@ -171,6 +187,8 @@ def __call__(self, event, context, **kwargs):
raise
finally:
self._after(event, context)
if self.blocking_response:
return self.blocking_response

def _inject_authorizer_span_headers(self, request_id):
reference_span = self.inferred_span if self.inferred_span else self.span
Expand Down Expand Up @@ -203,6 +221,7 @@ def _inject_authorizer_span_headers(self, request_id):
def _before(self, event, context):
try:
self.response = None
self.blocking_response = None
set_cold_start(init_timestamp_ns)

if not should_use_extension:
Expand Down Expand Up @@ -253,6 +272,7 @@ def _before(self, event, context):
)
if config.appsec_enabled:
asm_start_request(self.span, event, event_source, self.trigger_tags)
self.blocking_response = self._get_blocking_response()
else:
set_correlation_ids()
if config.profiling_enabled and is_new_sandbox():
Expand Down Expand Up @@ -286,13 +306,14 @@ def _after(self, event, context):
if status_code:
self.span.set_tag("http.status_code", status_code)

if config.appsec_enabled:
if config.appsec_enabled and not self.blocking_response:
asm_start_response(
self.span,
status_code,
self.event_source,
response=self.response,
)
self.blocking_response = self._get_blocking_response()

self.span.finish()

Expand Down
71 changes: 66 additions & 5 deletions tests/test_asm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,17 @@
import pytest
from unittest.mock import MagicMock, patch

from datadog_lambda.asm import asm_start_request, asm_start_response
from datadog_lambda.trigger import parse_event_source, extract_trigger_tags
from datadog_lambda.asm import (
asm_start_request,
asm_start_response,
get_asm_blocked_response,
)
from datadog_lambda.trigger import (
EventTypes,
_EventSource,
extract_trigger_tags,
parse_event_source,
)
from tests.utils import get_mock_context

event_samples = "tests/event_samples/"
Expand All @@ -15,7 +24,7 @@
"application_load_balancer",
"application-load-balancer.json",
"72.12.164.125",
"/lambda",
"/lambda?query=1234ABCD",
"GET",
"",
False,
Expand All @@ -27,7 +36,7 @@
"application_load_balancer_multivalue_headers",
"application-load-balancer-mutivalue-headers.json",
"72.12.164.125",
"/lambda",
"/lambda?query=1234ABCD",
"GET",
"",
False,
Expand All @@ -51,7 +60,7 @@
"api_gateway",
"api-gateway.json",
"127.0.0.1",
"/path/to/resource",
"/path/to/resource?foo=bar",
"POST",
"eyJ0ZXN0IjoiYm9keSJ9",
True,
Expand Down Expand Up @@ -199,6 +208,30 @@
),
]

ASM_BLOCKED_RESPONSE_TEST_CASES = [
(
{"status_code": 403, "type": "auto"},
403,
{"content-type": "application/json"},
),
(
{"status_code": 401, "content-type": "text/html", "location": "/login"},
401,
{"content-type": "text/html", "location": "/login"},
),
(
{"status_code": 301, "type": "none", "location": "/redirect"},
301,
{"content-type": "text/plain; charset=utf-8", "location": "/redirect"},
),
(
{"status_code": 302, "location": "https://datadoghq.com"},
302,
{"content-type": "application/json", "location": "https://datadoghq.com"},
),
({"type": "auto"}, 403, {"content-type": "application/json"}),
]


@pytest.mark.parametrize(
"name,file,expected_ip,expected_uri,expected_method,expected_body,expected_base64,expected_query,expected_path_params,expected_route",
Expand Down Expand Up @@ -327,3 +360,31 @@ def test_asm_start_response_parametrized(
else:
# Verify core.dispatch was not called for non-HTTP events
mock_core.dispatch.assert_not_called()


@pytest.mark.parametrize(
"blocked_config, expected_status, expected_headers",
ASM_BLOCKED_RESPONSE_TEST_CASES,
)
@patch("datadog_lambda.asm.get_blocked")
def test_get_asm_blocked_response_blocked(
mock_get_blocked,
blocked_config,
expected_status,
expected_headers,
):
mock_get_blocked.return_value = blocked_config
event_source = _EventSource(event_type=EventTypes.API_GATEWAY)
response = get_asm_blocked_response(event_source)
assert response["statusCode"] == expected_status
assert response["headers"] == expected_headers


@patch("datadog_lambda.asm.get_blocked")
def test_get_asm_blocked_response_not_blocked(
mock_get_blocked,
):
mock_get_blocked.return_value = None
event_source = _EventSource(event_type=EventTypes.API_GATEWAY)
response = get_asm_blocked_response(event_source)
assert response is None
111 changes: 110 additions & 1 deletion tests/test_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import unittest

from unittest.mock import patch, call, ANY
from unittest.mock import MagicMock, patch, call, ANY
from datadog_lambda.constants import TraceHeader

import datadog_lambda.wrapper as wrapper
Expand Down Expand Up @@ -660,3 +660,112 @@ def lambda_handler(event, context):
lambda_handler(lambda_event, lambda_context)

self.assertEqual(len(flushes), 0)


class TestLambdaWrapperAppsecBlocking(unittest.TestCase):
def setUp(self):
os.environ["DD_APPSEC_ENABLED"] = "true"
os.environ["DD_TRACE_ENABLED"] = "true"

self.addCleanup(os.environ.pop, "DD_APPSEC_ENABLED", None)
self.addCleanup(os.environ.pop, "DD_TRACE_ENABLED", None)

patcher = patch("datadog_lambda.wrapper.asm_set_context")
self.mock_asm_set_context = patcher.start()
self.addCleanup(patcher.stop)

patcher = patch("datadog_lambda.wrapper.asm_start_request")
self.mock_asm_start_request = patcher.start()
self.addCleanup(patcher.stop)

patcher = patch("datadog_lambda.wrapper.asm_start_response")
self.mock_asm_start_response = patcher.start()
self.addCleanup(patcher.stop)

patcher = patch("datadog_lambda.wrapper.get_asm_blocked_response")
self.mock_get_asm_blocking_response = patcher.start()
self.addCleanup(patcher.stop)

self.fake_blocking_response = {
"statusCode": "403",
"headers": {
"Content-Type": "application/json",
},
"body": '{"message": "Blocked by AppSec"}',
"isBase64Encoded": False,
}

def test_blocking_before(self):
self.mock_get_asm_blocking_response.return_value = self.fake_blocking_response

mock_handler = MagicMock()

lambda_handler = wrapper.datadog_lambda_wrapper(mock_handler)

response = lambda_handler({}, get_mock_context())
self.assertEqual(response, self.fake_blocking_response)

mock_handler.assert_not_called()

self.mock_asm_set_context.assert_called_once()
self.mock_asm_start_request.assert_called_once()
self.mock_asm_start_response.assert_not_called()

def test_blocking_during(self):
self.mock_get_asm_blocking_response.return_value = None

@wrapper.datadog_lambda_wrapper
def lambda_handler(event, context):
self.mock_get_asm_blocking_response.return_value = (
self.fake_blocking_response
)
raise wrapper.BlockingException()

response = lambda_handler({}, get_mock_context())
self.assertEqual(response, self.fake_blocking_response)

self.mock_asm_set_context.assert_called_once()
self.mock_asm_start_request.assert_called_once()
self.mock_asm_start_response.assert_not_called()

def test_blocking_after(self):
self.mock_get_asm_blocking_response.return_value = None

@wrapper.datadog_lambda_wrapper
def lambda_handler(event, context):
self.mock_get_asm_blocking_response.return_value = (
self.fake_blocking_response
)
return {
"statusCode": 200,
"body": "This should not be returned",
}

response = lambda_handler({}, get_mock_context())
self.assertEqual(response, self.fake_blocking_response)

self.mock_asm_set_context.assert_called_once()
self.mock_asm_start_request.assert_called_once()
self.mock_asm_start_response.assert_called_once()

def test_no_blocking_appsec_disabled(self):
os.environ["DD_APPSEC_ENABLED"] = "false"

self.mock_get_asm_blocking_response.return_value = self.fake_blocking_response

expected_response = {
"statusCode": 200,
"body": "This should be returned",
}

@wrapper.datadog_lambda_wrapper
def lambda_handler(event, context):
return expected_response

response = lambda_handler({}, get_mock_context())
self.assertEqual(response, expected_response)

self.mock_get_asm_blocking_response.assert_not_called()
self.mock_asm_set_context.assert_not_called()
self.mock_asm_start_request.assert_not_called()
self.mock_asm_start_response.assert_not_called()
Loading