Skip to content

Commit

Permalink
Update instrumentations to use tracer_provider for creating tracer if…
Browse files Browse the repository at this point in the history
… given, otherwise use global tracer provider (#402)
  • Loading branch information
srikanthccv authored Apr 28, 2021
1 parent bdbc249 commit 3ec7736
Show file tree
Hide file tree
Showing 33 changed files with 408 additions and 95 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
([#387](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/387))
- Update redis instrumentation to follow semantic conventions
([#403](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/403))
- Update instrumentations to use tracer_provider for creating tracer if given, otherwise use global tracer provider
([#402](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/402))
- `opentelemetry-instrumentation-wsgi` Replaced `name_callback` with `request_hook`
and `response_hook` callbacks.
([#424](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/424))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,17 +83,20 @@ def _instrument(self, **kwargs):
tracer_provider=tracer_provider,
)

# pylint:disable=no-self-use
def _uninstrument(self, **kwargs):
""""Disable aiopg instrumentation"""
wrappers.unwrap_connect()
wrappers.unwrap_create_pool()

# pylint:disable=no-self-use
def instrument_connection(self, connection):
def instrument_connection(self, connection, tracer_provider=None):
"""Enable instrumentation in a aiopg connection.
Args:
connection: The connection to instrument.
tracer_provider: The optional tracer provider to use. If omitted
the current globally configured one is used.
Returns:
An instrumented connection.
Expand All @@ -103,6 +106,8 @@ def instrument_connection(self, connection):
connection,
self._DATABASE_SYSTEM,
self._CONNECTION_ATTRIBUTES,
version=__version__,
tracer_provider=tracer_provider,
)

def uninstrument_connection(self, connection):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ async def traced_execution(
else self._db_api_integration.name
)

with self._db_api_integration.get_tracer().start_as_current_span(
with self._db_api_integration._tracer.start_as_current_span(
name, kind=SpanKind.CLIENT
) as span:
self._populate_span(span, cursor, *args)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,32 @@ def test_instrument_connection(self):
spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)

def test_custom_tracer_provider_instrument_connection(self):
resource = resources.Resource.create(
{"service.name": "db-test-service"}
)
result = self.create_tracer_provider(resource=resource)
tracer_provider, exporter = result

cnx = async_call(aiopg.connect(database="test"))

cnx = AiopgInstrumentor().instrument_connection(
cnx, tracer_provider=tracer_provider
)

cursor = async_call(cnx.cursor())
query = "SELECT * FROM test"
async_call(cursor.execute(query))

spans_list = exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)
span = spans_list[0]

self.assertEqual(
span.resource.attributes["service.name"], "db-test-service"
)
self.assertIs(span.resource, resource)

def test_uninstrument_connection(self):
AiopgInstrumentor().instrument()
cnx = async_call(aiopg.connect(database="test"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,11 +167,19 @@ class OpenTelemetryMiddleware:
and a tuple, representing the desired span name and a
dictionary with any additional span attributes to set.
Optional: Defaults to get_default_span_details.
tracer_provider: The optional tracer provider to use. If omitted
the current globally configured one is used.
"""

def __init__(self, app, excluded_urls=None, span_details_callback=None):
def __init__(
self,
app,
excluded_urls=None,
span_details_callback=None,
tracer_provider=None,
):
self.app = guarantee_single_callable(app)
self.tracer = trace.get_tracer(__name__, __version__)
self.tracer = trace.get_tracer(__name__, __version__, tracer_provider)
self.span_details_callback = (
span_details_callback or get_default_span_details
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@

import opentelemetry.instrumentation.asgi as otel_asgi
from opentelemetry import trace as trace_api
from opentelemetry.sdk import resources
from opentelemetry.semconv.trace import SpanAttributes
from opentelemetry.test.asgitestutil import (
AsgiTestBase,
setup_testing_defaults,
)
from opentelemetry.test.test_base import TestBase


async def http_app(scope, receive, send):
Expand Down Expand Up @@ -211,6 +213,22 @@ def update_expected_span_name(expected):
outputs = self.get_all_output()
self.validate_outputs(outputs, modifiers=[update_expected_span_name])

def test_custom_tracer_provider_otel_asgi(self):
resource = resources.Resource.create({"service-test-key": "value"})
result = TestBase.create_tracer_provider(resource=resource)
tracer_provider, exporter = result

app = otel_asgi.OpenTelemetryMiddleware(
simple_asgi, tracer_provider=tracer_provider
)
self.seed_app(app)
self.send_default_request()
span_list = exporter.get_finished_spans()
for span in span_list:
self.assertEqual(
span.resource.attributes["service-test-key"], "value"
)

def test_behavior_with_scope_server_as_none(self):
"""Test that middleware is ok when server is none in scope."""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,6 @@
from opentelemetry.trace import SpanKind
from opentelemetry.trace.status import Status, StatusCode

_APPLIED = "_opentelemetry_tracer"


def _hydrate_span_from_args(connection, query, parameters) -> dict:
"""Get network and database attributes from connection."""
Expand Down Expand Up @@ -98,16 +96,11 @@ class AsyncPGInstrumentor(BaseInstrumentor):
def __init__(self, capture_parameters=False):
super().__init__()
self.capture_parameters = capture_parameters
self._tracer = None

def _instrument(self, **kwargs):
tracer_provider = kwargs.get(
"tracer_provider", trace.get_tracer_provider()
)
setattr(
asyncpg,
_APPLIED,
tracer_provider.get_tracer("asyncpg", __version__),
)
tracer_provider = kwargs.get("tracer_provider")
self._tracer = trace.get_tracer(__name__, __version__, tracer_provider)

for method in [
"Connection.execute",
Expand All @@ -121,7 +114,6 @@ def _instrument(self, **kwargs):
)

def _uninstrument(self, **__):
delattr(asyncpg, _APPLIED)
for method in [
"execute",
"executemany",
Expand All @@ -132,13 +124,14 @@ def _uninstrument(self, **__):
unwrap(asyncpg.Connection, method)

async def _do_execute(self, func, instance, args, kwargs):
tracer = getattr(asyncpg, _APPLIED)

exception = None
params = getattr(instance, "_params", {})
name = args[0] if args[0] else params.get("database", "postgresql")

with tracer.start_as_current_span(name, kind=SpanKind.CLIENT) as span:
with self._tracer.start_as_current_span(
name, kind=SpanKind.CLIENT
) as span:
if span.is_recording():
span_attributes = _hydrate_span_from_args(
instance,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,10 @@
import asyncpg
from asyncpg import Connection

from opentelemetry.instrumentation.asyncpg import AsyncPGInstrumentor
from opentelemetry.test.test_base import TestBase


class TestAsyncPGInstrumentation(TestBase):
def test_instrumentation_flags(self):
AsyncPGInstrumentor().instrument()
self.assertTrue(hasattr(asyncpg, "_opentelemetry_tracer"))
AsyncPGInstrumentor().uninstrument()
self.assertFalse(hasattr(asyncpg, "_opentelemetry_tracer"))

def test_duplicated_instrumentation(self):
AsyncPGInstrumentor().instrument()
AsyncPGInstrumentor().instrument()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,21 +228,18 @@ def __init__(
}
self._name = name
self._version = version
self._tracer_provider = tracer_provider
self._tracer = get_tracer(
self._name,
instrumenting_library_version=self._version,
tracer_provider=tracer_provider,
)
self.capture_parameters = capture_parameters
self.database_system = database_system
self.connection_props = {}
self.span_attributes = {}
self.name = ""
self.database = ""

def get_tracer(self):
return get_tracer(
self._name,
instrumenting_library_version=self._version,
tracer_provider=self._tracer_provider,
)

def wrapped_connection(
self,
connect_method: typing.Callable[..., typing.Any],
Expand Down Expand Up @@ -370,7 +367,7 @@ def traced_execution(
else self._db_api_integration.name
)

with self._db_api_integration.get_tracer().start_as_current_span(
with self._db_api_integration._tracer.start_as_current_span(
name, kind=SpanKind.CLIENT
) as span:
self._populate_span(span, cursor, *args)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from opentelemetry import trace as trace_api
from opentelemetry.instrumentation import dbapi
from opentelemetry.sdk import resources
from opentelemetry.semconv.trace import SpanAttributes
from opentelemetry.test.test_base import TestBase

Expand All @@ -41,7 +42,7 @@ def test_span_succeeded(self):
"user": "user",
}
db_integration = dbapi.DatabaseApiIntegration(
self.tracer, "testcomponent", connection_attributes
"testname", "testcomponent", connection_attributes
)
mock_connection = db_integration.wrapped_connection(
mock_connect, {}, connection_props
Expand Down Expand Up @@ -73,7 +74,7 @@ def test_span_succeeded(self):

def test_span_name(self):
db_integration = dbapi.DatabaseApiIntegration(
self.tracer, "testcomponent", {}
"testname", "testcomponent", {}
)
mock_connection = db_integration.wrapped_connection(
mock_connect, {}, {}
Expand Down Expand Up @@ -106,7 +107,7 @@ def test_span_succeeded_with_capture_of_statement_parameters(self):
"user": "user",
}
db_integration = dbapi.DatabaseApiIntegration(
self.tracer,
"testname",
"testcomponent",
connection_attributes,
capture_parameters=True,
Expand Down Expand Up @@ -155,12 +156,10 @@ def test_span_not_recording(self):
"host": "server_host",
"user": "user",
}
mock_tracer = mock.Mock()
mock_span = mock.Mock()
mock_span.is_recording.return_value = False
mock_tracer.start_span.return_value = mock_span
db_integration = dbapi.DatabaseApiIntegration(
mock_tracer, "testcomponent", connection_attributes
"testname", "testcomponent", connection_attributes
)
mock_connection = db_integration.wrapped_connection(
mock_connect, {}, connection_props
Expand Down Expand Up @@ -192,9 +191,30 @@ def test_span_failed(self):
self.assertIs(span.status.status_code, trace_api.StatusCode.ERROR)
self.assertEqual(span.status.description, "Exception: Test Exception")

def test_custom_tracer_provider_dbapi(self):
resource = resources.Resource.create({"db-resource-key": "value"})
result = self.create_tracer_provider(resource=resource)
tracer_provider, exporter = result

db_integration = dbapi.DatabaseApiIntegration(
self.tracer, "testcomponent", tracer_provider=tracer_provider
)
mock_connection = db_integration.wrapped_connection(
mock_connect, {}, {}
)
cursor = mock_connection.cursor()
with self.assertRaises(Exception):
cursor.execute("Test query", throw_exception=True)

spans_list = exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)
span = spans_list[0]
self.assertEqual(span.resource.attributes["db-resource-key"], "value")
self.assertIs(span.status.status_code, trace_api.StatusCode.ERROR)

def test_executemany(self):
db_integration = dbapi.DatabaseApiIntegration(
self.tracer, "testcomponent"
"testname", "testcomponent"
)
mock_connection = db_integration.wrapped_connection(
mock_connect, {}, {}
Expand All @@ -210,7 +230,7 @@ def test_executemany(self):

def test_callproc(self):
db_integration = dbapi.DatabaseApiIntegration(
self.tracer, "testcomponent"
"testname", "testcomponent"
)
mock_connection = db_integration.wrapped_connection(
mock_connect, {}, {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def response_hook(span, request, response):
from opentelemetry.instrumentation.django.middleware import _DjangoMiddleware
from opentelemetry.instrumentation.django.version import __version__
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
from opentelemetry.trace import get_tracer

_logger = getLogger(__name__)

Expand All @@ -105,6 +106,13 @@ def _instrument(self, **kwargs):
if environ.get(OTEL_PYTHON_DJANGO_INSTRUMENT) == "False":
return

tracer_provider = kwargs.get("tracer_provider")
tracer = get_tracer(
__name__, __version__, tracer_provider=tracer_provider,
)

_DjangoMiddleware._tracer = tracer

_DjangoMiddleware._otel_request_hook = kwargs.pop("request_hook", None)
_DjangoMiddleware._otel_response_hook = kwargs.pop(
"response_hook", None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
)
from opentelemetry.propagate import extract
from opentelemetry.semconv.trace import SpanAttributes
from opentelemetry.trace import Span, SpanKind, get_tracer, use_span
from opentelemetry.trace import Span, SpanKind, use_span
from opentelemetry.util.http import get_excluded_urls, get_traced_request_attrs

try:
Expand Down Expand Up @@ -82,6 +82,7 @@ class _DjangoMiddleware(MiddlewareMixin):

_traced_request_attrs = get_traced_request_attrs("DJANGO")
_excluded_urls = get_excluded_urls("DJANGO")
_tracer = None

_otel_request_hook: Callable[[Span, HttpRequest], None] = None
_otel_response_hook: Callable[
Expand Down Expand Up @@ -125,9 +126,7 @@ def process_request(self, request):

token = attach(extract(request_meta, getter=wsgi_getter))

tracer = get_tracer(__name__, __version__)

span = tracer.start_span(
span = self._tracer.start_span(
self._get_span_name(request),
kind=SpanKind.SERVER,
start_time=request_meta.get(
Expand Down
Loading

0 comments on commit 3ec7736

Please sign in to comment.