From 8474a1d703bdeb84309a5b58a2f703c648095cc3 Mon Sep 17 00:00:00 2001 From: Liudmila Molkova Date: Mon, 29 Nov 2021 22:21:16 -0800 Subject: [PATCH] Make Cosmos logical spans CLIENT and current (#25571) * Make Cosmos logical spans CLIENT and current Co-authored-by: Trask Stalnaker --- .../resources/spotbugs/spotbugs-exclude.xml | 7 +- .../cosmos/implementation/TracerProvider.java | 211 ++++++++++++++---- .../azure/cosmos/util/CosmosPagedFlux.java | 159 +++++++------ .../com/azure/cosmos/CosmosTracerTest.java | 35 +-- .../implementation/TracerProviderTest.java | 170 ++++++++++++++ 5 files changed, 454 insertions(+), 128 deletions(-) create mode 100644 sdk/cosmos/azure-cosmos/src/test/java/com/azure/cosmos/implementation/TracerProviderTest.java diff --git a/eng/code-quality-reports/src/main/resources/spotbugs/spotbugs-exclude.xml b/eng/code-quality-reports/src/main/resources/spotbugs/spotbugs-exclude.xml index 0f2bdff00192a..17dba7194598f 100755 --- a/eng/code-quality-reports/src/main/resources/spotbugs/spotbugs-exclude.xml +++ b/eng/code-quality-reports/src/main/resources/spotbugs/spotbugs-exclude.xml @@ -1896,7 +1896,12 @@ - + + + + + + diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/TracerProvider.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/TracerProvider.java index 936171c95ff03..f3cecf87f5b40 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/TracerProvider.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/TracerProvider.java @@ -3,6 +3,8 @@ package com.azure.cosmos.implementation; import com.azure.core.util.Context; +import com.azure.core.util.tracing.SpanKind; +import com.azure.core.util.tracing.StartSpanOptions; import com.azure.core.util.tracing.Tracer; import com.azure.cosmos.BridgeInternal; import com.azure.cosmos.ConsistencyLevel; @@ -21,8 +23,13 @@ import org.HdrHistogram.ConcurrentDoubleHistogram; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + +import reactor.core.CoreSubscriber; +import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; import reactor.core.publisher.Signal; +import reactor.util.context.ContextView; import java.time.Duration; import java.time.OffsetDateTime; @@ -32,7 +39,6 @@ import java.util.Map; import java.util.Objects; import java.util.Optional; -import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; import static com.azure.core.util.tracing.Tracer.AZ_TRACING_NAMESPACE_KEY; @@ -53,14 +59,47 @@ public class TracerProvider { public static final String RESOURCE_PROVIDER_NAME = "Microsoft.DocumentDB"; public final Duration CRUD_THRESHOLD_FOR_DIAGNOSTICS = Duration.ofMillis(100); public final Duration QUERY_THRESHOLD_FOR_DIAGNOSTICS = Duration.ofMillis(500); + + private static final String REACTOR_TRACING_CONTEXT_KEY = "tracing-context"; + private static final Object DUMMY_VALUE = new Object(); + private final Mono propagatingMono; + private final Flux propagatingFlux; public TracerProvider(Tracer tracer) { this.tracer = tracer; + this.propagatingMono = new PropagatingMono(); + this.propagatingFlux = new PropagatingFlux(); } public boolean isEnabled() { return tracer != null; } + /** + * Gets {@link Context} from Reactor {@link ContextView}. + * + * @param reactorContext Reactor context instance. + * @return {@link Context} from reactor context or null if not present. + */ + public static Context getContextFromReactorOrNull(ContextView reactorContext) { + Object context = reactorContext.getOrDefault(REACTOR_TRACING_CONTEXT_KEY, null); + + if (context != null && context instanceof Context) { + return (Context) context; + } + + return null; + } + + /** + * Stores {@link Context} in Reactor {@link reactor.util.context.Context}. + * + * @param traceContext {@link Context} context with trace context to store. + * @return {@link reactor.util.context.Context} Reactor context with trace context. + */ + public static reactor.util.context.Context setContextInReactor(Context traceContext) { + return reactor.util.context.Context.of(REACTOR_TRACING_CONTEXT_KEY, traceContext); + } + /** * For each tracer plugged into the SDK a new tracing span is created. *

@@ -73,16 +112,18 @@ public boolean isEnabled() { */ public Context startSpan(String methodName, String databaseId, String endpoint, Context context) { Context local = Objects.requireNonNull(context, "'context' cannot be null."); - local = local.addData(AZ_TRACING_NAMESPACE_KEY, RESOURCE_PROVIDER_NAME); - local = tracer.start(methodName, local); // start the span and return the started span + + StartSpanOptions spanOptions = new StartSpanOptions(SpanKind.CLIENT) + .setAttribute(AZ_TRACING_NAMESPACE_KEY, RESOURCE_PROVIDER_NAME) + .setAttribute(DB_TYPE, DB_TYPE_VALUE) + .setAttribute(TracerProvider.DB_URL, endpoint) + .setAttribute(TracerProvider.DB_STATEMENT, methodName); if (databaseId != null) { - tracer.setAttribute(TracerProvider.DB_INSTANCE, databaseId, local); + spanOptions.setAttribute(TracerProvider.DB_INSTANCE, databaseId); } - tracer.setAttribute(TracerProvider.DB_TYPE, DB_TYPE_VALUE, local); - tracer.setAttribute(TracerProvider.DB_URL, endpoint, local); - tracer.setAttribute(TracerProvider.DB_STATEMENT, methodName, local); - return local; + // start the span and return the started span + return tracer.start(methodName, spanOptions, local); } /** @@ -106,17 +147,19 @@ public void addEvent(String name, Map attributes, OffsetDateTime * Given a context containing the current tracing span the span is marked completed with status info from * {@link Signal}. For each tracer plugged into the SDK the current tracing span is marked as completed. * - * @param context Additional metadata that is passed through the call stack. * @param signal The signal indicates the status and contains the metadata we need to end the tracing span. */ - public > void endSpan(Context context, - Signal signal, - int statusCode) { - Objects.requireNonNull(context, "'context' cannot be null."); + public void endSpan(Signal signal, int statusCode) { Objects.requireNonNull(signal, "'signal' cannot be null."); + Context context = getContextFromReactorOrNull(signal.getContextView()); + if (context == null) { + return; + } + switch (signal.getType()) { case ON_COMPLETE: + case ON_NEXT: end(statusCode, null, context); break; case ON_ERROR: @@ -133,7 +176,7 @@ public > void endSpan(Context conte end(statusCode, throwable, context); break; default: - // ON_SUBSCRIBE and ON_NEXT don't have the information to end the span so just return. + // ON_SUBSCRIBE isn't the right state to end span break; } } @@ -190,6 +233,20 @@ public Mono> traceEnabledCosmosItemResponsePublisher(M thresholdForDiagnosticsOnTracer); } + /** + * Runs given {@code Flux} publisher in the scope of trace context passed in using + * {@link TracerProvider#setContextInReactor(Context, reactor.util.context.Context)} in {@code contextWrite} + * Populates active trace context on Reactor's hot path. Reactor's instrumentation for OpenTelemetry + * (or other hypothetical solution) will take care of the cold path. + * + * @param publisher publisher to run. + * @return wrapped publisher. + */ + public Flux runUnderSpanInContext(Flux publisher) { + return propagatingFlux + .flatMap(ignored -> publisher); + } + private Mono traceEnabledPublisher(Mono resultPublisher, Context context, String spanName, @@ -198,41 +255,55 @@ private Mono traceEnabledPublisher(Mono resultPublisher, Function statusCodeFunc, Function diagnosticFunc, Duration thresholdForDiagnosticsOnTracer) { - final AtomicReference parentContext = new AtomicReference<>(Context.NONE); + + if (!isEnabled()) { + return resultPublisher; + } + Optional callDepth = context.getData(COSMOS_CALL_DEPTH); final boolean isNestedCall = callDepth.isPresent(); - return resultPublisher - .doOnSubscribe(ignoredValue -> { - if (isEnabled() && !isNestedCall) { - parentContext.set(this.startSpan(spanName, databaseId, endpoint, - context)); - } - }).doOnSuccess(response -> { - if (isEnabled() && !isNestedCall) { - CosmosDiagnostics cosmosDiagnostics = diagnosticFunc.apply(response); - try { - Duration threshold = thresholdForDiagnosticsOnTracer; - if(threshold == null) { - threshold = CRUD_THRESHOLD_FOR_DIAGNOSTICS; - } + if (isNestedCall) { + return resultPublisher; + } - if (cosmosDiagnostics != null - && cosmosDiagnostics.getDuration() != null - && cosmosDiagnostics.getDuration().compareTo(threshold) > 0) { - addDiagnosticsOnTracerEvent(cosmosDiagnostics, parentContext.get()); + // propagatingMono ensures active span is propagated to the `resultPublisher` + // subscription and hot path. OpenTelemetry reactor's instrumentation will + // propagate it on the cold path. + return propagatingMono + .flatMap(ignored -> resultPublisher) + .doOnEach(signal -> { + switch (signal.getType()) { + case ON_NEXT: + T response = signal.get(); + Context traceContext = getContextFromReactorOrNull(signal.getContextView()); + CosmosDiagnostics cosmosDiagnostics = diagnosticFunc.apply(response); + try { + Duration threshold = thresholdForDiagnosticsOnTracer; + if (threshold == null) { + threshold = CRUD_THRESHOLD_FOR_DIAGNOSTICS; + } + + if (cosmosDiagnostics != null + && cosmosDiagnostics.getDuration() != null + && cosmosDiagnostics.getDuration().compareTo(threshold) > 0) { + addDiagnosticsOnTracerEvent(cosmosDiagnostics, traceContext); + } + } catch (JsonProcessingException ex) { + LOGGER.warn("Error while serializing diagnostics for tracer", ex.getMessage()); } - } catch (JsonProcessingException ex) { - LOGGER.warn("Error while serializing diagnostics for tracer", ex.getMessage()); - } - this.endSpan(parentContext.get(), Signal.complete(), statusCodeFunc.apply(response)); - } - }).doOnError(throwable -> { - if (isEnabled() && !isNestedCall) { - // not adding diagnostics on trace event for exception as this information is already there as - // part of exception message - this.endSpan(parentContext.get(), Signal.error(throwable), ERROR_CODE); - } - }); + + this.endSpan(signal, statusCodeFunc.apply(response)); + break; + case ON_ERROR: + // not adding diagnostics on trace event for exception as this information is already there as + // part of exception message + this.endSpan(signal, ERROR_CODE); + break; + default: + break; + }}) + .contextWrite(setContextInReactor(this.startSpan(spanName, databaseId, endpoint, + context))); } private Mono publisherWithClientTelemetry(Mono resultPublisher, @@ -364,7 +435,7 @@ private ReportPayload createReportPayload(CosmosAsyncClient cosmosAsyncClient, } private void addDiagnosticsOnTracerEvent(CosmosDiagnostics cosmosDiagnostics, Context context) throws JsonProcessingException { - if (cosmosDiagnostics == null) { + if (cosmosDiagnostics == null || context == null) { return; } @@ -499,4 +570,52 @@ private void addDiagnosticsOnTracerEvent(CosmosDiagnostics cosmosDiagnostics, Co this.addEvent("ClientCfgs", attributes, OffsetDateTime.ofInstant(clientSideRequestStatistics.getRequestStartTimeUTC(), ZoneOffset.UTC), context); } + + private static void subscribe(Tracer tracer, CoreSubscriber actual) { + Context context = getContextFromReactorOrNull(actual.currentContext()); + if (context != null) { + AutoCloseable scope = tracer.makeSpanCurrent(context); + try { + actual.onSubscribe(Operators.scalarSubscription(actual, DUMMY_VALUE)); + } finally { + try { + scope.close(); + } catch (Exception e) { + // can't happen + } + } + } else { + actual.onSubscribe(Operators.scalarSubscription(actual, DUMMY_VALUE)); + } + } + + /** + * Helper class allowing running Mono subscription (and anything on the hot path) + * in scope of trace context. This enables OpenTelemetry auto-collection + * to pick it up and correlate lower levels of instrumentation and logs + * to logical Cosmos spans. + * + * OpenTelemetry reactor auto-instrumentation will take care of the cold path. + */ + private final class PropagatingMono extends Mono { + @Override + public void subscribe(CoreSubscriber actual) { + TracerProvider.subscribe(tracer, actual); + } + } + + /** + * Helper class allowing running Flux subscription (and anything on the hot path) + * in scope of trace context. This enables OpenTelemetry auto-collection + * to pick it up and correlate lower levels of instrumentation and logs + * to logical Cosmos spans. + * + * OpenTelemetry reactor auto-instrumentation will take care of the cold path. + */ + private final class PropagatingFlux extends Flux { + @Override + public void subscribe(CoreSubscriber actual) { + TracerProvider.subscribe(tracer, actual); + } + } } diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/util/CosmosPagedFlux.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/util/CosmosPagedFlux.java index 2eba04793abcb..183092e17ae5e 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/util/CosmosPagedFlux.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/util/CosmosPagedFlux.java @@ -164,77 +164,100 @@ private CosmosPagedFluxOptions createCosmosPagedFluxOptions() { return cosmosPagedFluxOptions; } + private Flux wrapWithTracingIfEnabled(CosmosPagedFluxOptions pagedFluxOptions, Flux publisher, Context context) { + TracerProvider tracerProvider = pagedFluxOptions.getTracerProvider(); + if (!isTracerEnabled(pagedFluxOptions)) { + return publisher; + } + + return tracerProvider.runUnderSpanInContext(publisher); + } + private Flux> byPage(CosmosPagedFluxOptions pagedFluxOptions, Context context) { - final AtomicReference parentContext = new AtomicReference<>(Context.NONE); AtomicReference startTime = new AtomicReference<>(); - return this.optionsFluxFunction.apply(pagedFluxOptions).doOnSubscribe(ignoredValue -> { - if (pagedFluxOptions.getTracerProvider() != null && pagedFluxOptions.getTracerProvider().isEnabled()) { - parentContext.set(pagedFluxOptions.getTracerProvider().startSpan(pagedFluxOptions.getTracerSpanName(), - pagedFluxOptions.getDatabaseId(), pagedFluxOptions.getServiceEndpoint(), - context)); - } - startTime.set(Instant.now()); - }).doOnComplete(() -> { - if (pagedFluxOptions.getTracerProvider() != null && pagedFluxOptions.getTracerProvider().isEnabled()) { - pagedFluxOptions.getTracerProvider().endSpan(parentContext.get(), Signal.complete(), - HttpConstants.StatusCodes.OK); - } - }).doOnError(throwable -> { - if (pagedFluxOptions.getCosmosAsyncClient() != null && - Configs.isClientTelemetryEnabled(BridgeInternal.isClientTelemetryEnabled(pagedFluxOptions.getCosmosAsyncClient())) && - throwable instanceof CosmosException) { - CosmosException cosmosException = (CosmosException) throwable; - // not adding diagnostics on trace event for exception as this information is already there as - // part of exception message - if (this.cosmosDiagnosticsAccessor.isDiagnosticsCapturedInPagedFlux(cosmosException.getDiagnostics()).compareAndSet(false, true)) { - fillClientTelemetry(pagedFluxOptions.getCosmosAsyncClient(), 0, pagedFluxOptions.getContainerId(), - pagedFluxOptions.getDatabaseId(), - pagedFluxOptions.getOperationType(), pagedFluxOptions.getResourceType(), - BridgeInternal.getContextClient(pagedFluxOptions.getCosmosAsyncClient()).getConsistencyLevel(), - (float) cosmosException.getRequestCharge(), Duration.between(startTime.get(), Instant.now())); - } - } - if (isTracerEnabled(pagedFluxOptions)) { - pagedFluxOptions.getTracerProvider().endSpan(parentContext.get(), Signal.error(throwable), - TracerProvider.ERROR_CODE); - } - startTime.set(Instant.now()); - }).doOnNext(feedResponse -> { - if (isTracerEnabled(pagedFluxOptions) && - this.cosmosDiagnosticsAccessor.isDiagnosticsCapturedInPagedFlux(feedResponse.getCosmosDiagnostics()).compareAndSet(false, true)) { - try { - Duration threshold = pagedFluxOptions.getThresholdForDiagnosticsOnTracer(); - if (threshold == null) { - threshold = pagedFluxOptions.getTracerProvider().QUERY_THRESHOLD_FOR_DIAGNOSTICS; - } - - if (Duration.between(startTime.get(), Instant.now()).compareTo(threshold) > 0) { - addDiagnosticsOnTracerEvent(pagedFluxOptions.getTracerProvider(), - feedResponse.getCosmosDiagnostics(), parentContext.get()); - } - } catch (JsonProcessingException ex) { - LOGGER.warn("Error while serializing diagnostics for tracer", ex.getMessage()); - } - } - // If the user has passed feedResponseConsumer, then call it with each feedResponse - if (feedResponseConsumer != null) { - feedResponseConsumer.accept(feedResponse); - } + Flux> result = + wrapWithTracingIfEnabled(pagedFluxOptions, this.optionsFluxFunction.apply(pagedFluxOptions), context) + .doOnSubscribe(ignoredValue -> startTime.set(Instant.now())) + .doOnEach(signal -> { + switch (signal.getType()) { + case ON_COMPLETE: + if (isTracerEnabled(pagedFluxOptions)) { + pagedFluxOptions.getTracerProvider().endSpan(signal, HttpConstants.StatusCodes.OK); + } + break; + case ON_ERROR: + Throwable throwable = signal.getThrowable(); + if (pagedFluxOptions.getCosmosAsyncClient() != null && + Configs.isClientTelemetryEnabled(BridgeInternal.isClientTelemetryEnabled(pagedFluxOptions.getCosmosAsyncClient())) && + throwable instanceof CosmosException) { + CosmosException cosmosException = (CosmosException) throwable; + // not adding diagnostics on trace event for exception as this information is already there as + // part of exception message + if (this.cosmosDiagnosticsAccessor.isDiagnosticsCapturedInPagedFlux(cosmosException.getDiagnostics()).compareAndSet(false, true)) { + fillClientTelemetry(pagedFluxOptions.getCosmosAsyncClient(), 0, pagedFluxOptions.getContainerId(), + pagedFluxOptions.getDatabaseId(), + pagedFluxOptions.getOperationType(), pagedFluxOptions.getResourceType(), + BridgeInternal.getContextClient(pagedFluxOptions.getCosmosAsyncClient()).getConsistencyLevel(), + (float) cosmosException.getRequestCharge(), Duration.between(startTime.get(), Instant.now())); + } + } + + if (isTracerEnabled(pagedFluxOptions)) { + pagedFluxOptions.getTracerProvider().endSpan(signal, TracerProvider.ERROR_CODE); + } + startTime.set(Instant.now()); + break; + case ON_NEXT: + FeedResponse feedResponse = signal.get(); + if (isTracerEnabled(pagedFluxOptions) && + this.cosmosDiagnosticsAccessor.isDiagnosticsCapturedInPagedFlux(feedResponse.getCosmosDiagnostics()).compareAndSet(false, true)) { + try { + Duration threshold = pagedFluxOptions.getThresholdForDiagnosticsOnTracer(); + if (threshold == null) { + threshold = pagedFluxOptions.getTracerProvider().QUERY_THRESHOLD_FOR_DIAGNOSTICS; + } + + if (Duration.between(startTime.get(), Instant.now()).compareTo(threshold) > 0) { + addDiagnosticsOnTracerEvent(pagedFluxOptions.getTracerProvider(), + feedResponse.getCosmosDiagnostics(), + TracerProvider.getContextFromReactorOrNull(signal.getContextView())); + } + } catch (JsonProcessingException ex) { + LOGGER.warn("Error while serializing diagnostics for tracer", ex.getMessage()); + } + } + // If the user has passed feedResponseConsumer, then call it with each feedResponse + if (feedResponseConsumer != null) { + feedResponseConsumer.accept(feedResponse); + } + + if (pagedFluxOptions.getCosmosAsyncClient() != null && + Configs.isClientTelemetryEnabled(BridgeInternal.isClientTelemetryEnabled(pagedFluxOptions.getCosmosAsyncClient()))) { + if (this.cosmosDiagnosticsAccessor.isDiagnosticsCapturedInPagedFlux(feedResponse.getCosmosDiagnostics()).compareAndSet(false, true)) { + fillClientTelemetry(pagedFluxOptions.getCosmosAsyncClient(), HttpConstants.StatusCodes.OK, + pagedFluxOptions.getContainerId(), + pagedFluxOptions.getDatabaseId(), + pagedFluxOptions.getOperationType(), pagedFluxOptions.getResourceType(), + BridgeInternal.getContextClient(pagedFluxOptions.getCosmosAsyncClient()).getConsistencyLevel(), + (float) feedResponse.getRequestCharge(), Duration.between(startTime.get(), Instant.now())); + startTime.set(Instant.now()); + }; + } + break; + default: + break; + }}); + + if (isTracerEnabled(pagedFluxOptions)) { + return result.contextWrite(TracerProvider.setContextInReactor( + pagedFluxOptions.getTracerProvider().startSpan(pagedFluxOptions.getTracerSpanName(), + pagedFluxOptions.getDatabaseId(), + pagedFluxOptions.getServiceEndpoint(), + context))); + } - if (pagedFluxOptions.getCosmosAsyncClient() != null && - Configs.isClientTelemetryEnabled(BridgeInternal.isClientTelemetryEnabled(pagedFluxOptions.getCosmosAsyncClient()))) { - if (this.cosmosDiagnosticsAccessor.isDiagnosticsCapturedInPagedFlux(feedResponse.getCosmosDiagnostics()).compareAndSet(false, true)) { - fillClientTelemetry(pagedFluxOptions.getCosmosAsyncClient(), HttpConstants.StatusCodes.OK, - pagedFluxOptions.getContainerId(), - pagedFluxOptions.getDatabaseId(), - pagedFluxOptions.getOperationType(), pagedFluxOptions.getResourceType(), - BridgeInternal.getContextClient(pagedFluxOptions.getCosmosAsyncClient()).getConsistencyLevel(), - (float) feedResponse.getRequestCharge(), Duration.between(startTime.get(), Instant.now())); - startTime.set(Instant.now()); - }; - } - }); + return result; } private void fillClientTelemetry(CosmosAsyncClient cosmosAsyncClient, @@ -322,7 +345,7 @@ private ReportPayload createReportPayload(CosmosAsyncClient cosmosAsyncClient, } private void addDiagnosticsOnTracerEvent(TracerProvider tracerProvider, CosmosDiagnostics cosmosDiagnostics, Context parentContext) throws JsonProcessingException { - if (cosmosDiagnostics == null) { + if (cosmosDiagnostics == null || parentContext == null) { return; } diff --git a/sdk/cosmos/azure-cosmos/src/test/java/com/azure/cosmos/CosmosTracerTest.java b/sdk/cosmos/azure-cosmos/src/test/java/com/azure/cosmos/CosmosTracerTest.java index 86b7b80081be4..92945da5f3589 100644 --- a/sdk/cosmos/azure-cosmos/src/test/java/com/azure/cosmos/CosmosTracerTest.java +++ b/sdk/cosmos/azure-cosmos/src/test/java/com/azure/cosmos/CosmosTracerTest.java @@ -3,6 +3,7 @@ package com.azure.cosmos; import com.azure.core.util.Context; +import com.azure.core.util.tracing.StartSpanOptions; import com.azure.core.util.tracing.Tracer; import com.azure.cosmos.implementation.ClientSideRequestStatistics; import com.azure.cosmos.implementation.FeedResponseDiagnostics; @@ -41,6 +42,7 @@ import com.azure.cosmos.rx.TestSuiteBase; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; +import org.mockito.ArgumentCaptor; import org.mockito.ArgumentMatchers; import org.mockito.Mockito; import org.mockito.invocation.InvocationOnMock; @@ -423,7 +425,8 @@ public void tracerExceptionSpan() throws Exception { Mockito.doAnswer(tracerProviderCapture).when(tracerProvider).startSpan(ArgumentMatchers.any(), ArgumentMatchers.any(), - ArgumentMatchers.any(), ArgumentMatchers.any()); + ArgumentMatchers.any(), + ArgumentMatchers.any()); Mockito.doAnswer(addEventCapture).when(tracerProvider).addEvent(ArgumentMatchers.any(), ArgumentMatchers.any(), ArgumentMatchers.any(), @@ -482,7 +485,10 @@ private static CosmosStoredProcedureProperties getCosmosStoredProcedurePropertie private Tracer getMockTracer() { Tracer mockTracer = Mockito.mock(Tracer.class); - Mockito.when(mockTracer.start(ArgumentMatchers.any(String.class), ArgumentMatchers.any(Context.class))).thenReturn(Context.NONE); + Mockito.when(mockTracer.start(ArgumentMatchers.any(String.class), + ArgumentMatchers.any(StartSpanOptions.class), + ArgumentMatchers.any(Context.class))) + .thenReturn(Context.NONE); return mockTracer; } @@ -490,24 +496,27 @@ private void verifyTracerAttributes(TracerProvider tracerProvider, Tracer mockTr Context context, String databaseName, int numberOfTimesCalledWithinTest, String errorType, CosmosDiagnostics cosmosDiagnostics, - Map> attributesMap) throws JsonProcessingException { + Map> eventAttributesMap) throws JsonProcessingException { Mockito.verify(tracerProvider, Mockito.times(numberOfTimesCalledWithinTest)).startSpan(ArgumentMatchers.any(), ArgumentMatchers.any(), ArgumentMatchers.any(), ArgumentMatchers.any(Context.class)); + ArgumentCaptor optionsCaptor = ArgumentCaptor.forClass(StartSpanOptions.class); + Mockito.verify(mockTracer, Mockito.times(numberOfTimesCalledWithinTest)) + .start(Mockito.any(), optionsCaptor.capture(), Mockito.any()); + + Map startAttributes = optionsCaptor.getValue().getAttributes(); if (databaseName != null) { - Mockito.verify(mockTracer, Mockito.times(numberOfTimesCalledWithinTest)).setAttribute(TracerProvider.DB_INSTANCE, - databaseName, context); + assertThat(startAttributes.get(TracerProvider.DB_INSTANCE)).isEqualTo(databaseName); } - Mockito.verify(mockTracer, Mockito.times(numberOfTimesCalledWithinTest)).setAttribute(TracerProvider.DB_TYPE, - TracerProvider.DB_TYPE_VALUE, context); - Mockito.verify(mockTracer, Mockito.times(numberOfTimesCalledWithinTest)).setAttribute(TracerProvider.DB_URL, - TestConfigurations.HOST, - context); - Mockito.verify(mockTracer, Mockito.times(1)).setAttribute(TracerProvider.DB_STATEMENT, methodName, context); + + assertThat(startAttributes.get(TracerProvider.DB_TYPE)).isEqualTo(TracerProvider.DB_TYPE_VALUE); + assertThat(startAttributes.get(TracerProvider.DB_URL)).isEqualTo(TestConfigurations.HOST); + assertThat(startAttributes.get(TracerProvider.DB_STATEMENT)).isEqualTo(methodName); + assertThat(startAttributes.get(Tracer.AZ_TRACING_NAMESPACE_KEY)).isEqualTo(TracerProvider.RESOURCE_PROVIDER_NAME); //verifying diagnostics as events - verifyTracerDiagnostics(tracerProvider, cosmosDiagnostics, attributesMap); + verifyTracerDiagnostics(tracerProvider, cosmosDiagnostics, eventAttributesMap); } private void verifyTracerDiagnostics(TracerProvider tracerProvider, @@ -675,7 +684,7 @@ private void verifyTracerDiagnostics(TracerProvider tracerProvider, } private class TracerProviderCapture implements Answer { - private Context result = null; + private Context result = Context.NONE; public Context getResult() { return result; diff --git a/sdk/cosmos/azure-cosmos/src/test/java/com/azure/cosmos/implementation/TracerProviderTest.java b/sdk/cosmos/azure-cosmos/src/test/java/com/azure/cosmos/implementation/TracerProviderTest.java new file mode 100644 index 0000000000000..cf85f5d705aa7 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/src/test/java/com/azure/cosmos/implementation/TracerProviderTest.java @@ -0,0 +1,170 @@ +package com.azure.cosmos.implementation; + +import com.azure.core.util.Context; +import com.azure.core.util.tracing.SpanKind; +import com.azure.core.util.tracing.StartSpanOptions; +import com.azure.core.util.tracing.Tracer; +import com.azure.cosmos.CosmosException; +import com.azure.cosmos.implementation.changefeed.exceptions.PartitionNotFoundException; +import com.azure.cosmos.models.CosmosResponse; +import org.mockito.ArgumentCaptor; +import org.mockito.Mockito; +import org.testng.annotations.Test; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Signal; + +import java.util.Map; +import java.util.concurrent.atomic.AtomicBoolean; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.fail; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.isNull; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.testng.Assert.assertThrows; + +public class TracerProviderTest { + @Test(groups = { "unit" }) + public void startSpan() { + Tracer tracerMock = Mockito.mock(Tracer.class); + String methodName = "get item"; + String endpoint = "endpoint"; + String instance = "instance"; + Context context = new Context("foo", "bar"); + + ArgumentCaptor optionsCaptor = ArgumentCaptor.forClass(StartSpanOptions.class); + TracerProvider provider = new TracerProvider(tracerMock); + provider.startSpan(methodName, instance, endpoint, context); + verify(tracerMock, times(1)).start(eq(methodName), optionsCaptor.capture(), eq(context)); + + assertThat(optionsCaptor.getValue().getSpanKind()).isEqualTo(SpanKind.CLIENT); + + Map attributes = optionsCaptor.getValue().getAttributes(); + + assertThat(attributes.get("az.namespace")).isEqualTo("Microsoft.DocumentDB"); + assertThat(attributes.get("db.type")).isEqualTo("Cosmos"); + assertThat(attributes.get("db.url")).isEqualTo(endpoint); + assertThat(attributes.get("db.statement")).isEqualTo(methodName); + assertThat(attributes.get("db.instance")).isEqualTo(instance); + } + + @Test(groups = { "unit" }) + public void endSpanSuccess() { + Context sdkContext = new Context("span", new Object()); + reactor.util.context.Context reactorContext = TracerProvider.setContextInReactor(sdkContext); + + Tracer tracerMock = Mockito.mock(Tracer.class); + TracerProvider provider = new TracerProvider(tracerMock); + provider.endSpan(Signal.complete(reactorContext), 200); + verify(tracerMock, times(1)).end(eq(200), isNull(), eq(sdkContext)); + } + + @Test(groups = { "unit" }) + public void endSpanFailureNotCosmosException() { + Context sdkContext = new Context("span", new Object()); + reactor.util.context.Context reactorContext = TracerProvider.setContextInReactor(sdkContext); + + Tracer tracerMock = Mockito.mock(Tracer.class); + TracerProvider provider = new TracerProvider(tracerMock); + Exception ex = new Exception("foo"); + provider.endSpan(Signal.error(ex, reactorContext), 500); + verify(tracerMock, times(1)).end(eq(500), eq(ex), eq(sdkContext)); + } + + @Test(groups = { "unit" }) + public void endSpanFailureCosmosException() { + Context sdkContext = new Context("span", new Object()); + reactor.util.context.Context reactorContext = TracerProvider.setContextInReactor(sdkContext); + + Tracer tracerMock = Mockito.mock(Tracer.class); + TracerProvider provider = new TracerProvider(tracerMock); + Exception ex = new ServiceUnavailableException(); + provider.endSpan(Signal.error(ex, reactorContext), -1); + verify(tracerMock, times(1)).end(eq(503), eq(ex), eq(sdkContext)); + } + + @Test(groups = { "unit" }) + public void traceMonoPublisher() { + Tracer tracerMock = Mockito.mock(Tracer.class); + + CosmosResponse response = Mockito.mock(CosmosResponse.class); + Context sdkContext = new Context("span", new Object()); + + TracerProvider provider = new TracerProvider(tracerMock); + AtomicBoolean closed = new AtomicBoolean(false); + when(tracerMock.start(anyString(), any(StartSpanOptions.class), any(Context.class))).thenReturn(sdkContext); + when(tracerMock.makeSpanCurrent(any())).thenReturn(() -> closed.set(true)); + when(response.getStatusCode()).thenReturn(412); + + provider.traceEnabledCosmosResponsePublisher(Mono.deferContextual(ctx -> { + assertThat(TracerProvider.getContextFromReactorOrNull(ctx)).isSameAs(sdkContext); + return Mono.just(response); + }), + Context.NONE, "methodName", "instance", "endpoint").block(); + + ArgumentCaptor optionsCaptor = ArgumentCaptor.forClass(StartSpanOptions.class); + verify(tracerMock, times(1)).start(anyString(), optionsCaptor.capture(), any(Context.class)); + verify(tracerMock, times(1)).makeSpanCurrent(eq(sdkContext)); + verify(tracerMock, times(1)).end(eq(412), any(), eq(sdkContext)); + assertThat(closed.get()).isTrue(); + } + + @Test(groups = { "unit" }) + public void traceMonoPublisherException() { + Tracer tracerMock = Mockito.mock(Tracer.class); + + CosmosResponse response = Mockito.mock(CosmosResponse.class); + TracerProvider provider = new TracerProvider(tracerMock); + AtomicBoolean closed = new AtomicBoolean(false); + when(tracerMock.start(anyString(), any(StartSpanOptions.class), any(Context.class))).thenReturn(Context.NONE); + when(tracerMock.makeSpanCurrent(any())).thenReturn(() -> closed.set(true)); + when(response.getStatusCode()).thenReturn(412); + + Exception ex = new BadRequestException("foo"); + + assertThrows( + CosmosException.class, + () -> provider.traceEnabledCosmosResponsePublisher(Mono.error(ex), + Context.NONE, "methodName", "instance", "endpoint").block()); + + verify(tracerMock, times(1)).start(anyString(), any(StartSpanOptions.class), any(Context.class)); + verify(tracerMock, times(1)).makeSpanCurrent(any()); + verify(tracerMock, times(1)).end(eq(400), eq(ex), any()); + assertThat(closed.get()).isTrue(); + } + + @Test(groups = { "unit" }) + public void testSetGetReactorContext() { + Context sdkContext = new Context("span", new Object()); + + reactor.util.context.Context reactorContext = + TracerProvider.setContextInReactor(sdkContext); + + assertThat(TracerProvider.getContextFromReactorOrNull(reactorContext)).isSameAs(sdkContext); + } + + @Test(groups = { "unit" }) + public void traceFluxPropagation() { + Tracer tracerMock = Mockito.mock(Tracer.class); + + CosmosResponse response = Mockito.mock(CosmosResponse.class); + Context sdkContext = new Context("span", new Object()); + + TracerProvider provider = new TracerProvider(tracerMock); + AtomicBoolean closed = new AtomicBoolean(false); + when(tracerMock.makeSpanCurrent(any())).thenReturn(() -> closed.set(true)); + + provider + .runUnderSpanInContext(Flux.just(response)) + .contextWrite(TracerProvider.setContextInReactor(sdkContext)) + .blockLast(); + + verify(tracerMock, times(1)).makeSpanCurrent(eq(sdkContext)); + assertThat(closed.get()).isTrue(); + } +}