From 9d55b7fa4163c04b1c2f43ac51a296c4bea767e6 Mon Sep 17 00:00:00 2001 From: Leighton Chen Date: Fri, 20 Sep 2024 14:39:05 -0700 Subject: [PATCH] Fix setting custom `TracerProvider` bug with no global `TracerProvider` (#37469) --- .../CHANGELOG.md | 3 + .../exporter/export/trace/_exporter.py | 4 +- .../tests/trace/test_trace.py | 79 ++++++++++++++++--- 3 files changed, 74 insertions(+), 12 deletions(-) diff --git a/sdk/monitor/azure-monitor-opentelemetry-exporter/CHANGELOG.md b/sdk/monitor/azure-monitor-opentelemetry-exporter/CHANGELOG.md index f47ac2749f48..f1ce0af23df1 100644 --- a/sdk/monitor/azure-monitor-opentelemetry-exporter/CHANGELOG.md +++ b/sdk/monitor/azure-monitor-opentelemetry-exporter/CHANGELOG.md @@ -8,6 +8,9 @@ ### Bugs Fixed +- Fix setting custom `TracerProvider` bug + ([#37469](https://github.com/Azure/azure-sdk-for-python/pull/37469)) + ### Other Changes ## 1.0.0b29 (2024-09-10) diff --git a/sdk/monitor/azure-monitor-opentelemetry-exporter/azure/monitor/opentelemetry/exporter/export/trace/_exporter.py b/sdk/monitor/azure-monitor-opentelemetry-exporter/azure/monitor/opentelemetry/exporter/export/trace/_exporter.py index ede4fd08ff64..07cd207a9ee6 100644 --- a/sdk/monitor/azure-monitor-opentelemetry-exporter/azure/monitor/opentelemetry/exporter/export/trace/_exporter.py +++ b/sdk/monitor/azure-monitor-opentelemetry-exporter/azure/monitor/opentelemetry/exporter/export/trace/_exporter.py @@ -72,7 +72,7 @@ class AzureMonitorTraceExporter(BaseExporter, SpanExporter): """Azure Monitor Trace exporter for OpenTelemetry.""" def __init__(self, **kwargs: Any): - self._tracer_provider = kwargs.pop("tracer_provider", get_tracer_provider()) + self._tracer_provider = kwargs.pop("tracer_provider", None) super().__init__(**kwargs) def export(self, spans: Sequence[ReadableSpan], **kwargs: Any) -> SpanExportResult: # pylint: disable=unused-argument @@ -87,7 +87,7 @@ def export(self, spans: Sequence[ReadableSpan], **kwargs: Any) -> SpanExportResu if spans and self._should_collect_otel_resource_metric(): resource = None try: - tracer_provider = self._tracer_provider + tracer_provider = self._tracer_provider or get_tracer_provider() resource = tracer_provider.resource # type: ignore envelopes.append(self._get_otel_resource_envelope(resource)) except AttributeError as e: diff --git a/sdk/monitor/azure-monitor-opentelemetry-exporter/tests/trace/test_trace.py b/sdk/monitor/azure-monitor-opentelemetry-exporter/tests/trace/test_trace.py index a6708726d159..3d76c77e74c2 100644 --- a/sdk/monitor/azure-monitor-opentelemetry-exporter/tests/trace/test_trace.py +++ b/sdk/monitor/azure-monitor-opentelemetry-exporter/tests/trace/test_trace.py @@ -1,7 +1,6 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. - import json import os import platform @@ -67,14 +66,10 @@ def tearDownClass(cls): def test_constructor(self): """Test the constructor.""" tp = trace.TracerProvider() - set_tracer_provider(tp) exporter = AzureMonitorTraceExporter( connection_string="InstrumentationKey=4321abcd-5678-4efa-8abc-1234567890ab", ) - self.assertEqual( - exporter._tracer_provider, - tp, - ) + self.assertIsNone(exporter._tracer_provider) self.assertEqual( exporter._instrumentation_key, "4321abcd-5678-4efa-8abc-1234567890ab", @@ -83,15 +78,13 @@ def test_constructor(self): def test_constructor_tracer_provider(self): """Test the constructor.""" tp = trace.TracerProvider() - tp2 = trace.TracerProvider() - set_tracer_provider(tp) exporter = AzureMonitorTraceExporter( connection_string="InstrumentationKey=4321abcd-5678-4efa-8abc-1234567890ab", - tracer_provider=tp2, + tracer_provider=tp, ) self.assertEqual( exporter._tracer_provider, - tp2, + tp, ) self.assertEqual( exporter._instrumentation_key, @@ -157,6 +150,72 @@ def test_export_success(self): self.assertEqual(result, SpanExportResult.SUCCESS) self.assertEqual(storage_mock.call_count, 1) + def test_export_with_tracer_provider(self): + mock_resource = mock.Mock() + tp = trace.TracerProvider( + resource=mock_resource, + ) + exporter = AzureMonitorTraceExporter( + connection_string="InstrumentationKey=4321abcd-5678-4efa-8abc-1234567890ab", + tracer_provider=tp, + ) + test_span = trace._Span( + name="test", + context=SpanContext( + trace_id=36873507687745823477771305566750195431, + span_id=12030755672171557338, + is_remote=False, + ), + ) + test_span.start() + test_span.end() + with mock.patch( + "azure.monitor.opentelemetry.exporter.AzureMonitorTraceExporter._transmit" + ) as transmit: # noqa: E501 + transmit.return_value = ExportResult.SUCCESS + storage_mock = mock.Mock() + exporter._transmit_from_storage = storage_mock + with mock.patch( + "azure.monitor.opentelemetry.exporter.AzureMonitorTraceExporter._get_otel_resource_envelope" + ) as resource_patch: # noqa: E501 + result = exporter.export([test_span]) + resource_patch.assert_called_once_with(mock_resource) + self.assertEqual(result, SpanExportResult.SUCCESS) + self.assertEqual(storage_mock.call_count, 1) + + def test_export_with_tracer_provider_global(self): + mock_resource = mock.Mock() + tp = trace.TracerProvider( + resource=mock_resource, + ) + set_tracer_provider(tp) + exporter = AzureMonitorTraceExporter( + connection_string="InstrumentationKey=4321abcd-5678-4efa-8abc-1234567890ab", + ) + test_span = trace._Span( + name="test", + context=SpanContext( + trace_id=36873507687745823477771305566750195431, + span_id=12030755672171557338, + is_remote=False, + ), + ) + test_span.start() + test_span.end() + with mock.patch( + "azure.monitor.opentelemetry.exporter.AzureMonitorTraceExporter._transmit" + ) as transmit: # noqa: E501 + transmit.return_value = ExportResult.SUCCESS + storage_mock = mock.Mock() + exporter._transmit_from_storage = storage_mock + with mock.patch( + "azure.monitor.opentelemetry.exporter.AzureMonitorTraceExporter._get_otel_resource_envelope" + ) as resource_patch: # noqa: E501 + result = exporter.export([test_span]) + resource_patch.assert_called_once_with(mock_resource) + self.assertEqual(result, SpanExportResult.SUCCESS) + self.assertEqual(storage_mock.call_count, 1) + @mock.patch("azure.monitor.opentelemetry.exporter.export.trace._exporter._logger") def test_export_exception(self, logger_mock): test_span = trace._Span(