Skip to content

Commit

Permalink
Move all key access to a single spot and deduplicate context managers
Browse files Browse the repository at this point in the history
  • Loading branch information
adriangb committed Nov 14, 2023
1 parent 7758d5e commit 6ae8d59
Show file tree
Hide file tree
Showing 21 changed files with 122 additions and 231 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -210,10 +210,9 @@ def _patched_api_call(self, original_func, instance, args, kwargs):
raise
else:
_apply_response_attributes(span, result)
_safe_invoke(extension.on_success, span, result)
finally:
_safe_invoke(extension.after_service_call)
_safe_invoke(extension.on_success, span, result)
finally:
_safe_invoke(extension.after_service_call)
self._call_response_hook(span, call_context, result)

return result
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

from opentelemetry import trace as trace_api
from opentelemetry.instrumentation.botocore import BotocoreInstrumentor
from opentelemetry.instrumentation.utils import suppress_instrumentation
from opentelemetry.instrumentation.utils import suppress_http_instrumentation, suppress_instrumentation
from opentelemetry.propagate import get_global_textmap, set_global_textmap
from opentelemetry.propagators.aws.aws_xray_propagator import TRACE_HEADER_KEY
from opentelemetry.semconv.trace import SpanAttributes
Expand Down Expand Up @@ -342,7 +342,7 @@ def test_suppress_instrumentation_xray_client(self):

@mock_xray
def test_suppress_http_instrumentation_xray_client(self):
with suppress_instrumentation():
with suppress_http_instrumentation():
xray_client.put_trace_segments(TraceSegmentDocuments=["str1"])
xray_client.put_trace_segments(TraceSegmentDocuments=["str2"])
self.assertEqual(2, len(self.get_finished_spans()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@
from wrapt import wrap_function_wrapper

from opentelemetry import trace
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
from opentelemetry.instrumentation.cassandra.package import _instruments
from opentelemetry.instrumentation.cassandra.version import __version__
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
from opentelemetry.instrumentation.utils import unwrap
from opentelemetry.semconv.trace import SpanAttributes

Expand All @@ -70,10 +70,7 @@ def _traced_execute_async(func, instance, args, kwargs):
if span.is_recording():
span.set_attribute(SpanAttributes.DB_NAME, instance.keyspace)
span.set_attribute(SpanAttributes.DB_SYSTEM, "cassandra")
span.set_attribute(
SpanAttributes.NET_PEER_NAME,
instance.cluster.contact_points,
)
span.set_attribute(SpanAttributes.NET_PEER_NAME, instance.cluster.contact_points)

if include_db_statement:
query = args[0]
Expand All @@ -82,9 +79,7 @@ def _traced_execute_async(func, instance, args, kwargs):
response = func(*args, **kwargs)
return response

wrap_function_wrapper(
"cassandra.cluster", "Session.execute_async", _traced_execute_async
)
wrap_function_wrapper("cassandra.cluster", "Session.execute_async", _traced_execute_async)


class CassandraInstrumentor(BaseInstrumentor):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,25 +46,15 @@ def tearDown(self):
def test_instrument_uninstrument(self):
instrumentation = CassandraInstrumentor()
instrumentation.instrument()
self.assertTrue(
isinstance(
cassandra.cluster.Session.execute_async, BoundFunctionWrapper
)
)
self.assertTrue(isinstance(cassandra.cluster.Session.execute_async, BoundFunctionWrapper))

instrumentation.uninstrument()
self.assertFalse(
isinstance(
cassandra.cluster.Session.execute_async, BoundFunctionWrapper
)
)
self.assertFalse(isinstance(cassandra.cluster.Session.execute_async, BoundFunctionWrapper))

@mock.patch("cassandra.cluster.Cluster.connect")
@mock.patch("cassandra.cluster.Session.__init__")
@mock.patch("cassandra.cluster.Session._create_response_future")
def test_instrumentor(
self, mock_create_response_future, mock_session_init, mock_connect
):
def test_instrumentor(self, mock_create_response_future, mock_session_init, mock_connect):
mock_create_response_future.return_value = mock.Mock()
mock_session_init.return_value = None
mock_connect.return_value = cassandra.cluster.Session()
Expand Down Expand Up @@ -95,9 +85,7 @@ def test_instrumentor(
@mock.patch("cassandra.cluster.Cluster.connect")
@mock.patch("cassandra.cluster.Session.__init__")
@mock.patch("cassandra.cluster.Session._create_response_future")
def test_custom_tracer_provider(
self, mock_create_response_future, mock_session_init, mock_connect
):
def test_custom_tracer_provider(self, mock_create_response_future, mock_session_init, mock_connect):
mock_create_response_future.return_value = mock.Mock()
mock_session_init.return_value = None
mock_connect.return_value = cassandra.cluster.Session()
Expand All @@ -119,9 +107,7 @@ def test_custom_tracer_provider(
@mock.patch("cassandra.cluster.Cluster.connect")
@mock.patch("cassandra.cluster.Session.__init__")
@mock.patch("cassandra.cluster.Session._create_response_future")
def test_instrument_connection_no_op_tracer_provider(
self, mock_create_response_future, mock_session_init, mock_connect
):
def test_instrument_connection_no_op_tracer_provider(self, mock_create_response_future, mock_session_init, mock_connect):
mock_create_response_future.return_value = mock.Mock()
mock_session_init.return_value = None
mock_connect.return_value = cassandra.cluster.Session()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ def add(x, y):
from timeit import default_timer
from typing import Collection, Iterable

from billiard import VERSION
from billiard.einfo import ExceptionInfo
from celery import signals # pylint: disable=no-name-in-module

Expand All @@ -77,6 +76,8 @@ def add(x, y):
from opentelemetry.propagators.textmap import Getter
from opentelemetry.semconv.trace import SpanAttributes
from opentelemetry.trace.status import Status, StatusCode
from billiard import VERSION


if VERSION >= (4, 0, 1):
from billiard.einfo import ExceptionWithTraceback
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@

import logging

from billiard import VERSION
from celery import registry # pylint: disable=no-name-in-module
from billiard import VERSION

from opentelemetry.semconv.trace import SpanAttributes

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,8 @@ def instrument_consumer(consumer: Consumer, tracer_provider=None)
from .package import _instruments
from .utils import (
KafkaPropertiesExtractor,
_create_new_consume_span,
_end_current_consume_span,
_create_new_consume_span,
_enrich_span,
_get_span_name,
_kafka_getter,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
from typing import List, Optional

from opentelemetry import context, propagate
from opentelemetry.trace import SpanKind, Link
from opentelemetry.propagators import textmap
from opentelemetry.semconv.trace import (
MessagingDestinationKindValues,
MessagingOperationValues,
SpanAttributes,
)
from opentelemetry.trace import Link, SpanKind

_LOG = getLogger(__name__)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@

# pylint: disable=no-name-in-module

from opentelemetry.semconv.trace import (
SpanAttributes,
MessagingDestinationKindValues,
)
from opentelemetry.test.test_base import TestBase
from .utils import MockConsumer, MockedMessage

from confluent_kafka import Consumer, Producer

from opentelemetry.instrumentation.confluent_kafka import (
Expand All @@ -25,13 +32,6 @@
KafkaContextGetter,
KafkaContextSetter,
)
from opentelemetry.semconv.trace import (
MessagingDestinationKindValues,
SpanAttributes,
)
from opentelemetry.test.test_base import TestBase

from .utils import MockConsumer, MockedMessage


class TestConfluentKafka(TestBase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SANITIZE_FIELDS,
OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SERVER_REQUEST,
OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SERVER_RESPONSE,
OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SERVER_RESPONSE,
OTEL_PYTHON_INSTRUMENTATION_HTTP_CAPTURE_ALL_METHODS,
get_excluded_urls,
)
Expand Down Expand Up @@ -327,9 +328,7 @@ def test_flask_metric_values(self):
if isinstance(point, NumberDataPoint):
self.assertEqual(point.value, 0)

def _assert_basic_metric(
self, expected_duration_attributes, expected_requests_count_attributes
):
def _assert_basic_metric(self, expected_duration_attributes, expected_requests_count_attributes):
metrics_list = self.memory_metrics_reader.get_metrics_data()
for resource_metric in metrics_list.resource_metrics:
for scope_metrics in resource_metric.scope_metrics:
Expand Down Expand Up @@ -395,7 +394,7 @@ def test_basic_metric_nonstandard_http_method_success(self):
)

@patch.dict(
"os.environ",
"os.environ",
{
OTEL_PYTHON_INSTRUMENTATION_HTTP_CAPTURE_ALL_METHODS: "1",
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,10 +170,7 @@ def response_hook(span, request, response):
from opentelemetry.instrumentation.httpx.package import _instruments
from opentelemetry.instrumentation.httpx.version import __version__
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
from opentelemetry.instrumentation.utils import (
http_status_to_status_code,
is_http_instrumentation_enabled,
)
from opentelemetry.instrumentation.utils import http_status_to_status_code
from opentelemetry.propagate import inject
from opentelemetry.semconv.trace import SpanAttributes
from opentelemetry.trace import SpanKind, TracerProvider, get_tracer
Expand Down Expand Up @@ -319,7 +316,7 @@ def handle_request(
httpx.Response,
]:
"""Add request info to span."""
if not is_http_instrumentation_enabled():
if context.get_value("suppress_instrumentation"):
return self._transport.handle_request(*args, **kwargs)

method, url, headers, stream, extensions = _extract_parameters(
Expand Down Expand Up @@ -412,7 +409,7 @@ async def handle_async_request(
httpx.Response,
]:
"""Add request info to span."""
if not is_http_instrumentation_enabled():
if context.get_value("suppress_instrumentation"):
return await self._transport.handle_async_request(*args, **kwargs)

method, url, headers, stream, extensions = _extract_parameters(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@
from platform import python_implementation
from unittest import mock

from opentelemetry.instrumentation.system_metrics import (
SystemMetricsInstrumentor,
)
from opentelemetry.sdk.metrics import MeterProvider
from opentelemetry.sdk.metrics.export import InMemoryMetricReader
from opentelemetry.test.test_base import TestBase

from opentelemetry.instrumentation.system_metrics import (
SystemMetricsInstrumentor,
)


def _mock_netconnection():
NetConnection = namedtuple(
Expand Down Expand Up @@ -170,9 +171,9 @@ def _assert_metrics(self, observer_name, reader, expected):
for data_point in metric.data.data_points:
for expect in expected:
if (
dict(data_point.attributes)
== expect.attributes
and metric.name == observer_name
dict(data_point.attributes)
== expect.attributes
and metric.name == observer_name
):
self.assertEqual(
data_point.value,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,13 @@ def response_hook(span, request_obj, response)
Request,
)

from opentelemetry import context
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
from opentelemetry.instrumentation.urllib.package import _instruments
from opentelemetry.instrumentation.urllib.version import __version__
from opentelemetry.instrumentation.utils import (
http_status_to_status_code,
is_http_instrumentation_enabled,
suppress_http_instrumentation,
http_status_to_status_code,
)
from opentelemetry.metrics import Histogram, get_meter
from opentelemetry.propagate import inject
Expand Down Expand Up @@ -233,8 +232,8 @@ def _instrumented_open_call(
inject(headers)

with suppress_http_instrumentation():
start_time = default_timer()
try:
start_time = default_timer()
result = call_wrapped() # *** PROCEED
except Exception as exc: # pylint: disable=W0703
exception = exc
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -302,9 +302,7 @@ def collect_request_attributes(environ):
"""

result = {
SpanAttributes.HTTP_METHOD: sanitize_method(
environ.get("REQUEST_METHOD")
),
SpanAttributes.HTTP_METHOD: sanitize_method(environ.get("REQUEST_METHOD")),
SpanAttributes.HTTP_SERVER_NAME: environ.get("SERVER_NAME"),
SpanAttributes.HTTP_SCHEME: environ.get("wsgi.url_scheme"),
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -286,26 +286,22 @@ def test_wsgi_metrics(self):
self.assertTrue(number_data_point_seen and histogram_data_point_seen)

def test_nonstandard_http_method(self):
self.environ["REQUEST_METHOD"] = "NONSTANDARD"
self.environ["REQUEST_METHOD"]= "NONSTANDARD"
app = otel_wsgi.OpenTelemetryMiddleware(simple_wsgi)
response = app(self.environ, self.start_response)
self.validate_response(
response, span_name="UNKNOWN /", http_method="UNKNOWN"
)
self.validate_response(response, span_name="UNKNOWN /", http_method="UNKNOWN")

@mock.patch.dict(
"os.environ",
"os.environ",
{
OTEL_PYTHON_INSTRUMENTATION_HTTP_CAPTURE_ALL_METHODS: "1",
},
)
def test_nonstandard_http_method_allowed(self):
self.environ["REQUEST_METHOD"] = "NONSTANDARD"
self.environ["REQUEST_METHOD"]= "NONSTANDARD"
app = otel_wsgi.OpenTelemetryMiddleware(simple_wsgi)
response = app(self.environ, self.start_response)
self.validate_response(
response, span_name="NONSTANDARD /", http_method="NONSTANDARD"
)
self.validate_response(response, span_name="NONSTANDARD /", http_method="NONSTANDARD")

def test_default_span_name_missing_path_info(self):
"""Test that default span_names with missing path info."""
Expand Down
Loading

0 comments on commit 6ae8d59

Please sign in to comment.