Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 61 additions & 24 deletions src/Microsoft.Identity.Web.DownstreamApi/DownstreamApi.cs
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.IO;
using System.Linq;
using System.Net.Http;
using System.Runtime.CompilerServices;
using System.Security.Claims;
using System.Text;
using System.Text.Json;
using System.Text.Json.Serialization.Metadata;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using Microsoft.Identity.Abstractions;
using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.IO;
using System.Linq;
using System.Net.Http;
using System.Runtime.CompilerServices;
using System.Security.Claims;
using System.Text;
using System.Text.Json;
using System.Text.Json.Serialization.Metadata;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using Microsoft.Identity.Abstractions;
using Microsoft.Identity.Client;

namespace Microsoft.Identity.Web
Expand Down Expand Up @@ -582,11 +582,48 @@ public Task<HttpResponseMessage> CallApiForAppAsync(
{
Logger.UnauthenticatedApiCall(_logger, null);
}
if (!string.IsNullOrEmpty(effectiveOptions.AcceptHeader))
{
httpRequestMessage.Headers.Accept.ParseAdd(effectiveOptions.AcceptHeader);
}
// Opportunity to change the request message
if (!string.IsNullOrEmpty(effectiveOptions.AcceptHeader))
{
httpRequestMessage.Headers.Accept.ParseAdd(effectiveOptions.AcceptHeader);
}

// Add extra headers if specified directly on DownstreamApiOptions
if (effectiveOptions.ExtraHeaderParameters != null)
{
foreach (var header in effectiveOptions.ExtraHeaderParameters)
{
httpRequestMessage.Headers.TryAddWithoutValidation(header.Key, header.Value);
}
}

// Add extra query parameters if specified directly on DownstreamApiOptions
if (effectiveOptions.ExtraQueryParameters != null && effectiveOptions.ExtraQueryParameters.Count > 0)
{
var uriBuilder = new UriBuilder(httpRequestMessage.RequestUri!);
var existingQuery = uriBuilder.Query;
var queryString = new StringBuilder(existingQuery);

foreach (var queryParam in effectiveOptions.ExtraQueryParameters)
{
if (queryString.Length > 1) // if there are existing query parameters
{
queryString.Append('&');
}
else if (queryString.Length == 0)
{
queryString.Append('?');
}

queryString.Append(Uri.EscapeDataString(queryParam.Key));
queryString.Append('=');
queryString.Append(Uri.EscapeDataString(queryParam.Value));
}

uriBuilder.Query = queryString.ToString().TrimStart('?');
httpRequestMessage.RequestUri = uriBuilder.Uri;
}

// Opportunity to change the request message
effectiveOptions.CustomizeHttpRequestMessage?.Invoke(httpRequestMessage);
}

Expand All @@ -608,7 +645,7 @@ private static void AddCallerSDKTelemetry(DownstreamApiOptions effectiveOptions)
CallerSDKDetails["caller-sdk-id"];
effectiveOptions.AcquireTokenOptions.ExtraQueryParameters["caller-sdk-ver"] =
CallerSDKDetails["caller-sdk-ver"];
}
}
}
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Net.Http;
using System.Security.Claims;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using Microsoft.Identity.Abstractions;
using Microsoft.Identity.Web.Test.Resource;
using Xunit;

namespace Microsoft.Identity.Web.Tests
{
public class ExtraParametersTests
{
private readonly IAuthorizationHeaderProvider _authorizationHeaderProvider;
private readonly IHttpClientFactory _httpClientFactory;
private readonly IOptionsMonitor<DownstreamApiOptions> _namedDownstreamApiOptions;
private readonly ILogger<DownstreamApi> _logger;
private readonly DownstreamApi _downstreamApi;

public ExtraParametersTests()
{
_authorizationHeaderProvider = new MyAuthorizationHeaderProvider();
_httpClientFactory = new HttpClientFactoryTest();
_namedDownstreamApiOptions = new MyMonitor();
_logger = new LoggerFactory().CreateLogger<DownstreamApi>();

_downstreamApi = new DownstreamApi(
_authorizationHeaderProvider,
_namedDownstreamApiOptions,
_httpClientFactory,
_logger);
}

[Fact]
public async Task UpdateRequestAsync_WithExtraHeaderParameters_AddsHeadersToRequest()
{
// Arrange
var httpRequestMessage = new HttpRequestMessage(HttpMethod.Get, "https://example.com/api");
var options = new DownstreamApiOptions()
{
ExtraHeaderParameters = new Dictionary<string, string>
{
{ "OData-Version", "4.0" },
{ "Custom-Header", "test-value" }
}
};

// Act
await _downstreamApi.UpdateRequestAsync(httpRequestMessage, null, options, false, null, CancellationToken.None);

// Assert
Assert.True(httpRequestMessage.Headers.Contains("OData-Version"));
Assert.True(httpRequestMessage.Headers.Contains("Custom-Header"));
Assert.Equal("4.0", httpRequestMessage.Headers.GetValues("OData-Version").First());
Assert.Equal("test-value", httpRequestMessage.Headers.GetValues("Custom-Header").First());
}

[Fact]
public async Task UpdateRequestAsync_WithExtraQueryParameters_AddsQueryParametersToUrl()
{
// Arrange
var httpRequestMessage = new HttpRequestMessage(HttpMethod.Get, "https://example.com/api");
var options = new DownstreamApiOptions()
{
ExtraQueryParameters = new Dictionary<string, string>
{
{ "param1", "value1" },
{ "param2", "value2" }
}
};

// Act
await _downstreamApi.UpdateRequestAsync(httpRequestMessage, null, options, false, null, CancellationToken.None);

// Assert
var requestUri = httpRequestMessage.RequestUri!.ToString();
Assert.Contains("param1=value1", requestUri, StringComparison.Ordinal);
Assert.Contains("param2=value2", requestUri, StringComparison.Ordinal);
}

[Fact]
public async Task UpdateRequestAsync_WithExtraQueryParameters_AppendsToExistingQuery()
{
// Arrange
var httpRequestMessage = new HttpRequestMessage(HttpMethod.Get, "https://example.com/api?existing=true");
var options = new DownstreamApiOptions()
{
ExtraQueryParameters = new Dictionary<string, string>
{
{ "new", "param" }
}
};

// Act
await _downstreamApi.UpdateRequestAsync(httpRequestMessage, null, options, false, null, CancellationToken.None);

// Assert
var requestUri = httpRequestMessage.RequestUri!.ToString();
Assert.Contains("existing=true", requestUri, StringComparison.Ordinal);
Assert.Contains("new=param", requestUri, StringComparison.Ordinal);
}

[Fact]
public async Task UpdateRequestAsync_WithoutExtraParameters_DoesNotModifyRequest()
{
// Arrange
var originalUri = "https://example.com/api";
var httpRequestMessage = new HttpRequestMessage(HttpMethod.Get, originalUri);
var options = new DownstreamApiOptions(); // No extra parameters

// Act
await _downstreamApi.UpdateRequestAsync(httpRequestMessage, null, options, false, null, CancellationToken.None);

// Assert
Assert.Equal(originalUri, httpRequestMessage.RequestUri!.ToString());
}


[Fact]
public async Task UpdateRequestAsync_WithEmptyExtraParameters_DoesNotModifyRequest()
{
// Arrange
var originalUri = "https://example.com/api";
var httpRequestMessage = new HttpRequestMessage(HttpMethod.Get, originalUri);
var options = new DownstreamApiOptions()
{
ExtraHeaderParameters = new Dictionary<string, string>(), // Empty dictionary
ExtraQueryParameters = new Dictionary<string, string>() // Empty dictionary
};

// Act
await _downstreamApi.UpdateRequestAsync(httpRequestMessage, null, options, false, null, CancellationToken.None);

// Assert
Assert.Equal(originalUri, httpRequestMessage.RequestUri!.ToString());
}

[Fact]
public async Task UpdateRequestAsync_WithSpecialCharacters_EscapesCorrectly()
{
// Arrange
var httpRequestMessage = new HttpRequestMessage(HttpMethod.Get, "https://example.com/api");
var options = new DownstreamApiOptions()
{
ExtraQueryParameters = new Dictionary<string, string>
{
{ "special", "value with spaces & symbols" }
}
};

// Act
await _downstreamApi.UpdateRequestAsync(httpRequestMessage, null, options, false, null, CancellationToken.None);

// Assert
var requestUri = httpRequestMessage.RequestUri!.ToString();
Assert.Contains("special=value with spaces %26 symbols", requestUri, StringComparison.Ordinal);
}

private class MyAuthorizationHeaderProvider : IAuthorizationHeaderProvider
{
public Task<string> CreateAuthorizationHeaderForAppAsync(string scopes, AuthorizationHeaderProviderOptions? downstreamApiOptions = null, CancellationToken cancellationToken = default)
{
return Task.FromResult("Bearer ey");
}

public Task<string> CreateAuthorizationHeaderForUserAsync(IEnumerable<string> scopes, AuthorizationHeaderProviderOptions? authorizationHeaderProviderOptions = null, ClaimsPrincipal? claimsPrincipal = null, CancellationToken cancellationToken = default)
{
return Task.FromResult("Bearer ey");
}

public Task<string> CreateAuthorizationHeaderAsync(IEnumerable<string> scopes, AuthorizationHeaderProviderOptions? authorizationHeaderProviderOptions = null, ClaimsPrincipal? claimsPrincipal = null, CancellationToken cancellationToken = default)
{
return Task.FromResult("Bearer ey");
}
}

private class MyMonitor : IOptionsMonitor<DownstreamApiOptions>
{
public DownstreamApiOptions CurrentValue => new DownstreamApiOptions();

public DownstreamApiOptions Get(string? name)
{
return new DownstreamApiOptions();
}

public DownstreamApiOptions Get(string name, string key)
{
return new DownstreamApiOptions();
}

public IDisposable OnChange(Action<DownstreamApiOptions, string> listener)
{
throw new NotImplementedException();
}
}
}
}
Loading