Skip to content

Commit

Permalink
Clean up use of suppress_instrumentation in context and fix httpx bug
Browse files Browse the repository at this point in the history
  • Loading branch information
adriangb committed Nov 14, 2023
1 parent ba190a8 commit ef2c7cc
Show file tree
Hide file tree
Showing 25 changed files with 166 additions and 321 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ def response_hook(span: Span, params: typing.Union[
from opentelemetry.instrumentation.aiohttp_client.version import __version__
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
from opentelemetry.instrumentation.utils import (
_SUPPRESS_INSTRUMENTATION_KEY,
http_status_to_status_code,
is_instrumentation_enabled,
unwrap,
)
from opentelemetry.propagate import inject
Expand Down Expand Up @@ -179,7 +179,7 @@ async def on_request_start(
trace_config_ctx: types.SimpleNamespace,
params: aiohttp.TraceRequestStartParams,
):
if context_api.get_value(_SUPPRESS_INSTRUMENTATION_KEY):
if not is_instrumentation_enabled():
trace_config_ctx.span = None
return

Expand Down Expand Up @@ -282,7 +282,7 @@ def _instrument(

# pylint:disable=unused-argument
def instrumented_init(wrapped, instance, args, kwargs):
if context_api.get_value(_SUPPRESS_INSTRUMENTATION_KEY):
if not is_instrumentation_enabled():
return wrapped(*args, **kwargs)

client_trace_configs = list(kwargs.get("trace_configs") or [])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from opentelemetry.instrumentation.aiohttp_client import (
AioHttpClientInstrumentor,
)
from opentelemetry.instrumentation.utils import _SUPPRESS_INSTRUMENTATION_KEY
from opentelemetry.instrumentation.utils import suppress_instrumentation
from opentelemetry.semconv.trace import SpanAttributes
from opentelemetry.test.test_base import TestBase
from opentelemetry.trace import Span, StatusCode
Expand Down Expand Up @@ -506,25 +506,17 @@ async def uninstrument_request(server: aiohttp.test_utils.TestServer):
self.assert_spans(1)

def test_suppress_instrumentation(self):
token = context.attach(
context.set_value(_SUPPRESS_INSTRUMENTATION_KEY, True)
)
try:
with suppress_instrumentation():
run_with_test_server(
self.get_default_request(), self.URL, self.default_handler
)
finally:
context.detach(token)
self.assert_spans(0)

@staticmethod
async def suppressed_request(server: aiohttp.test_utils.TestServer):
async with aiohttp.test_utils.TestClient(server) as client:
token = context.attach(
context.set_value(_SUPPRESS_INSTRUMENTATION_KEY, True)
)
await client.get(TestAioHttpClientInstrumentor.URL)
context.detach(token)
with suppress_instrumentation():
await client.get(TestAioHttpClientInstrumentor.URL)

def test_suppress_instrumentation_after_creation(self):
run_with_test_server(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from opentelemetry import context, propagate, trace
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
from opentelemetry.instrumentation.utils import (
_SUPPRESS_INSTRUMENTATION_KEY,
is_instrumentation_enabled,
unwrap,
)
from opentelemetry.propagators.textmap import CarrierT, Getter, Setter
Expand Down Expand Up @@ -218,7 +218,7 @@ def _create_processing_span(

def _wrap_send_message(self, sqs_class: type) -> None:
def send_wrapper(wrapped, instance, args, kwargs):
if context.get_value(_SUPPRESS_INSTRUMENTATION_KEY):
if not is_instrumentation_enabled():
return wrapped(*args, **kwargs)
queue_url = kwargs.get("QueueUrl")
# The method expect QueueUrl and Entries params, so if they are None, we call wrapped to receive the
Expand Down Expand Up @@ -252,7 +252,7 @@ def send_batch_wrapper(wrapped, instance, args, kwargs):
# The method expect QueueUrl and Entries params, so if they are None, we call wrapped to receive the
# original exception
if (
context.get_value(_SUPPRESS_INSTRUMENTATION_KEY)
not is_instrumentation_enabled()
or not queue_url
or not entries
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,6 @@ def response_hook(span, service_name, operation_name, result):
from wrapt import wrap_function_wrapper

from opentelemetry import context as context_api

# FIXME: fix the importing of this private attribute when the location of the _SUPPRESS_HTTP_INSTRUMENTATION_KEY is defined.
from opentelemetry.context import _SUPPRESS_HTTP_INSTRUMENTATION_KEY
from opentelemetry.instrumentation.botocore.extensions import _find_extension
from opentelemetry.instrumentation.botocore.extensions.types import (
_AwsSdkCallContext,
Expand All @@ -98,7 +95,8 @@ def response_hook(span, service_name, operation_name, result):
from opentelemetry.instrumentation.botocore.version import __version__
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
from opentelemetry.instrumentation.utils import (
_SUPPRESS_INSTRUMENTATION_KEY,
is_instrumentation_enabled,
suppress_http_instrumentation,
unwrap,
)
from opentelemetry.propagators.aws.aws_xray_propagator import AwsXRayPropagator
Expand Down Expand Up @@ -171,7 +169,7 @@ def _patched_endpoint_prepare_request(

# pylint: disable=too-many-branches
def _patched_api_call(self, original_func, instance, args, kwargs):
if context_api.get_value(_SUPPRESS_INSTRUMENTATION_KEY):
if not is_instrumentation_enabled():
return original_func(*args, **kwargs)

call_context = _determine_call_context(instance, args)
Expand Down Expand Up @@ -200,25 +198,21 @@ def _patched_api_call(self, original_func, instance, args, kwargs):
_safe_invoke(extension.before_service_call, span)
self._call_request_hook(span, call_context)

token = context_api.attach(
context_api.set_value(_SUPPRESS_HTTP_INSTRUMENTATION_KEY, True)
)

result = None
try:
result = original_func(*args, **kwargs)
except ClientError as error:
result = getattr(error, "response", None)
_apply_response_attributes(span, result)
_safe_invoke(extension.on_error, span, error)
raise
else:
_apply_response_attributes(span, result)
_safe_invoke(extension.on_success, span, result)
with suppress_http_instrumentation():
result = None
try:
result = original_func(*args, **kwargs)
except ClientError as error:
result = getattr(error, "response", None)
_apply_response_attributes(span, result)
_safe_invoke(extension.on_error, span, error)
raise
else:
_apply_response_attributes(span, result)
_safe_invoke(extension.on_success, span, result)
finally:
context_api.detach(token)
_safe_invoke(extension.after_service_call)

self._call_response_hook(span, call_context, result)

return result
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,8 @@
)

from opentelemetry import trace as trace_api
from opentelemetry.context import (
_SUPPRESS_HTTP_INSTRUMENTATION_KEY,
attach,
detach,
set_value,
)
from opentelemetry.instrumentation.botocore import BotocoreInstrumentor
from opentelemetry.instrumentation.utils import _SUPPRESS_INSTRUMENTATION_KEY
from opentelemetry.instrumentation.utils import suppress_http_instrumentation, suppress_instrumentation
from opentelemetry.propagate import get_global_textmap, set_global_textmap
from opentelemetry.propagators.aws.aws_xray_propagator import TRACE_HEADER_KEY
from opentelemetry.semconv.trace import SpanAttributes
Expand Down Expand Up @@ -341,23 +335,16 @@ def check_headers(**kwargs):
@mock_xray
def test_suppress_instrumentation_xray_client(self):
xray_client = self._make_client("xray")
token = attach(set_value(_SUPPRESS_INSTRUMENTATION_KEY, True))
try:
with suppress_instrumentation():
xray_client.put_trace_segments(TraceSegmentDocuments=["str1"])
xray_client.put_trace_segments(TraceSegmentDocuments=["str2"])
finally:
detach(token)
self.assertEqual(0, len(self.get_finished_spans()))

@mock_xray
def test_suppress_http_instrumentation_xray_client(self):
xray_client = self._make_client("xray")
token = attach(set_value(_SUPPRESS_HTTP_INSTRUMENTATION_KEY, True))
try:
with suppress_http_instrumentation():
xray_client.put_trace_segments(TraceSegmentDocuments=["str1"])
xray_client.put_trace_segments(TraceSegmentDocuments=["str2"])
finally:
detach(token)
self.assertEqual(2, len(self.get_finished_spans()))

@mock_s3
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,7 @@ def _traced_execute_async(func, instance, args, kwargs):
if span.is_recording():
span.set_attribute(SpanAttributes.DB_NAME, instance.keyspace)
span.set_attribute(SpanAttributes.DB_SYSTEM, "cassandra")
span.set_attribute(
SpanAttributes.NET_PEER_NAME,
instance.cluster.contact_points,
)
span.set_attribute(SpanAttributes.NET_PEER_NAME, instance.cluster.contact_points)

if include_db_statement:
query = args[0]
Expand All @@ -82,9 +79,7 @@ def _traced_execute_async(func, instance, args, kwargs):
response = func(*args, **kwargs)
return response

wrap_function_wrapper(
"cassandra.cluster", "Session.execute_async", _traced_execute_async
)
wrap_function_wrapper("cassandra.cluster", "Session.execute_async", _traced_execute_async)


class CassandraInstrumentor(BaseInstrumentor):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,25 +46,15 @@ def tearDown(self):
def test_instrument_uninstrument(self):
instrumentation = CassandraInstrumentor()
instrumentation.instrument()
self.assertTrue(
isinstance(
cassandra.cluster.Session.execute_async, BoundFunctionWrapper
)
)
self.assertTrue(isinstance(cassandra.cluster.Session.execute_async, BoundFunctionWrapper))

instrumentation.uninstrument()
self.assertFalse(
isinstance(
cassandra.cluster.Session.execute_async, BoundFunctionWrapper
)
)
self.assertFalse(isinstance(cassandra.cluster.Session.execute_async, BoundFunctionWrapper))

@mock.patch("cassandra.cluster.Cluster.connect")
@mock.patch("cassandra.cluster.Session.__init__")
@mock.patch("cassandra.cluster.Session._create_response_future")
def test_instrumentor(
self, mock_create_response_future, mock_session_init, mock_connect
):
def test_instrumentor(self, mock_create_response_future, mock_session_init, mock_connect):
mock_create_response_future.return_value = mock.Mock()
mock_session_init.return_value = None
mock_connect.return_value = cassandra.cluster.Session()
Expand Down Expand Up @@ -95,9 +85,7 @@ def test_instrumentor(
@mock.patch("cassandra.cluster.Cluster.connect")
@mock.patch("cassandra.cluster.Session.__init__")
@mock.patch("cassandra.cluster.Session._create_response_future")
def test_custom_tracer_provider(
self, mock_create_response_future, mock_session_init, mock_connect
):
def test_custom_tracer_provider(self, mock_create_response_future, mock_session_init, mock_connect):
mock_create_response_future.return_value = mock.Mock()
mock_session_init.return_value = None
mock_connect.return_value = cassandra.cluster.Session()
Expand All @@ -119,9 +107,7 @@ def test_custom_tracer_provider(
@mock.patch("cassandra.cluster.Cluster.connect")
@mock.patch("cassandra.cluster.Session.__init__")
@mock.patch("cassandra.cluster.Session._create_response_future")
def test_instrument_connection_no_op_tracer_provider(
self, mock_create_response_future, mock_session_init, mock_connect
):
def test_instrument_connection_no_op_tracer_provider(self, mock_create_response_future, mock_session_init, mock_connect):
mock_create_response_future.return_value = mock.Mock()
mock_session_init.return_value = None
mock_connect.return_value = cassandra.cluster.Session()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -328,9 +328,7 @@ def test_flask_metric_values(self):
if isinstance(point, NumberDataPoint):
self.assertEqual(point.value, 0)

def _assert_basic_metric(
self, expected_duration_attributes, expected_requests_count_attributes
):
def _assert_basic_metric(self, expected_duration_attributes, expected_requests_count_attributes):
metrics_list = self.memory_metrics_reader.get_metrics_data()
for resource_metric in metrics_list.resource_metrics:
for scope_metrics in resource_metric.scope_metrics:
Expand Down Expand Up @@ -396,7 +394,7 @@ def test_basic_metric_nonstandard_http_method_success(self):
)

@patch.dict(
"os.environ",
"os.environ",
{
OTEL_PYTHON_INSTRUMENTATION_HTTP_CAPTURE_ALL_METHODS: "1",
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,11 @@
import grpc
from grpc.aio import ClientCallDetails

from opentelemetry import context
from opentelemetry.instrumentation.grpc._client import (
OpenTelemetryClientInterceptor,
_carrier_setter,
)
from opentelemetry.instrumentation.utils import _SUPPRESS_INSTRUMENTATION_KEY
from opentelemetry.instrumentation.utils import is_instrumentation_enabled
from opentelemetry.propagate import inject
from opentelemetry.semconv.trace import SpanAttributes
from opentelemetry.trace.status import Status, StatusCode
Expand Down Expand Up @@ -139,9 +138,10 @@ async def _wrap_stream_response(self, span, call):
span.end()

def tracing_skipped(self, client_call_details):
return context.get_value(
_SUPPRESS_INSTRUMENTATION_KEY
) or not self.rpc_matches_filters(client_call_details)
return (
not is_instrumentation_enabled()
or not self.rpc_matches_filters(client_call_details)
)

def rpc_matches_filters(self, client_call_details):
return self._filter is None or self._filter(client_call_details)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from opentelemetry import context, trace
from opentelemetry.instrumentation.grpc import grpcext
from opentelemetry.instrumentation.grpc._utilities import RpcInfo
from opentelemetry.instrumentation.utils import _SUPPRESS_INSTRUMENTATION_KEY
from opentelemetry.instrumentation.utils import is_instrumentation_enabled
from opentelemetry.propagate import inject
from opentelemetry.propagators.textmap import Setter
from opentelemetry.semconv.trace import SpanAttributes
Expand Down Expand Up @@ -123,7 +123,7 @@ def _trace_result(self, span, rpc_info, result):
return result

def _intercept(self, request, metadata, client_info, invoker):
if context.get_value(_SUPPRESS_INSTRUMENTATION_KEY):
if not is_instrumentation_enabled():
return invoker(request, metadata)

if not metadata:
Expand Down Expand Up @@ -219,7 +219,7 @@ def _intercept_server_stream(
def intercept_stream(
self, request_or_iterator, metadata, client_info, invoker
):
if context.get_value(_SUPPRESS_INSTRUMENTATION_KEY):
if not is_instrumentation_enabled():
return invoker(request_or_iterator, metadata)

if self._filter is not None and not self._filter(client_info):
Expand Down
Loading

0 comments on commit ef2c7cc

Please sign in to comment.