Skip to content

Commit

Permalink
WSGI fixes (#148)
Browse files Browse the repository at this point in the history
Fix http.url.
Don't delay calling wrapped app.
  • Loading branch information
Oberon00 authored and reyang committed Sep 24, 2019
1 parent 83aad2f commit 7813924
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 37 deletions.
76 changes: 57 additions & 19 deletions ext/opentelemetry-ext-wsgi/src/opentelemetry/ext/wsgi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,34 @@ def _add_request_attributes(span, environ):
span.set_attribute("component", "http")
span.set_attribute("http.method", environ["REQUEST_METHOD"])

host = environ.get("HTTP_HOST") or environ["SERVER_NAME"]
host = environ.get("HTTP_HOST")
if not host:
host = environ["SERVER_NAME"]
port = environ["SERVER_PORT"]
if (
port != "80"
and environ["wsgi.url_scheme"] == "http"
or port != "443"
):
host += ":" + port

# NOTE: Nonstandard
span.set_attribute("http.host", host)

url = (
environ.get("REQUEST_URI")
or environ.get("RAW_URI")
or wsgiref_util.request_uri(environ, include_query=False)
)
url = environ.get("REQUEST_URI") or environ.get("RAW_URI")

if url:
if url[0] == "/":
# We assume that no scheme-relative URLs will be in url here.
# After all, if a request is made to http://myserver//foo, we may get
# //foo which looks like scheme-relative but isn't.
url = environ["wsgi.url_scheme"] + "://" + host + url
elif not url.startswith(environ["wsgi.url_scheme"] + ":"):
# Something fishy is in RAW_URL. Let's fall back to request_uri()
url = wsgiref_util.request_uri(environ)
else:
url = wsgiref_util.request_uri(environ)

span.set_attribute("http.url", url)

@staticmethod
Expand Down Expand Up @@ -85,24 +105,27 @@ def __call__(self, environ, start_response):

tracer = trace.tracer()
path_info = environ["PATH_INFO"] or "/"
parent_span = propagators.extract(get_header_from_environ, environ)
parent_span = propagators.extract(_get_header_from_environ, environ)

with tracer.start_span(
span = tracer.create_span(
path_info, parent_span, kind=trace.SpanKind.SERVER
) as span:
self._add_request_attributes(span, environ)
start_response = self._create_start_response(span, start_response)
)
span.start()
try:
with tracer.use_span(span):
self._add_request_attributes(span, environ)
start_response = self._create_start_response(
span, start_response
)

iterable = self.wsgi(environ, start_response)
try:
for yielded in iterable:
yield yielded
finally:
if hasattr(iterable, "close"):
iterable.close()
iterable = self.wsgi(environ, start_response)
return _end_span_after_iterating(iterable, span, tracer)
except: # noqa
span.end()
raise


def get_header_from_environ(
def _get_header_from_environ(
environ: dict, header_name: str
) -> typing.List[str]:
"""Retrieve the header value from the wsgi environ dictionary.
Expand All @@ -115,3 +138,18 @@ def get_header_from_environ(
if value:
return [value]
return []


# Put this in a subfunction to not delay the call to the wrapped
# WSGI application (instrumentation should change the application
# behavior as little as possible).
def _end_span_after_iterating(iterable, span, tracer):
try:
with tracer.use_span(span):
for yielded in iterable:
yield yielded
finally:
close = getattr(iterable, "close", None)
if close:
close()
span.end()
123 changes: 105 additions & 18 deletions ext/opentelemetry-ext-wsgi/tests/test_wsgi_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import unittest
import unittest.mock as mock
import wsgiref.util as wsgiref_util
from urllib.parse import urlparse

from opentelemetry import trace as trace_api
from opentelemetry.ext.wsgi import OpenTelemetryMiddleware
Expand Down Expand Up @@ -52,6 +53,15 @@ def iter_wsgi(environ, start_response):
return iter_wsgi


def create_gen_wsgi(response):
def gen_wsgi(environ, start_response):
result = create_iter_wsgi(response)(environ, start_response)
yield from result
getattr(result, "close", lambda: None)()

return gen_wsgi


def error_wsgi(environ, start_response):
assert isinstance(environ, dict)
try:
Expand All @@ -66,18 +76,15 @@ def error_wsgi(environ, start_response):
class TestWsgiApplication(unittest.TestCase):
def setUp(self):
tracer = trace_api.tracer()
self.span_context_manager = mock.MagicMock()
self.span_context_manager.__enter__.return_value = mock.create_autospec(
trace_api.Span, spec_set=True
)
self.patcher = mock.patch.object(
self.span = mock.create_autospec(trace_api.Span, spec_set=True)
self.create_span_patcher = mock.patch.object(
tracer,
"start_span",
"create_span",
autospec=True,
spec_set=True,
return_value=self.span_context_manager,
return_value=self.span,
)
self.start_span = self.patcher.start()
self.create_span = self.create_span_patcher.start()

self.write_buffer = io.BytesIO()
self.write = self.write_buffer.write
Expand All @@ -90,11 +97,11 @@ def setUp(self):
self.exc_info = None

def tearDown(self):
self.patcher.stop()
self.create_span_patcher.stop()

def start_response(self, status, response_headers, exc_info=None):
# The span should have started already
self.span_context_manager.__enter__.assert_called_with()
self.span.start.assert_called_once_with()

self.status = status
self.response_headers = response_headers
Expand All @@ -105,12 +112,10 @@ def validate_response(self, response, error=None):
while True:
try:
value = next(response)
self.span_context_manager.__exit__.assert_not_called()
self.assertEqual(0, self.span.end.call_count)
self.assertEqual(value, b"*")
except StopIteration:
self.span_context_manager.__exit__.assert_called_with(
None, None, None
)
self.span.end.assert_called_once_with()
break

self.assertEqual(self.status, "200 OK")
Expand All @@ -125,9 +130,10 @@ def validate_response(self, response, error=None):
self.assertIsNone(self.exc_info)

# Verify that start_span has been called
self.start_span.assert_called_once_with(
self.create_span.assert_called_with(
"/", trace_api.INVALID_SPAN_CONTEXT, kind=trace_api.SpanKind.SERVER
)
self.span.start.assert_called_with()

def test_basic_wsgi_call(self):
app = OpenTelemetryMiddleware(simple_wsgi)
Expand All @@ -139,12 +145,24 @@ def test_wsgi_iterable(self):
iter_wsgi = create_iter_wsgi(original_response)
app = OpenTelemetryMiddleware(iter_wsgi)
response = app(self.environ, self.start_response)
# Verify that start_response has not been called yet
# Verify that start_response has been called
self.assertTrue(self.status)
self.validate_response(response)

# Verify that close has been called exactly once
self.assertEqual(original_response.close_calls, 1)

def test_wsgi_generator(self):
original_response = Response()
gen_wsgi = create_gen_wsgi(original_response)
app = OpenTelemetryMiddleware(gen_wsgi)
response = app(self.environ, self.start_response)
# Verify that start_response has not been called
self.assertIsNone(self.status)
self.validate_response(response)

# Verify that close has been called exactly once
assert original_response.close_calls == 1
self.assertEqual(original_response.close_calls, 1)

def test_wsgi_exc_info(self):
app = OpenTelemetryMiddleware(error_wsgi)
Expand All @@ -159,18 +177,87 @@ def setUp(self):
self.span = mock.create_autospec(trace_api.Span, spec_set=True)

def test_request_attributes(self):
self.environ["QUERY_STRING"] = "foo=bar"

OpenTelemetryMiddleware._add_request_attributes( # noqa pylint: disable=protected-access
self.span, self.environ
)

expected = (
mock.call("component", "http"),
mock.call("http.method", "GET"),
mock.call("http.host", "127.0.0.1"),
mock.call("http.url", "http://127.0.0.1/"),
mock.call("http.url", "http://127.0.0.1/?foo=bar"),
)
self.assertEqual(self.span.set_attribute.call_count, len(expected))
self.span.set_attribute.assert_has_calls(expected, any_order=True)

def validate_url(self, expected_url):
OpenTelemetryMiddleware._add_request_attributes( # noqa pylint: disable=protected-access
self.span, self.environ
)
attrs = {
args[0][0]: args[0][1]
for args in self.span.set_attribute.call_args_list
}
self.assertIn("http.url", attrs)
self.assertEqual(attrs["http.url"], expected_url)
self.assertIn("http.host", attrs)
self.assertEqual(
attrs["http.host"], urlparse(attrs["http.url"]).netloc
)

def test_request_attributes_with_partial_raw_uri(self):
self.environ["RAW_URI"] = "/#top"
self.validate_url("http://127.0.0.1/#top")

def test_request_attributes_with_partial_raw_uri_and_nonstandard_port(
self
):
self.environ["RAW_URI"] = "/?"
del self.environ["HTTP_HOST"]
self.environ["SERVER_PORT"] = "8080"
self.validate_url("http://127.0.0.1:8080/?")

def test_https_uri_port(self):
del self.environ["HTTP_HOST"]
self.environ["SERVER_PORT"] = "443"
self.environ["wsgi.url_scheme"] = "https"
self.validate_url("https://127.0.0.1/")

self.environ["SERVER_PORT"] = "8080"
self.validate_url("https://127.0.0.1:8080/")

self.environ["SERVER_PORT"] = "80"
self.validate_url("https://127.0.0.1:80/")

def test_request_attributes_with_nonstandard_port_and_no_host(self):
del self.environ["HTTP_HOST"]
self.environ["SERVER_PORT"] = "8080"
self.validate_url("http://127.0.0.1:8080/")

self.environ["SERVER_PORT"] = "443"
self.validate_url("http://127.0.0.1:443/")

def test_request_attributes_with_nonstandard_port(self):
self.environ["HTTP_HOST"] += ":8080"
self.validate_url("http://127.0.0.1:8080/")

def test_request_attributes_with_faux_scheme_relative_raw_uri(self):
self.environ["RAW_URI"] = "//127.0.0.1/?"
self.validate_url("http://127.0.0.1//127.0.0.1/?")

def test_request_attributes_with_pathless_raw_uri(self):
self.environ["PATH_INFO"] = ""
self.environ["RAW_URI"] = "http://hello"
self.environ["HTTP_HOST"] = "hello"
self.validate_url("http://hello")

def test_request_attributes_with_full_request_uri(self):
self.environ["HTTP_HOST"] = "127.0.0.1:8080"
self.environ["REQUEST_URI"] = "http://127.0.0.1:8080/?foo=bar#top"
self.validate_url("http://127.0.0.1:8080/?foo=bar#top")

def test_response_attributes(self):
OpenTelemetryMiddleware._add_response_attributes( # noqa pylint: disable=protected-access
self.span, "404 Not Found"
Expand Down

0 comments on commit 7813924

Please sign in to comment.