diff --git a/sdk/azcore/CHANGELOG.md b/sdk/azcore/CHANGELOG.md index 6b8329201d86..09aeea35e6a5 100644 --- a/sdk/azcore/CHANGELOG.md +++ b/sdk/azcore/CHANGELOG.md @@ -3,6 +3,14 @@ ## 1.5.0-beta.2 (Unreleased) ### Features Added +* Added supporting features to enable distributed tracing. + * Added func `runtime.StartSpan()` for use by SDKs to start spans. + * Added method `WithContext()` to `runtime.Request` to support shallow cloning with a new context. + * Added field `TracingNamespace` to `runtime.PipelineOptions`. + * Added field `Tracer` to `runtime.NewPollerOptions` and `runtime.NewPollerFromResumeTokenOptions` types. + * Added field `SpanFromContext` to `tracing.TracerOptions`. + * Added methods `Enabled()`, `SetAttributes()`, and `SpanFromContext()` to `tracing.Tracer`. + * Added supporting pipeline policies to include HTTP spans when creating clients. ### Breaking Changes diff --git a/sdk/azcore/arm/runtime/pipeline.go b/sdk/azcore/arm/runtime/pipeline.go index 266c74b17bf1..302c19cd4265 100644 --- a/sdk/azcore/arm/runtime/pipeline.go +++ b/sdk/azcore/arm/runtime/pipeline.go @@ -13,6 +13,7 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore" armpolicy "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm/policy" "github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported" azpolicy "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" azruntime "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" ) @@ -34,7 +35,7 @@ func NewPipeline(module, version string, cred azcore.TokenCredential, plOpts azr }) perRetry := make([]azpolicy.Policy, len(plOpts.PerRetry), len(plOpts.PerRetry)+1) copy(perRetry, plOpts.PerRetry) - plOpts.PerRetry = append(perRetry, authPolicy) + plOpts.PerRetry = append(perRetry, authPolicy, exported.PolicyFunc(httpTraceNamespacePolicy)) if !options.DisableRPRegistration { regRPOpts := armpolicy.RegistrationOptions{ClientOptions: options.ClientOptions} regPolicy, err := NewRPRegistrationPolicy(cred, ®RPOpts) diff --git a/sdk/azcore/arm/runtime/policy_trace_namespace.go b/sdk/azcore/arm/runtime/policy_trace_namespace.go new file mode 100644 index 000000000000..76aefe8550dd --- /dev/null +++ b/sdk/azcore/arm/runtime/policy_trace_namespace.go @@ -0,0 +1,31 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package runtime + +import ( + "net/http" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm/internal/resource" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/tracing" +) + +// httpTraceNamespacePolicy is a policy that adds the az.namespace attribute to the current Span +func httpTraceNamespacePolicy(req *policy.Request) (resp *http.Response, err error) { + rawTracer := req.Raw().Context().Value(shared.CtxWithTracingTracer{}) + if tracer, ok := rawTracer.(tracing.Tracer); ok { + rt, err := resource.ParseResourceType(req.Raw().URL.Path) + if err == nil { + // add the namespace attribute to the current span + if span, ok := tracer.SpanFromContext(req.Raw().Context()); ok { + span.SetAttributes(tracing.Attribute{Key: "az.namespace", Value: rt.Namespace}) + } + } + } + return req.Next() +} diff --git a/sdk/azcore/arm/runtime/policy_trace_namespace_test.go b/sdk/azcore/arm/runtime/policy_trace_namespace_test.go new file mode 100644 index 000000000000..4ac7484823f8 --- /dev/null +++ b/sdk/azcore/arm/runtime/policy_trace_namespace_test.go @@ -0,0 +1,97 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package runtime + +import ( + "context" + "net/http" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/tracing" + "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" + "github.com/stretchr/testify/require" +) + +func TestHTTPTraceNamespacePolicy(t *testing.T) { + srv, close := mock.NewServer() + defer close() + + pl := exported.NewPipeline(srv, exported.PolicyFunc(httpTraceNamespacePolicy)) + + // no tracer + req, err := exported.NewRequest(context.Background(), http.MethodGet, srv.URL()) + require.NoError(t, err) + srv.AppendResponse() + _, err = pl.Do(req) + require.NoError(t, err) + + // wrong tracer type + req, err = exported.NewRequest(context.WithValue(context.Background(), shared.CtxWithTracingTracer{}, 0), http.MethodGet, srv.URL()) + require.NoError(t, err) + srv.AppendResponse() + _, err = pl.Do(req) + require.NoError(t, err) + + // no SpanFromContext impl + tr := tracing.NewTracer(func(ctx context.Context, spanName string, options *tracing.SpanOptions) (context.Context, tracing.Span) { + return ctx, tracing.Span{} + }, nil) + req, err = exported.NewRequest(context.WithValue(context.Background(), shared.CtxWithTracingTracer{}, tr), http.MethodGet, srv.URL()) + require.NoError(t, err) + srv.AppendResponse() + _, err = pl.Do(req) + require.NoError(t, err) + + // failed to parse resource ID, shouldn't call SetAttributes + var attrString string + tr = tracing.NewTracer(func(ctx context.Context, spanName string, options *tracing.SpanOptions) (context.Context, tracing.Span) { + return ctx, tracing.Span{} + }, &tracing.TracerOptions{ + SpanFromContext: func(ctx context.Context) (tracing.Span, bool) { + spanImpl := tracing.SpanImpl{ + SetAttributes: func(a ...tracing.Attribute) { + require.Len(t, a, 1) + v, ok := a[0].Value.(string) + require.True(t, ok) + attrString = a[0].Key + ":" + v + }, + } + return tracing.NewSpan(spanImpl), true + }, + }) + req, err = exported.NewRequest(context.WithValue(context.Background(), shared.CtxWithTracingTracer{}, tr), http.MethodGet, srv.URL()) + require.NoError(t, err) + srv.AppendResponse() + _, err = pl.Do(req) + require.NoError(t, err) + require.Empty(t, attrString) + + // success + tr = tracing.NewTracer(func(ctx context.Context, spanName string, options *tracing.SpanOptions) (context.Context, tracing.Span) { + return ctx, tracing.Span{} + }, &tracing.TracerOptions{ + SpanFromContext: func(ctx context.Context) (tracing.Span, bool) { + spanImpl := tracing.SpanImpl{ + SetAttributes: func(a ...tracing.Attribute) { + require.Len(t, a, 1) + v, ok := a[0].Value.(string) + require.True(t, ok) + attrString = a[0].Key + ":" + v + }, + } + return tracing.NewSpan(spanImpl), true + }, + }) + req, err = exported.NewRequest(context.WithValue(context.Background(), shared.CtxWithTracingTracer{}, tr), http.MethodGet, srv.URL()+requestEndpoint) + require.NoError(t, err) + srv.AppendResponse() + _, err = pl.Do(req) + require.NoError(t, err) + require.EqualValues(t, "az.namespace:Microsoft.Storage", attrString) +} diff --git a/sdk/azcore/core.go b/sdk/azcore/core.go index 72c2cf21eef3..29666d2d021f 100644 --- a/sdk/azcore/core.go +++ b/sdk/azcore/core.go @@ -99,6 +99,9 @@ func NewClient(clientName, moduleVersion string, plOpts runtime.PipelineOptions, pl := runtime.NewPipeline(pkg, moduleVersion, plOpts, options) tr := options.TracingProvider.NewTracer(clientName, moduleVersion) + if tr.Enabled() && plOpts.TracingNamespace != "" { + tr.SetAttributes(tracing.Attribute{Key: "az.namespace", Value: plOpts.TracingNamespace}) + } return &Client{pl: pl, tr: tr}, nil } diff --git a/sdk/azcore/core_test.go b/sdk/azcore/core_test.go index 13d3361e1f77..e3288dcde895 100644 --- a/sdk/azcore/core_test.go +++ b/sdk/azcore/core_test.go @@ -7,11 +7,17 @@ package azcore import ( + "context" + "net/http" "reflect" "testing" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/tracing" + "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" "github.com/stretchr/testify/require" ) @@ -131,3 +137,37 @@ func TestNewClientError(t *testing.T) { require.Error(t, err) require.Nil(t, client) } + +func TestNewClientTracingEnabled(t *testing.T) { + srv, close := mock.NewServer() + defer close() + + var attrString string + client, err := NewClient("package.Client", "v1.0.0", runtime.PipelineOptions{TracingNamespace: "Widget.Factory"}, &policy.ClientOptions{ + TracingProvider: tracing.NewProvider(func(name, version string) tracing.Tracer { + return tracing.NewTracer(func(ctx context.Context, spanName string, options *tracing.SpanOptions) (context.Context, tracing.Span) { + require.NotNil(t, options) + for _, attr := range options.Attributes { + if attr.Key == "az.namespace" { + v, ok := attr.Value.(string) + require.True(t, ok) + attrString = attr.Key + ":" + v + } + } + return ctx, tracing.Span{} + }, nil) + }, nil), + Transport: srv, + }) + require.NoError(t, err) + require.NotNil(t, client) + require.NotZero(t, client.Pipeline()) + require.NotZero(t, client.Tracer()) + + const requestEndpoint = "/subscriptions/00000000-0000-0000-0000-000000000000/resourceGroups/fakeResourceGroupo/providers/Microsoft.Storage/storageAccounts/fakeAccountName" + req, err := exported.NewRequest(context.WithValue(context.Background(), shared.CtxWithTracingTracer{}, client.Tracer()), http.MethodGet, srv.URL()+requestEndpoint) + require.NoError(t, err) + srv.AppendResponse() + client.Pipeline().Do(req) + require.EqualValues(t, "az.namespace:Widget.Factory", attrString) +} diff --git a/sdk/azcore/internal/exported/request.go b/sdk/azcore/internal/exported/request.go index fa99d1b7ed1f..48229f5ccd68 100644 --- a/sdk/azcore/internal/exported/request.go +++ b/sdk/azcore/internal/exported/request.go @@ -170,6 +170,14 @@ func (req *Request) Clone(ctx context.Context) *Request { return &r2 } +// WithContext returns a shallow copy of the request with its context changed to ctx. +func (req *Request) WithContext(ctx context.Context) *Request { + r2 := new(Request) + *r2 = *req + r2.req = r2.req.WithContext(ctx) + return r2 +} + // not exported but dependent on Request // PolicyFunc is a type that implements the Policy interface. diff --git a/sdk/azcore/internal/exported/request_test.go b/sdk/azcore/internal/exported/request_test.go index d26b734c82c9..3acc8e7a76ae 100644 --- a/sdk/azcore/internal/exported/request_test.go +++ b/sdk/azcore/internal/exported/request_test.go @@ -194,3 +194,20 @@ func TestNewRequestFail(t *testing.T) { t.Fatal("unexpected request") } } + +func TestRequestWithContext(t *testing.T) { + type ctxKey1 struct{} + type ctxKey2 struct{} + + req1, err := NewRequest(context.WithValue(context.Background(), ctxKey1{}, 1), http.MethodPost, testURL) + require.NoError(t, err) + require.NotNil(t, req1.Raw().Context().Value(ctxKey1{})) + + req2 := req1.WithContext(context.WithValue(context.Background(), ctxKey2{}, 1)) + require.Nil(t, req2.Raw().Context().Value(ctxKey1{})) + require.NotNil(t, req2.Raw().Context().Value(ctxKey2{})) + + // shallow copy, so changing req2 affects req1 + req2.Raw().Header.Add("added-req2", "value") + require.EqualValues(t, "value", req1.Raw().Header.Get("added-req2")) +} diff --git a/sdk/azcore/internal/shared/constants.go b/sdk/azcore/internal/shared/constants.go index e02c40275f07..5b60973308b8 100644 --- a/sdk/azcore/internal/shared/constants.go +++ b/sdk/azcore/internal/shared/constants.go @@ -23,6 +23,7 @@ const ( HeaderUserAgent = "User-Agent" HeaderWWWAuthenticate = "WWW-Authenticate" HeaderXMSClientRequestID = "x-ms-client-request-id" + HeaderXMSRequestID = "x-ms-request-id" ) const BearerTokenPrefix = "Bearer " diff --git a/sdk/azcore/internal/shared/shared.go b/sdk/azcore/internal/shared/shared.go index 7c71df307008..55003b10f0f3 100644 --- a/sdk/azcore/internal/shared/shared.go +++ b/sdk/azcore/internal/shared/shared.go @@ -28,6 +28,9 @@ type CtxWithRetryOptionsKey struct{} // CtxIncludeResponseKey is used as a context key for retrieving the raw response. type CtxIncludeResponseKey struct{} +// CtxWithTracingTracer is used as a context key for adding/retrieving tracing.Tracer. +type CtxWithTracingTracer struct{} + // Delay waits for the duration to elapse or the context to be cancelled. func Delay(ctx context.Context, delay time.Duration) error { select { diff --git a/sdk/azcore/policy/policy.go b/sdk/azcore/policy/policy.go index c427e14d88c3..f4684a5385ef 100644 --- a/sdk/azcore/policy/policy.go +++ b/sdk/azcore/policy/policy.go @@ -29,7 +29,8 @@ type Request = exported.Request // ClientOptions contains optional settings for a client's pipeline. // All zero-value fields will be initialized with default values. type ClientOptions struct { - // APIVersion overrides the default version requested of the service. Set with caution as this package version has not been tested with arbitrary service versions. + // APIVersion overrides the default version requested of the service. + // Set with caution as this package version has not been tested with arbitrary service versions. APIVersion string // Cloud specifies a cloud for the client. The default is Azure Public Cloud. diff --git a/sdk/azcore/runtime/pipeline.go b/sdk/azcore/runtime/pipeline.go index 9d9288f53d3d..b317726351cd 100644 --- a/sdk/azcore/runtime/pipeline.go +++ b/sdk/azcore/runtime/pipeline.go @@ -13,9 +13,29 @@ import ( // PipelineOptions contains Pipeline options for SDK developers type PipelineOptions struct { - AllowedHeaders, AllowedQueryParameters []string - APIVersion APIVersionOptions - PerCall, PerRetry []policy.Policy + // AllowedHeaders is the slice of headers to log with their values intact. + // All headers not in the slice will have their values REDACTED. + // Applies to request and response headers. + AllowedHeaders []string + + // AllowedQueryParameters is the slice of query parameters to log with their values intact. + // All query parameters not in the slice will have their values REDACTED. + AllowedQueryParameters []string + + // APIVersion overrides the default version requested of the service. + // Set with caution as this package version has not been tested with arbitrary service versions. + APIVersion APIVersionOptions + + // PerCall contains custom policies to inject into the pipeline. + // Each policy is executed once per request. + PerCall []policy.Policy + + // PerRetry contains custom policies to inject into the pipeline. + // Each policy is executed once per request, and for each retry of that request. + PerRetry []policy.Policy + + // TracingNamespace contains the value to use for the az.namespace span attribute. + TracingNamespace string } // Pipeline represents a primitive for sending HTTP requests and receiving responses. @@ -58,6 +78,7 @@ func NewPipeline(module, version string, plOpts PipelineOptions, options *policy policies = append(policies, cp.PerRetryPolicies...) policies = append(policies, NewLogPolicy(&cp.Logging)) policies = append(policies, exported.PolicyFunc(httpHeaderPolicy), exported.PolicyFunc(bodyDownloadPolicy)) + policies = append(policies, newHTTPTracePolicy(cp.Logging.AllowedQueryParams)) transport := cp.Transport if transport == nil { transport = defaultHTTPClient diff --git a/sdk/azcore/runtime/policy_http_trace.go b/sdk/azcore/runtime/policy_http_trace.go new file mode 100644 index 000000000000..466094e36983 --- /dev/null +++ b/sdk/azcore/runtime/policy_http_trace.go @@ -0,0 +1,114 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package runtime + +import ( + "context" + "fmt" + "net/http" + "strings" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/tracing" +) + +const ( + attrHTTPMethod = "http.method" + attrHTTPURL = "http.url" + attrHTTPUserAgent = "http.user_agent" + attrHTTPStatusCode = "http.status_code" + + attrAZClientReqID = "az.client_request_id" + attrAZServiceReqID = "az.service_request_id" +) + +// newHTTPTracePolicy creates a new instance of the httpTracePolicy. +// - allowedQueryParams contains the user-specified query parameters that don't need to be redacted from the trace +func newHTTPTracePolicy(allowedQueryParams []string) exported.Policy { + return &httpTracePolicy{allowedQP: getAllowedQueryParams(allowedQueryParams)} +} + +// httpTracePolicy is a policy that creates a trace for the HTTP request and its response +type httpTracePolicy struct { + allowedQP map[string]struct{} +} + +// Do implements the pipeline.Policy interfaces for the httpTracePolicy type. +func (h *httpTracePolicy) Do(req *policy.Request) (resp *http.Response, err error) { + rawTracer := req.Raw().Context().Value(shared.CtxWithTracingTracer{}) + if tracer, ok := rawTracer.(tracing.Tracer); ok { + attributes := []tracing.Attribute{ + {Key: attrHTTPMethod, Value: req.Raw().Method}, + {Key: attrHTTPURL, Value: getSanitizedURL(*req.Raw().URL, h.allowedQP)}, + } + + if ua := req.Raw().Header.Get(shared.HeaderUserAgent); ua != "" { + attributes = append(attributes, tracing.Attribute{Key: attrHTTPUserAgent, Value: ua}) + } + if reqID := req.Raw().Header.Get(shared.HeaderXMSClientRequestID); reqID != "" { + attributes = append(attributes, tracing.Attribute{Key: attrAZClientReqID, Value: reqID}) + } + + ctx := req.Raw().Context() + ctx, span := tracer.Start(ctx, "HTTP "+req.Raw().Method, &tracing.SpanOptions{ + Kind: tracing.SpanKindClient, + Attributes: attributes, + }) + + defer func() { + if resp != nil { + span.SetAttributes(tracing.Attribute{Key: attrHTTPStatusCode, Value: resp.StatusCode}) + if resp.StatusCode > 399 { + span.SetStatus(tracing.SpanStatusError, resp.Status) + } + if reqID := resp.Header.Get(shared.HeaderXMSRequestID); reqID != "" { + span.SetAttributes(tracing.Attribute{Key: attrAZServiceReqID, Value: reqID}) + } + } else if err != nil { + // including the output from err.Error() might disclose URL query parameters. + // so instead of attempting to sanitize the output, we simply output the error type. + span.SetStatus(tracing.SpanStatusError, fmt.Sprintf("%T", err)) + } + span.End() + }() + + req = req.WithContext(ctx) + } + resp, err = req.Next() + return +} + +// StartSpanOptions contains the optional values for StartSpan. +type StartSpanOptions struct { + // for future expansion +} + +// StartSpan starts a new tracing span. +// You must call the returned func to terminate the span. Pass the applicable error +// if the span will exit with an error condition. +// - ctx is the parent context of the newly created context +// - name is the name of the span. this is typically the fully qualified name of an API ("Client.Method") +// - tracer is the client's Tracer for creating spans +// - options contains optional values. pass nil to accept any default values +func StartSpan(ctx context.Context, name string, tracer tracing.Tracer, options *StartSpanOptions) (context.Context, func(error)) { + if !tracer.Enabled() { + return ctx, func(err error) {} + } + ctx, span := tracer.Start(ctx, name, &tracing.SpanOptions{ + Kind: tracing.SpanKindInternal, + }) + ctx = context.WithValue(ctx, shared.CtxWithTracingTracer{}, tracer) + return ctx, func(err error) { + if err != nil { + errType := strings.Replace(fmt.Sprintf("%T", err), "*exported.", "*azcore.", 1) + span.SetStatus(tracing.SpanStatusError, fmt.Sprintf("%s:\n%s", errType, err.Error())) + } + span.End() + } +} diff --git a/sdk/azcore/runtime/policy_http_trace_test.go b/sdk/azcore/runtime/policy_http_trace_test.go new file mode 100644 index 000000000000..324bd4d027ff --- /dev/null +++ b/sdk/azcore/runtime/policy_http_trace_test.go @@ -0,0 +1,162 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package runtime + +import ( + "context" + "io" + "net" + "net/http" + "net/url" + "strings" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/tracing" + "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" + "github.com/stretchr/testify/require" +) + +func TestHTTPTracePolicy(t *testing.T) { + srv, close := mock.NewServer() + defer close() + + pl := exported.NewPipeline(srv, newHTTPTracePolicy([]string{"visibleqp"})) + + // no tracer + req, err := exported.NewRequest(context.Background(), http.MethodGet, srv.URL()) + require.NoError(t, err) + srv.AppendResponse() + _, err = pl.Do(req) + require.NoError(t, err) + + // wrong tracer type + req, err = exported.NewRequest(context.WithValue(context.Background(), shared.CtxWithTracingTracer{}, 0), http.MethodGet, srv.URL()) + require.NoError(t, err) + srv.AppendResponse() + _, err = pl.Do(req) + require.NoError(t, err) + + var fullSpanName string + var spanKind tracing.SpanKind + var spanAttrs []tracing.Attribute + var spanStatus tracing.SpanStatus + var spanStatusStr string + tr := tracing.NewTracer(func(ctx context.Context, spanName string, options *tracing.SpanOptions) (context.Context, tracing.Span) { + fullSpanName = spanName + require.NotNil(t, options) + spanKind = options.Kind + spanAttrs = options.Attributes + spanImpl := tracing.SpanImpl{ + SetAttributes: func(a ...tracing.Attribute) { spanAttrs = append(spanAttrs, a...) }, + SetStatus: func(ss tracing.SpanStatus, s string) { + spanStatus = ss + spanStatusStr = s + }, + } + return ctx, tracing.NewSpan(spanImpl) + }, nil) + + // HTTP ok + req, err = exported.NewRequest(context.WithValue(context.Background(), shared.CtxWithTracingTracer{}, tr), http.MethodGet, srv.URL()+"?foo=redactme&visibleqp=bar") + require.NoError(t, err) + req.Raw().Header.Add(shared.HeaderUserAgent, "my-user-agent") + req.Raw().Header.Add(shared.HeaderXMSClientRequestID, "my-client-request") + srv.AppendResponse(mock.WithHeader(shared.HeaderXMSRequestID, "request-id")) + _, err = pl.Do(req) + require.NoError(t, err) + require.EqualValues(t, tracing.SpanStatusUnset, spanStatus) + require.EqualValues(t, "HTTP GET", fullSpanName) + require.EqualValues(t, tracing.SpanKindClient, spanKind) + require.Len(t, spanAttrs, 6) + require.Contains(t, spanAttrs, tracing.Attribute{Key: attrHTTPMethod, Value: http.MethodGet}) + require.Contains(t, spanAttrs, tracing.Attribute{Key: attrHTTPURL, Value: srv.URL() + "?foo=REDACTED&visibleqp=bar"}) + require.Contains(t, spanAttrs, tracing.Attribute{Key: attrHTTPUserAgent, Value: "my-user-agent"}) + require.Contains(t, spanAttrs, tracing.Attribute{Key: attrAZClientReqID, Value: "my-client-request"}) + require.Contains(t, spanAttrs, tracing.Attribute{Key: attrHTTPStatusCode, Value: http.StatusOK}) + require.Contains(t, spanAttrs, tracing.Attribute{Key: attrAZServiceReqID, Value: "request-id"}) + + // HTTP bad request + req, err = exported.NewRequest(context.WithValue(context.Background(), shared.CtxWithTracingTracer{}, tr), http.MethodGet, srv.URL()) + require.NoError(t, err) + srv.AppendResponse(mock.WithStatusCode(http.StatusBadRequest)) + _, err = pl.Do(req) + require.NoError(t, err) + require.EqualValues(t, tracing.SpanStatusError, spanStatus) + require.EqualValues(t, "400 Bad Request", spanStatusStr) + require.Contains(t, spanAttrs, tracing.Attribute{Key: attrHTTPStatusCode, Value: http.StatusBadRequest}) + + // HTTP error + req, err = exported.NewRequest(context.WithValue(context.Background(), shared.CtxWithTracingTracer{}, tr), http.MethodGet, srv.URL()) + require.NoError(t, err) + srv.AppendError(net.ErrClosed) + _, err = pl.Do(req) + require.Error(t, err) + require.ErrorIs(t, err, net.ErrClosed) + require.EqualValues(t, tracing.SpanStatusError, spanStatus) + require.EqualValues(t, "poll.errNetClosing", spanStatusStr) +} + +func TestStartSpan(t *testing.T) { + // tracing disabled + ctx, end := StartSpan(context.Background(), "TestStartSpan", tracing.Tracer{}, nil) + end(nil) + require.Same(t, context.Background(), ctx) + + // span no error + var startCalled bool + var endCalled bool + tr := tracing.NewTracer(func(ctx context.Context, spanName string, options *tracing.SpanOptions) (context.Context, tracing.Span) { + startCalled = true + require.EqualValues(t, "TestStartSpan", spanName) + require.NotNil(t, options) + require.EqualValues(t, tracing.SpanKindInternal, options.Kind) + spanImpl := tracing.SpanImpl{ + End: func() { endCalled = true }, + } + return ctx, tracing.NewSpan(spanImpl) + }, nil) + ctx, end = StartSpan(context.Background(), "TestStartSpan", tr, nil) + end(nil) + ctxTr := ctx.Value(shared.CtxWithTracingTracer{}) + require.NotNil(t, ctxTr) + _, ok := ctxTr.(tracing.Tracer) + require.True(t, ok) + require.True(t, startCalled) + require.True(t, endCalled) + + // with error + var spanStatus tracing.SpanStatus + var errStr string + tr = tracing.NewTracer(func(ctx context.Context, spanName string, options *tracing.SpanOptions) (context.Context, tracing.Span) { + spanImpl := tracing.SpanImpl{ + End: func() { endCalled = true }, + SetStatus: func(ss tracing.SpanStatus, s string) { + spanStatus = ss + errStr = s + }, + } + return ctx, tracing.NewSpan(spanImpl) + }, nil) + _, end = StartSpan(context.Background(), "TestStartSpan", tr, nil) + u, err := url.Parse("https://contoso.com") + require.NoError(t, err) + resp := &http.Response{ + Status: "the operation failed", + StatusCode: http.StatusBadRequest, + Body: io.NopCloser(strings.NewReader(`{ "error": { "code": "ErrorItFailed", "message": "it's not working" } }`)), + Request: &http.Request{ + Method: http.MethodGet, + URL: u, + }, + } + end(exported.NewResponseError(resp)) + require.EqualValues(t, tracing.SpanStatusError, spanStatus) + require.Contains(t, errStr, "*azcore.ResponseError") + require.Contains(t, errStr, "ERROR CODE: ErrorItFailed") +} diff --git a/sdk/azcore/runtime/poller.go b/sdk/azcore/runtime/poller.go index 0be5210a4a38..e33e49ba26e5 100644 --- a/sdk/azcore/runtime/poller.go +++ b/sdk/azcore/runtime/poller.go @@ -13,6 +13,8 @@ import ( "flag" "fmt" "net/http" + "reflect" + "strings" "time" "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported" @@ -23,6 +25,7 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers/loc" "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers/op" "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/tracing" ) // FinalStateVia is the enumerated type for the possible final-state-via values. @@ -53,6 +56,9 @@ type NewPollerOptions[T any] struct { // Handler[T] contains a custom polling implementation. Handler PollingHandler[T] + + // Tracer contains the Tracer from the client that's creating the Poller. + Tracer tracing.Tracer } // NewPoller creates a Poller based on the provided initial response. @@ -69,6 +75,7 @@ func NewPoller[T any](resp *http.Response, pl exported.Pipeline, options *NewPol op: options.Handler, resp: resp, result: result, + tracer: options.Tracer, }, nil } @@ -109,6 +116,7 @@ func NewPoller[T any](resp *http.Response, pl exported.Pipeline, options *NewPol op: opr, resp: resp, result: result, + tracer: options.Tracer, }, nil } @@ -120,6 +128,9 @@ type NewPollerFromResumeTokenOptions[T any] struct { // Handler[T] contains a custom polling implementation. Handler PollingHandler[T] + + // Tracer contains the Tracer from the client that's creating the Poller. + Tracer tracing.Tracer } // NewPollerFromResumeToken creates a Poller from a resume token string. @@ -165,6 +176,7 @@ func NewPollerFromResumeToken[T any](token string, pl exported.Pipeline, options return &Poller[T]{ op: opr, result: result, + tracer: options.Tracer, }, nil } @@ -187,6 +199,7 @@ type Poller[T any] struct { resp *http.Response err error result *T + tracer tracing.Tracer done bool } @@ -202,7 +215,7 @@ type PollUntilDoneOptions struct { // options: pass nil to accept the default values. // NOTE: the default polling frequency is 30 seconds which works well for most operations. However, some operations might // benefit from a shorter or longer duration. -func (p *Poller[T]) PollUntilDone(ctx context.Context, options *PollUntilDoneOptions) (T, error) { +func (p *Poller[T]) PollUntilDone(ctx context.Context, options *PollUntilDoneOptions) (res T, err error) { if options == nil { options = &PollUntilDoneOptions{} } @@ -211,9 +224,13 @@ func (p *Poller[T]) PollUntilDone(ctx context.Context, options *PollUntilDoneOpt cp.Frequency = 30 * time.Second } + ctx, endSpan := StartSpan(ctx, fmt.Sprintf("%s.PollUntilDone", shortenPollerTypeName(reflect.TypeOf(*p).Name())), p.tracer, nil) + defer func() { endSpan(err) }() + // skip the floor check when executing tests so they don't take so long if isTest := flag.Lookup("test.v"); isTest == nil && cp.Frequency < time.Second { - return *new(T), errors.New("polling frequency minimum is one second") + err = errors.New("polling frequency minimum is one second") + return } start := time.Now() @@ -225,22 +242,24 @@ func (p *Poller[T]) PollUntilDone(ctx context.Context, options *PollUntilDoneOpt // initial check for a retry-after header existing on the initial response if retryAfter := shared.RetryAfter(p.resp); retryAfter > 0 { log.Writef(log.EventLRO, "initial Retry-After delay for %s", retryAfter.String()) - if err := shared.Delay(ctx, retryAfter); err != nil { + if err = shared.Delay(ctx, retryAfter); err != nil { logPollUntilDoneExit(err) - return *new(T), err + return } } } // begin polling the endpoint until a terminal state is reached for { - resp, err := p.Poll(ctx) + var resp *http.Response + resp, err = p.Poll(ctx) if err != nil { logPollUntilDoneExit(err) - return *new(T), err + return } if p.Done() { logPollUntilDoneExit("succeeded") - return p.Result(ctx) + res, err = p.Result(ctx) + return } d := cp.Frequency if retryAfter := shared.RetryAfter(resp); retryAfter > 0 { @@ -251,7 +270,7 @@ func (p *Poller[T]) PollUntilDone(ctx context.Context, options *PollUntilDoneOpt } if err = shared.Delay(ctx, d); err != nil { logPollUntilDoneExit(err) - return *new(T), err + return } } } @@ -260,17 +279,22 @@ func (p *Poller[T]) PollUntilDone(ctx context.Context, options *PollUntilDoneOpt // If Poll succeeds, the poller's state is updated and the HTTP response is returned. // If Poll fails, the poller's state is unmodified and the error is returned. // Calling Poll on an LRO that has reached a terminal state will return the last HTTP response. -func (p *Poller[T]) Poll(ctx context.Context) (*http.Response, error) { +func (p *Poller[T]) Poll(ctx context.Context) (resp *http.Response, err error) { if p.Done() { // the LRO has reached a terminal state, don't poll again - return p.resp, nil + resp = p.resp + return } - resp, err := p.op.Poll(ctx) + + ctx, endSpan := StartSpan(ctx, fmt.Sprintf("%s.Poll", shortenPollerTypeName(reflect.TypeOf(*p).Name())), p.tracer, nil) + defer func() { endSpan(err) }() + + resp, err = p.op.Poll(ctx) if err != nil { - return nil, err + return } p.resp = resp - return p.resp, nil + return } // Done returns true if the LRO has reached a terminal state. @@ -283,31 +307,40 @@ func (p *Poller[T]) Done() bool { // If the LRO completed successfully, a populated instance of T is returned. // If the LRO failed or was canceled, an *azcore.ResponseError error is returned. // Calling this on an LRO in a non-terminal state will return an error. -func (p *Poller[T]) Result(ctx context.Context) (T, error) { +func (p *Poller[T]) Result(ctx context.Context) (res T, err error) { if !p.Done() { - return *new(T), errors.New("poller is in a non-terminal state") + err = errors.New("poller is in a non-terminal state") + return } if p.done { // the result has already been retrieved, return the cached value if p.err != nil { - return *new(T), p.err + err = p.err + return } - return *p.result, nil + res = *p.result + return } - err := p.op.Result(ctx, p.result) + + ctx, endSpan := StartSpan(ctx, fmt.Sprintf("%s.Result", shortenPollerTypeName(reflect.TypeOf(*p).Name())), p.tracer, nil) + defer func() { endSpan(err) }() + + err = p.op.Result(ctx, p.result) var respErr *exported.ResponseError if errors.As(err, &respErr) { // the LRO failed. record the error p.err = err } else if err != nil { // the call to Result failed, don't cache anything in this case - return *new(T), err + return } p.done = true if p.err != nil { - return *new(T), p.err + err = p.err + return } - return *p.result, nil + res = *p.result + return } // ResumeToken returns a value representing the poller that can be used to resume @@ -324,3 +357,22 @@ func (p *Poller[T]) ResumeToken() (string, error) { } return tk, err } + +// extracts the type name from the string returned from reflect.Value.Name() +func shortenPollerTypeName(s string) string { + // the value is formatted as follows + // Poller[module/Package.Type].Method + // we want to shorten the generic type parameter string to Type + // anything we don't recognize will be left as-is + begin := strings.Index(s, "[") + end := strings.Index(s, "]") + if begin == -1 || end == -1 { + return s + } + + typeName := s[begin+1 : end] + if i := strings.LastIndex(typeName, "."); i > -1 { + typeName = typeName[i+1:] + } + return s[:begin+1] + typeName + s[end:] +} diff --git a/sdk/azcore/runtime/poller_test.go b/sdk/azcore/runtime/poller_test.go index 3ce04097a0e5..c91b806322a0 100644 --- a/sdk/azcore/runtime/poller_test.go +++ b/sdk/azcore/runtime/poller_test.go @@ -1176,3 +1176,17 @@ func TestNewPollerWithCustomHandler(t *testing.T) { require.NoError(t, err) require.EqualValues(t, "value", *result.Field) } + +func TestShortenPollerTypeName(t *testing.T) { + result := shortenPollerTypeName("Poller[module/package.ClientOperationResponse].PollUntilDone") + require.EqualValues(t, "Poller[ClientOperationResponse].PollUntilDone", result) + + result = shortenPollerTypeName("Poller[package.ClientOperationResponse].PollUntilDone") + require.EqualValues(t, "Poller[ClientOperationResponse].PollUntilDone", result) + + result = shortenPollerTypeName("Poller[ClientOperationResponse].PollUntilDone") + require.EqualValues(t, "Poller[ClientOperationResponse].PollUntilDone", result) + + result = shortenPollerTypeName("Poller.PollUntilDone") + require.EqualValues(t, "Poller.PollUntilDone", result) +} diff --git a/sdk/azcore/tracing/tracing.go b/sdk/azcore/tracing/tracing.go index 75f757cedd3b..f5157005f555 100644 --- a/sdk/azcore/tracing/tracing.go +++ b/sdk/azcore/tracing/tracing.go @@ -45,21 +45,28 @@ func (p Provider) NewTracer(name, version string) (tracer Tracer) { // TracerOptions contains the optional values when creating a Tracer. type TracerOptions struct { - // for future expansion + // SpanFromContext contains the implementation for the Tracer.SpanFromContext method. + SpanFromContext func(context.Context) (Span, bool) } // NewTracer creates a Tracer with the specified values. // - newSpanFn is the underlying implementation for creating Span instances // - options contains optional values; pass nil to accept the default value func NewTracer(newSpanFn func(ctx context.Context, spanName string, options *SpanOptions) (context.Context, Span), options *TracerOptions) Tracer { + if options == nil { + options = &TracerOptions{} + } return Tracer{ - newSpanFn: newSpanFn, + newSpanFn: newSpanFn, + spanFromContextFn: options.SpanFromContext, } } // Tracer is the factory that creates Span instances. type Tracer struct { - newSpanFn func(ctx context.Context, spanName string, options *SpanOptions) (context.Context, Span) + attrs []Attribute + newSpanFn func(ctx context.Context, spanName string, options *SpanOptions) (context.Context, Span) + spanFromContextFn func(ctx context.Context) (Span, bool) } // Start creates a new span and a context.Context that contains it. @@ -68,11 +75,37 @@ type Tracer struct { // - options contains optional values for the span, pass nil to accept any defaults func (t Tracer) Start(ctx context.Context, spanName string, options *SpanOptions) (context.Context, Span) { if t.newSpanFn != nil { - return t.newSpanFn(ctx, spanName, options) + opts := SpanOptions{} + if options != nil { + opts = *options + } + opts.Attributes = append(opts.Attributes, t.attrs...) + return t.newSpanFn(ctx, spanName, &opts) } return ctx, Span{} } +// SetAttributes sets attrs to be applied to each Span. If a key from attrs +// already exists for an attribute of the Span it will be overwritten with +// the value contained in attrs. +func (t *Tracer) SetAttributes(attrs ...Attribute) { + t.attrs = append(t.attrs, attrs...) +} + +// Enabled returns true if this Tracer is capable of creating Spans. +func (t Tracer) Enabled() bool { + return t.newSpanFn != nil +} + +// SpanFromContext returns the Span associated with the current context. +// If the provided context has no Span, false is returned. +func (t Tracer) SpanFromContext(ctx context.Context) (Span, bool) { + if t.spanFromContextFn != nil { + return t.spanFromContextFn(ctx) + } + return Span{}, false +} + // SpanOptions contains optional settings for creating a span. type SpanOptions struct { // Kind indicates the kind of Span. diff --git a/sdk/azcore/tracing/tracing_test.go b/sdk/azcore/tracing/tracing_test.go index da04627e3167..5ca8b3f267de 100644 --- a/sdk/azcore/tracing/tracing_test.go +++ b/sdk/azcore/tracing/tracing_test.go @@ -17,6 +17,8 @@ func TestProviderZeroValues(t *testing.T) { pr := Provider{} tr := pr.NewTracer("name", "version") require.Zero(t, tr) + require.False(t, tr.Enabled()) + tr.SetAttributes() ctx, sp := tr.Start(context.Background(), "spanName", nil) require.Equal(t, context.Background(), ctx) require.Zero(t, sp) @@ -25,6 +27,9 @@ func TestProviderZeroValues(t *testing.T) { sp.End() sp.SetAttributes(Attribute{}) sp.SetStatus(SpanStatusError, "boom") + sp, ok := tr.SpanFromContext(ctx) + require.False(t, ok) + require.Zero(t, sp) } func TestProvider(t *testing.T) { @@ -33,6 +38,7 @@ func TestProvider(t *testing.T) { var endCalled bool var setAttributesCalled bool var setStatusCalled bool + var spanFromContextCalled bool pr := NewProvider(func(name, version string) Tracer { return NewTracer(func(context.Context, string, *SpanOptions) (context.Context, Span) { @@ -43,10 +49,23 @@ func TestProvider(t *testing.T) { SetAttributes: func(...Attribute) { setAttributesCalled = true }, SetStatus: func(SpanStatus, string) { setStatusCalled = true }, }) - }, nil) + }, &TracerOptions{ + SpanFromContext: func(context.Context) (Span, bool) { + spanFromContextCalled = true + return Span{}, true + }, + }) }, nil) tr := pr.NewTracer("name", "version") require.NotZero(t, tr) + require.True(t, tr.Enabled()) + sp, ok := tr.SpanFromContext(context.Background()) + require.True(t, ok) + require.Zero(t, sp) + tr.SetAttributes(Attribute{Key: "some", Value: "attribute"}) + require.Len(t, tr.attrs, 1) + require.EqualValues(t, tr.attrs[0].Key, "some") + require.EqualValues(t, tr.attrs[0].Value, "attribute") ctx, sp := tr.Start(context.Background(), "name", nil) require.NotEqual(t, context.Background(), ctx) @@ -62,4 +81,5 @@ func TestProvider(t *testing.T) { require.True(t, endCalled) require.True(t, setAttributesCalled) require.True(t, setStatusCalled) + require.True(t, spanFromContextCalled) }