Skip to content

Commit b84bbbc

Browse files
authored
Fix error with missing content-type header in Starlette (#763)
* Run black (how was this missed?) * Don't assume "content-type" header is present * Better error handling around body capture * Add test for POST with body * Fix isort
1 parent e5efcf3 commit b84bbbc

File tree

2 files changed

+76
-33
lines changed

2 files changed

+76
-33
lines changed

elasticapm/contrib/starlette/utils.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,11 @@
3030

3131
import json
3232

33-
from starlette.requests import Request
34-
from starlette.types import Message
35-
from starlette.responses import Response
36-
3733
from elasticapm.conf import constants
3834
from elasticapm.utils import compat, get_url_dict
35+
from starlette.requests import Request
36+
from starlette.responses import Response
37+
from starlette.types import Message
3938

4039

4140
async def get_data_from_request(request: Request, capture_body=False, capture_headers=True) -> dict:
@@ -51,21 +50,22 @@ async def get_data_from_request(request: Request, capture_body=False, capture_he
5150
"""
5251
result = {
5352
"method": request.method,
54-
"socket": {
55-
"remote_address": _get_client_ip(request),
56-
"encrypted": request.url.is_secure
57-
},
53+
"socket": {"remote_address": _get_client_ip(request), "encrypted": request.url.is_secure},
5854
"cookies": request.cookies,
5955
}
6056
if capture_headers:
6157
result["headers"] = dict(request.headers)
6258

6359
if request.method in constants.HTTP_WITH_BODY:
64-
body = await get_body(request)
65-
if request.headers['content-type'] == "application/x-www-form-urlencoded":
66-
body = await query_params_to_dict(body)
67-
else:
68-
body = json.loads(body)
60+
body = None
61+
try:
62+
body = await get_body(request)
63+
if request.headers.get("content-type") == "application/x-www-form-urlencoded":
64+
body = await query_params_to_dict(body)
65+
else:
66+
body = json.loads(body)
67+
except Exception:
68+
pass
6969
if body is not None:
7070
result["body"] = body if capture_body else "[REDACTED]"
7171

@@ -107,6 +107,7 @@ async def set_body(request: Request, body: bytes):
107107
request (Request)
108108
body (bytes)
109109
"""
110+
110111
async def receive() -> Message:
111112
return {"type": "http.request", "body": body}
112113

@@ -155,9 +156,9 @@ async def query_params_to_dict(query_params: str) -> dict:
155156

156157

157158
def _get_client_ip(request: Request):
158-
x_forwarded_for = request.headers.get('HTTP_X_FORWARDED_FOR')
159+
x_forwarded_for = request.headers.get("HTTP_X_FORWARDED_FOR")
159160
if x_forwarded_for:
160-
ip = x_forwarded_for.split(',')[0]
161+
ip = x_forwarded_for.split(",")[0]
161162
else:
162-
ip = request.headers.get('REMOTE_ADDR')
163+
ip = request.headers.get("REMOTE_ADDR")
163164
return ip

tests/contrib/asyncio/starlette_tests.py

Lines changed: 59 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,14 @@
3333
starlette = pytest.importorskip("starlette") # isort:skip
3434

3535
import mock
36-
3736
from starlette.applications import Starlette
3837
from starlette.responses import PlainTextResponse
3938
from starlette.testclient import TestClient
40-
from elasticapm.utils.disttracing import TraceParent
41-
from elasticapm import async_capture_span
4239

40+
from elasticapm import async_capture_span
4341
from elasticapm.conf import constants
4442
from elasticapm.contrib.starlette import ElasticAPM
43+
from elasticapm.utils.disttracing import TraceParent
4544

4645
pytestmark = [pytest.mark.starlette]
4746

@@ -50,7 +49,7 @@
5049
def app(elasticapm_client):
5150
app = Starlette()
5251

53-
@app.route("/")
52+
@app.route("/", methods=["GET", "POST"])
5453
async def hi(request):
5554
with async_capture_span("test"):
5655
pass
@@ -68,11 +67,14 @@ async def raise_exception(request):
6867
def test_get(app, elasticapm_client):
6968
client = TestClient(app)
7069

71-
response = client.get('/', headers={
72-
constants.TRACEPARENT_HEADER_NAME: "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-03",
73-
constants.TRACESTATE_HEADER_NAME: "foo=bar,bar=baz",
74-
"REMOTE_ADDR": "127.0.0.1",
75-
})
70+
response = client.get(
71+
"/",
72+
headers={
73+
constants.TRACEPARENT_HEADER_NAME: "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-03",
74+
constants.TRACESTATE_HEADER_NAME: "foo=bar,bar=baz",
75+
"REMOTE_ADDR": "127.0.0.1",
76+
},
77+
)
7678

7779
assert response.status_code == 200
7880

@@ -93,15 +95,52 @@ def test_get(app, elasticapm_client):
9395
assert span["name"] == "test"
9496

9597

96-
def test_exception(app, elasticapm_client):
98+
@pytest.mark.parametrize("elasticapm_client", [{"capture_body": "all"}], indirect=True)
99+
def test_post(app, elasticapm_client):
97100
client = TestClient(app)
98101

99-
with pytest.raises(ValueError):
100-
client.get('/raise-exception', headers={
102+
response = client.post(
103+
"/",
104+
headers={
101105
constants.TRACEPARENT_HEADER_NAME: "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-03",
102106
constants.TRACESTATE_HEADER_NAME: "foo=bar,bar=baz",
103107
"REMOTE_ADDR": "127.0.0.1",
104-
})
108+
},
109+
data={"foo": "bar"},
110+
)
111+
112+
assert response.status_code == 200
113+
114+
assert len(elasticapm_client.events[constants.TRANSACTION]) == 1
115+
transaction = elasticapm_client.events[constants.TRANSACTION][0]
116+
spans = elasticapm_client.spans_for_transaction(transaction)
117+
assert len(spans) == 1
118+
span = spans[0]
119+
120+
assert transaction["name"] == "POST /"
121+
assert transaction["result"] == "HTTP 2xx"
122+
assert transaction["type"] == "request"
123+
assert transaction["span_count"]["started"] == 1
124+
request = transaction["context"]["request"]
125+
request["method"] == "GET"
126+
request["socket"] == {"remote_address": "127.0.0.1", "encrypted": False}
127+
assert request["body"]["foo"] == "bar"
128+
129+
assert span["name"] == "test"
130+
131+
132+
def test_exception(app, elasticapm_client):
133+
client = TestClient(app)
134+
135+
with pytest.raises(ValueError):
136+
client.get(
137+
"/raise-exception",
138+
headers={
139+
constants.TRACEPARENT_HEADER_NAME: "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-03",
140+
constants.TRACESTATE_HEADER_NAME: "foo=bar,bar=baz",
141+
"REMOTE_ADDR": "127.0.0.1",
142+
},
143+
)
105144

106145
assert len(elasticapm_client.events[constants.TRANSACTION]) == 1
107146
transaction = elasticapm_client.events[constants.TRANSACTION][0]
@@ -129,10 +168,13 @@ def test_traceparent_handling(app, elasticapm_client, header_name):
129168
with mock.patch(
130169
"elasticapm.contrib.flask.TraceParent.from_string", wraps=TraceParent.from_string
131170
) as wrapped_from_string:
132-
response = client.get('/', headers={
133-
header_name: "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-03",
134-
constants.TRACESTATE_HEADER_NAME: "foo=bar,baz=bazzinga",
135-
})
171+
response = client.get(
172+
"/",
173+
headers={
174+
header_name: "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-03",
175+
constants.TRACESTATE_HEADER_NAME: "foo=bar,baz=bazzinga",
176+
},
177+
)
136178

137179
assert response.status_code == 200
138180

0 commit comments

Comments
 (0)