Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sanitize Azure.Core HTTP Activity URL #24545

Merged
merged 2 commits into from
Oct 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions sdk/core/Azure.Core/src/Pipeline/HttpPipelineBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ void AddCustomerPolicies(HttpPipelinePosition position)
}
}

DiagnosticsOptions diagnostics = options.Diagnostics;

var sanitizer = new HttpMessageSanitizer(diagnostics.LoggedHeaderNames.ToArray(), diagnostics.LoggedQueryParameters.ToArray());

bool isDistributedTracingEnabled = options.Diagnostics.IsDistributedTracingEnabled;

policies.Add(ReadClientRequestIdPolicy.Shared);
Expand All @@ -73,7 +77,6 @@ void AddCustomerPolicies(HttpPipelinePosition position)

policies.Add(ClientRequestIdPolicy.Shared);

DiagnosticsOptions diagnostics = options.Diagnostics;
if (diagnostics.IsTelemetryEnabled)
{
policies.Add(CreateTelemetryPolicy(options));
Expand All @@ -92,13 +95,12 @@ void AddCustomerPolicies(HttpPipelinePosition position)
{
string assemblyName = options.GetType().Assembly!.GetName().Name!;

policies.Add(new LoggingPolicy(diagnostics.IsLoggingContentEnabled, diagnostics.LoggedContentSizeLimit,
diagnostics.LoggedHeaderNames.ToArray(), diagnostics.LoggedQueryParameters.ToArray(), assemblyName));
policies.Add(new LoggingPolicy(diagnostics.IsLoggingContentEnabled, diagnostics.LoggedContentSizeLimit, sanitizer, assemblyName));
}

policies.Add(new ResponseBodyPolicy(options.Retry.NetworkTimeout));

policies.Add(new RequestActivityPolicy(isDistributedTracingEnabled, ClientDiagnostics.GetResourceProviderNamespace(options.GetType().Assembly)));
policies.Add(new RequestActivityPolicy(isDistributedTracingEnabled, ClientDiagnostics.GetResourceProviderNamespace(options.GetType().Assembly), sanitizer));

AddCustomerPolicies(HttpPipelinePosition.BeforeTransport);

Expand Down
4 changes: 2 additions & 2 deletions sdk/core/Azure.Core/src/Pipeline/Internal/LoggingPolicy.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ namespace Azure.Core.Pipeline
{
internal class LoggingPolicy : HttpPipelinePolicy
{
public LoggingPolicy(bool logContent, int maxLength, string[] allowedHeaderNames, string[] allowedQueryParameters, string? assemblyName)
public LoggingPolicy(bool logContent, int maxLength, HttpMessageSanitizer sanitizer, string? assemblyName)
{
_sanitizer = new HttpMessageSanitizer(allowedQueryParameters, allowedHeaderNames);
_sanitizer = sanitizer;
_logContent = logContent;
_maxLength = maxLength;
_assemblyName = assemblyName;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ internal class RequestActivityPolicy : HttpPipelinePolicy
{
private readonly bool _isDistributedTracingEnabled;
private readonly string? _resourceProviderNamespace;
private readonly HttpMessageSanitizer _sanitizer;

private const string TraceParentHeaderName = "traceparent";
private const string TraceStateHeaderName = "tracestate";
Expand All @@ -20,10 +21,11 @@ internal class RequestActivityPolicy : HttpPipelinePolicy
private static readonly DiagnosticListener s_diagnosticSource = new DiagnosticListener("Azure.Core");
private static readonly object? s_activitySource = ActivityExtensions.CreateActivitySource("Azure.Core.Http");

public RequestActivityPolicy(bool isDistributedTracingEnabled, string? resourceProviderNamespace)
public RequestActivityPolicy(bool isDistributedTracingEnabled, string? resourceProviderNamespace, HttpMessageSanitizer httpMessageSanitizer)
{
_isDistributedTracingEnabled = isDistributedTracingEnabled;
_resourceProviderNamespace = resourceProviderNamespace;
_sanitizer = httpMessageSanitizer;
}

public override ValueTask ProcessAsync(HttpMessage message, ReadOnlyMemory<HttpPipelinePolicy> pipeline)
Expand Down Expand Up @@ -54,7 +56,7 @@ private async ValueTask ProcessAsync(HttpMessage message, ReadOnlyMemory<HttpPip
{
using var scope = new DiagnosticScope("Azure.Core.Http.Request", s_diagnosticSource, message, s_activitySource, DiagnosticScope.ActivityKind.Client);
scope.AddAttribute("http.method", message.Request.Method.Method);
scope.AddAttribute("http.url", message.Request.Uri.ToString());
scope.AddAttribute("http.url", _sanitizer.SanitizeUrl(message.Request.Uri.ToString()));
scope.AddAttribute("requestId", message.Request.ClientRequestId);

if (_resourceProviderNamespace != null)
Expand Down
29 changes: 15 additions & 14 deletions sdk/core/Azure.Core/tests/EventSourceTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,9 @@ public class EventSourceTests : SyncAsyncPolicyTestBase

private TestEventListener _listener;

private string[] s_allowedHeaders = new[] { "Date", "Custom-Header", "Custom-Response-Header" };
private string[] s_allowedQueryParameters = new[] { "api-version" };
private static string[] s_allowedHeaders = new[] { "Date", "Custom-Header", "Custom-Response-Header" };
private static string[] s_allowedQueryParameters = new[] { "api-version" };
private static HttpMessageSanitizer _sanitizer = new HttpMessageSanitizer(s_allowedQueryParameters, s_allowedHeaders);

public EventSourceTests(bool isAsync) : base(isAsync)
{
Expand Down Expand Up @@ -79,7 +80,7 @@ public async Task SendingRequestProducesEvents()

MockTransport mockTransport = CreateMockTransport(response);

var pipeline = new HttpPipeline(mockTransport, new[] { new LoggingPolicy(logContent: true, int.MaxValue, s_allowedHeaders, s_allowedQueryParameters, "Test-SDK") });
var pipeline = new HttpPipeline(mockTransport, new[] { new LoggingPolicy(logContent: true, int.MaxValue, _sanitizer, "Test-SDK") });
string requestId = null;

await SendRequestAsync(pipeline, request =>
Expand Down Expand Up @@ -131,7 +132,7 @@ public void GettingExceptionResponseProducesEvents()
throw exception;
});

var pipeline = new HttpPipeline(mockTransport, new[] { new LoggingPolicy(logContent: true, int.MaxValue, s_allowedHeaders, s_allowedQueryParameters, "Test-SDK") });
var pipeline = new HttpPipeline(mockTransport, new[] { new LoggingPolicy(logContent: true, int.MaxValue, _sanitizer, "Test-SDK") });
string requestId = null;

Assert.ThrowsAsync<InvalidOperationException>(async () => await SendRequestAsync(pipeline, request =>
Expand Down Expand Up @@ -172,7 +173,7 @@ public async Task FailingAccessTokenBackgroundRefreshProducesEvents()
return new MockResponse(200);
});

var pipeline = new HttpPipeline(mockTransport, new HttpPipelinePolicy[] { policy, new LoggingPolicy(logContent: true, int.MaxValue, s_allowedHeaders, s_allowedQueryParameters, "Test-SDK") });
var pipeline = new HttpPipeline(mockTransport, new HttpPipelinePolicy[] { policy, new LoggingPolicy(logContent: true, int.MaxValue, _sanitizer, "Test-SDK") });
await SendRequestAsync(pipeline, request =>
{
request.Method = RequestMethod.Get;
Expand Down Expand Up @@ -207,7 +208,7 @@ public async Task GettingErrorRequestProducesEvents()

MockTransport mockTransport = CreateMockTransport(response);

var pipeline = new HttpPipeline(mockTransport, new[] { new LoggingPolicy(logContent: true, int.MaxValue, s_allowedHeaders, s_allowedQueryParameters, "Test-SDK") });
var pipeline = new HttpPipeline(mockTransport, new[] { new LoggingPolicy(logContent: true, int.MaxValue, _sanitizer, "Test-SDK") });
string requestId = null;

await SendRequestAsync(pipeline, request =>
Expand Down Expand Up @@ -240,7 +241,7 @@ public async Task RequestContentIsLoggedAsText()
var response = new MockResponse(500);
MockTransport mockTransport = CreateMockTransport(response);

var pipeline = new HttpPipeline(mockTransport, new[] { new LoggingPolicy(logContent: true, int.MaxValue, s_allowedHeaders, s_allowedQueryParameters, "Test-SDK") });
var pipeline = new HttpPipeline(mockTransport, new[] { new LoggingPolicy(logContent: true, int.MaxValue, _sanitizer, "Test-SDK") });
string requestId = null;

await SendRequestAsync(pipeline, request =>
Expand Down Expand Up @@ -270,7 +271,7 @@ public async Task ContentIsNotLoggedAsTextWhenDisabled()

MockTransport mockTransport = CreateMockTransport(response);

var pipeline = new HttpPipeline(mockTransport, new[] { new LoggingPolicy(logContent: false, int.MaxValue, s_allowedHeaders, s_allowedQueryParameters, "Test-SDK") });
var pipeline = new HttpPipeline(mockTransport, new[] { new LoggingPolicy(logContent: false, int.MaxValue, _sanitizer, "Test-SDK") });

await SendRequestAsync(pipeline, request =>
{
Expand All @@ -291,7 +292,7 @@ public async Task ContentIsNotLoggedWhenDisabled()

MockTransport mockTransport = CreateMockTransport(response);

var pipeline = new HttpPipeline(mockTransport, new[] { new LoggingPolicy(logContent: false, int.MaxValue, s_allowedHeaders, s_allowedQueryParameters, "Test-SDK") });
var pipeline = new HttpPipeline(mockTransport, new[] { new LoggingPolicy(logContent: false, int.MaxValue, _sanitizer, "Test-SDK") });

await SendRequestAsync(pipeline, request =>
{
Expand All @@ -311,7 +312,7 @@ public async Task RequestContentIsNotLoggedWhenDisabled()

MockTransport mockTransport = CreateMockTransport(response);

var pipeline = new HttpPipeline(mockTransport, new[] { new LoggingPolicy(logContent: false, int.MaxValue, s_allowedHeaders, s_allowedQueryParameters, "Test-SDK") });
var pipeline = new HttpPipeline(mockTransport, new[] { new LoggingPolicy(logContent: false, int.MaxValue, _sanitizer, "Test-SDK") });

await SendRequestAsync(pipeline, request =>
{
Expand Down Expand Up @@ -500,7 +501,7 @@ public async Task RequestContentLogsAreLimitedInLength()
var response = new MockResponse(500);
MockTransport mockTransport = CreateMockTransport(response);

var pipeline = new HttpPipeline(mockTransport, new[] { new LoggingPolicy(logContent: true, 5, s_allowedHeaders, s_allowedQueryParameters, "Test-SDK") });
var pipeline = new HttpPipeline(mockTransport, new[] { new LoggingPolicy(logContent: true, 5, _sanitizer, "Test-SDK") });
string requestId = null;

await SendRequestAsync(pipeline, request =>
Expand Down Expand Up @@ -554,7 +555,7 @@ public async Task HeadersAndQueryParametersAreSanitized()

MockTransport mockTransport = CreateMockTransport(response);

var pipeline = new HttpPipeline(mockTransport, new[] { new LoggingPolicy(logContent: false, int.MaxValue, s_allowedHeaders, s_allowedQueryParameters, "Test-SDK") });
var pipeline = new HttpPipeline(mockTransport, new[] { new LoggingPolicy(logContent: false, int.MaxValue, _sanitizer, "Test-SDK") });
string requestId = null;

await SendRequestAsync(pipeline, request =>
Expand Down Expand Up @@ -598,7 +599,7 @@ public async Task HeadersAndQueryParametersAreNotSanitizedWhenStars()

MockTransport mockTransport = CreateMockTransport(response);

var pipeline = new HttpPipeline(mockTransport, new[] { new LoggingPolicy(logContent: false, int.MaxValue, new[] { "*" }, new[] { "*" }, "Test-SDK") });
var pipeline = new HttpPipeline(mockTransport, new[] { new LoggingPolicy(logContent: false, int.MaxValue, new HttpMessageSanitizer(new[] { "*" }, new[] { "*" }), "Test-SDK") });
string requestId = null;

await SendRequestAsync(pipeline, request =>
Expand Down Expand Up @@ -647,7 +648,7 @@ private async Task<Response> SendRequest(bool isSeekable, bool isError, Action<M
setupRequest?.Invoke(mockResponse);

MockTransport mockTransport = CreateMockTransport(mockResponse);
var pipeline = new HttpPipeline(mockTransport, new[] { new LoggingPolicy(logContent: true, maxLength, s_allowedHeaders, s_allowedQueryParameters, "Test-SDK") });
var pipeline = new HttpPipeline(mockTransport, new[] { new LoggingPolicy(logContent: true, maxLength, _sanitizer, "Test-SDK") });

Response response = await SendRequestAsync(pipeline, request =>
{
Expand Down
33 changes: 31 additions & 2 deletions sdk/core/Azure.Core/tests/RequestActivityPolicyTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ public RequestActivityPolicyTests(bool isAsync) : base(isAsync)
{
}

private static readonly RequestActivityPolicy s_enabledPolicy = new RequestActivityPolicy(true, "Microsoft.Azure.Core.Cool.Tests");
private static string[] s_allowedQueryParameters = new[] { "api-version" };
private static HttpMessageSanitizer _sanitizer = new HttpMessageSanitizer(s_allowedQueryParameters, Array.Empty<string>());
private static readonly RequestActivityPolicy s_enabledPolicy = new RequestActivityPolicy(true, "Microsoft.Azure.Core.Cool.Tests", _sanitizer);

[Test]
[NonParallelizable]
Expand Down Expand Up @@ -66,6 +68,33 @@ public async Task ActivityIsCreatedForRequest()
CollectionAssert.Contains(activity.Tags, new KeyValuePair<string, string>("az.namespace", "Microsoft.Azure.Core.Cool.Tests"));
}

[Test]
[NonParallelizable]
public async Task UriAttributeIsSanitized()
{
Activity activity = null;
using var testListener = new TestDiagnosticListener("Azure.Core");

MockTransport mockTransport = CreateMockTransport(_ =>
{
activity = Activity.Current;
return new MockResponse(201);
});

string clientRequestId = null;
Task<Response> requestTask = SendRequestAsync(mockTransport, request =>
{
request.Method = RequestMethod.Get;
request.Uri.Reset(new Uri("http://example.com?api-version=v2&sas=secret value"));
clientRequestId = request.ClientRequestId;
}, s_enabledPolicy);

await requestTask;

CollectionAssert.Contains(activity.Tags, new KeyValuePair<string, string>("http.url", "http://example.com/?api-version=v2&sas=REDACTED"));
CollectionAssert.IsEmpty(activity.Tags.Where(kvp => kvp.Value.Contains("secret")));
}

[Test]
[NonParallelizable]
public async Task ActivityMarkedAsErrorForErrorResponse()
Expand Down Expand Up @@ -239,7 +268,7 @@ public async Task ActivityIsNotCreatedWhenDisabled()

var transport = new MockTransport(new MockResponse(200));

await SendGetRequest(transport, new RequestActivityPolicy(isDistributedTracingEnabled: false, "Microsoft.Azure.Core.Cool.Tests"));
await SendGetRequest(transport, new RequestActivityPolicy(isDistributedTracingEnabled: false, "Microsoft.Azure.Core.Cool.Tests", _sanitizer));

Assert.AreEqual(0, testListener.Events.Count);
}
Expand Down