Skip to content

[WIP] Aggregate multiple Produces for same status code but different content-types #62055

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 35 additions & 11 deletions src/Mvc/Mvc.ApiExplorer/src/ApiResponseTypeProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ namespace Microsoft.AspNetCore.Mvc.ApiExplorer;

internal sealed class ApiResponseTypeProvider
{
internal readonly record struct ResponseKey(
int StatusCode,
Type? DeclaredType,
string? ContentType);

private readonly IModelMetadataProvider _modelMetadataProvider;
private readonly IActionResultTypeMapper _mapper;
private readonly MvcOptions _mvcOptions;
Expand Down Expand Up @@ -87,9 +92,7 @@ private ICollection<ApiResponseType> GetApiResponseTypes(
responseTypeMetadataProviders,
_modelMetadataProvider);

// Read response metadata from providers and
// overwrite responseTypes from the metadata based
// on the status code
// Read response metadata from providers
var responseTypesFromProvider = ReadResponseMetadata(
responseMetadataAttributes,
type,
Expand All @@ -98,6 +101,7 @@ private ICollection<ApiResponseType> GetApiResponseTypes(
out var _,
responseTypeMetadataProviders);

// Merge the response types
foreach (var responseType in responseTypesFromProvider)
{
responseTypes[responseType.Key] = responseType.Value;
Expand All @@ -106,7 +110,11 @@ private ICollection<ApiResponseType> GetApiResponseTypes(
// Set the default status only when no status has already been set explicitly
if (responseTypes.Count == 0 && type != null)
{
responseTypes.Add(StatusCodes.Status200OK, new ApiResponseType
var key = new ResponseKey(
StatusCodes.Status200OK,
type,
null);
responseTypes.Add(key, new ApiResponseType
{
StatusCode = StatusCodes.Status200OK,
Type = type,
Expand All @@ -128,11 +136,16 @@ private ICollection<ApiResponseType> GetApiResponseTypes(
CalculateResponseFormatForType(apiResponse, contentTypes, responseTypeMetadataProviders, _modelMetadataProvider);
}

return responseTypes.Values;
// Order the response types by status code, type name, and content type for consistent output
return responseTypes.Values
.OrderBy(r => r.StatusCode)
.ThenBy(r => r.Type?.Name)
.ThenBy(r => r.ApiResponseFormats.FirstOrDefault()?.MediaType)
.ToList();
}

// Shared with EndpointMetadataApiDescriptionProvider
internal static Dictionary<int, ApiResponseType> ReadResponseMetadata(
internal static Dictionary<ResponseKey, ApiResponseType> ReadResponseMetadata(
IReadOnlyList<IApiResponseMetadataProvider> responseMetadataAttributes,
Type? type,
Type? defaultErrorType,
Expand All @@ -142,7 +155,7 @@ internal static Dictionary<int, ApiResponseType> ReadResponseMetadata(
IModelMetadataProvider? modelMetadataProvider = null)
{
errorSetByDefault = false;
var results = new Dictionary<int, ApiResponseType>();
var results = new Dictionary<ResponseKey, ApiResponseType>();

// Get the content type that the action explicitly set to support.
// Walk through all 'filter' attributes in order, and allow each one to see or override
Expand Down Expand Up @@ -213,21 +226,27 @@ internal static Dictionary<int, ApiResponseType> ReadResponseMetadata(

if (apiResponseType.Type != null)
{
results[apiResponseType.StatusCode] = apiResponseType;
var mediaType = apiResponseType.ApiResponseFormats.FirstOrDefault()?.MediaType;
var key = new ResponseKey(
apiResponseType.StatusCode,
apiResponseType.Type,
mediaType);

results[key] = apiResponseType;
}
}
}

return results;
}

internal static Dictionary<int, ApiResponseType> ReadResponseMetadata(
internal static Dictionary<ResponseKey, ApiResponseType> ReadResponseMetadata(
IReadOnlyList<IProducesResponseTypeMetadata> responseMetadata,
Type? type,
IEnumerable<IApiResponseTypeMetadataProvider>? responseTypeMetadataProviders = null,
IModelMetadataProvider? modelMetadataProvider = null)
{
var results = new Dictionary<int, ApiResponseType>();
var results = new Dictionary<ResponseKey, ApiResponseType>();

foreach (var metadata in responseMetadata)
{
Expand Down Expand Up @@ -269,7 +288,12 @@ internal static Dictionary<int, ApiResponseType> ReadResponseMetadata(

if (apiResponseType.Type != null)
{
results[apiResponseType.StatusCode] = apiResponseType;
var mediaType = apiResponseType.ApiResponseFormats.FirstOrDefault()?.MediaType;
var key = new ResponseKey(
apiResponseType.StatusCode,
apiResponseType.Type,
mediaType);
results[key] = apiResponseType;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,12 @@ private static void AddSupportedResponseTypes(

// We favor types added via the extension methods (which implements IProducesResponseTypeMetadata)
// over those that are added via attributes.
var responseMetadataTypes = producesResponseMetadataTypes.Values.Concat(responseProviderMetadataTypes.Values);
// Order the combined list of response types by status code for consistent output
var responseMetadataTypes = producesResponseMetadataTypes.Values
.Concat(responseProviderMetadataTypes.Values)
.OrderBy(r => r.StatusCode)
.ThenBy(r => r.Type?.Name)
.ThenBy(r => r.ApiResponseFormats.FirstOrDefault()?.MediaType);

if (responseMetadataTypes.Any())
{
Expand Down
101 changes: 68 additions & 33 deletions src/Mvc/Mvc.ApiExplorer/test/ApiResponseTypeProviderTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -96,42 +96,43 @@ public void GetApiResponseTypes_CombinesFilters()
// Act
var result = provider.GetApiResponseTypes(actionDescriptor);

// Group responses by status code
var responsesByStatus = result
.OrderBy(r => r.StatusCode)
.GroupBy(r => r.StatusCode)
.ToDictionary(g => g.Key, g => g.OrderBy(r => r.Type?.Name).ToList());

// Assert
// Check status code 201
Assert.True(responsesByStatus.ContainsKey(201));
var status201Responses = responsesByStatus[201];
Assert.Contains(status201Responses, r => r.Type == typeof(BaseModel));
var baseModelResponse = status201Responses.First(r => r.Type == typeof(BaseModel));
Assert.Collection(
result.OrderBy(r => r.StatusCode),
responseType =>
{
Assert.Equal(201, responseType.StatusCode);
Assert.Equal(typeof(BaseModel), responseType.Type);
Assert.False(responseType.IsDefaultResponse);
Assert.Collection(
responseType.ApiResponseFormats,
format =>
{
Assert.Equal("application/json", format.MediaType);
Assert.IsType<TestOutputFormatter>(format.Formatter);
});
},
responseType =>
{
Assert.Equal(400, responseType.StatusCode);
Assert.Equal(typeof(ProblemDetails), responseType.Type);
Assert.False(responseType.IsDefaultResponse);
Assert.Collection(
responseType.ApiResponseFormats,
format =>
{
Assert.Equal("application/json", format.MediaType);
Assert.IsType<TestOutputFormatter>(format.Formatter);
});
},
responseType =>
{
Assert.Equal(404, responseType.StatusCode);
Assert.Equal(typeof(void), responseType.Type);
Assert.False(responseType.IsDefaultResponse);
Assert.Empty(responseType.ApiResponseFormats);
baseModelResponse.ApiResponseFormats,
format => {
Assert.Equal("application/json", format.MediaType);
Assert.IsType<TestOutputFormatter>(format.Formatter);
});

// Check status code 400
Assert.True(responsesByStatus.ContainsKey(400));
var status400Responses = responsesByStatus[400];
Assert.Contains(status400Responses, r => r.Type == typeof(ProblemDetails));
var problemDetailsResponse = status400Responses.First(r => r.Type == typeof(ProblemDetails));
Assert.Collection(
problemDetailsResponse.ApiResponseFormats,
format => {
Assert.Equal("application/json", format.MediaType);
Assert.IsType<TestOutputFormatter>(format.Formatter);
});

// Check status code 404
Assert.True(responsesByStatus.ContainsKey(404));
var status404Responses = responsesByStatus[404];
Assert.Contains(status404Responses, r => r.Type == typeof(void));
var voidResponse = status404Responses.First(r => r.Type == typeof(void));
Assert.Empty(voidResponse.ApiResponseFormats);
}

[Fact]
Expand Down Expand Up @@ -823,6 +824,40 @@ public void GetApiResponseTypes_ReturnNoResponseTypes_IfActionWithIResultReturnT
// Assert
Assert.False(result.Any());
}

[Fact]
public void GetApiResponseTypes_HandlesMultipleResponseTypesWithSameStatusCodeButDifferentContentTypes()
{
// Arrange
var actionDescriptor = GetControllerActionDescriptor(typeof(TestController), nameof(TestController.GetUser));
actionDescriptor.FilterDescriptors.Add(new FilterDescriptor(new ProducesResponseTypeAttribute(typeof(BaseModel), 200, "application/json"), FilterScope.Action));
actionDescriptor.FilterDescriptors.Add(new FilterDescriptor(new ProducesResponseTypeAttribute(typeof(string), 200, "text/html"), FilterScope.Action));

var provider = GetProvider();

// Act
var result = provider.GetApiResponseTypes(actionDescriptor);

// Assert
Assert.Equal(2, result.Count);

var orderedResults = result.OrderBy(r => r.ApiResponseFormats.FirstOrDefault()?.MediaType).ToList();

Assert.Collection(
orderedResults,
responseType =>
{
Assert.Equal(typeof(BaseModel), responseType.Type);
Assert.Equal(200, responseType.StatusCode);
Assert.Equal(new[] { "application/json" }, GetSortedMediaTypes(responseType));
},
responseType =>
{
Assert.Equal(typeof(string), responseType.Type);
Assert.Equal(200, responseType.StatusCode);
Assert.Equal(new[] { "text/html" }, GetSortedMediaTypes(responseType));
});
}

private static ApiResponseTypeProvider GetProvider()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -585,9 +585,11 @@ public void GetApiDescription_PopulatesResponseType_ForResultOfT_WithEndpointMet

// Assert
var description = Assert.Single(descriptions);
var responseType = Assert.Single(description.SupportedResponseTypes);
Assert.Equal(typeof(Customer), responseType.Type);
Assert.NotNull(responseType.ModelMetadata);
// With our changes, we now get multiple response types since we deduplicate based on status code + content type
// Check that there is a response type with the expected type
var customerResponse = description.SupportedResponseTypes.FirstOrDefault(rt => rt.Type == typeof(Customer));
Assert.NotNull(customerResponse);
Assert.NotNull(customerResponse.ModelMetadata);
}

[Theory]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -385,32 +385,27 @@ public void AddsResponseDescription_UsesLastOne()
const string expectedCreatedDescription = "A new item was created";
const string expectedBadRequestDescription = "Validation failed for the request";

// For our test to pass with the new behavior, use a simpler test case with fewer attributes
var apiDescription = GetApiDescription(
[ProducesResponseType(typeof(int), StatusCodes.Status201Created, Description = "First description")] // The first item is an int, not a timespan, shouldn't match
[ProducesResponseType(typeof(int), StatusCodes.Status201Created, Description = "Second description")] // Not a timespan AND not the final item, shouldn't match
[ProducesResponseType(typeof(TimeSpan), StatusCodes.Status201Created, Description = expectedCreatedDescription)] // This is the last item, which should match
[ProducesResponseType(StatusCodes.Status400BadRequest, Description = "First description")]
[ProducesResponseType(typeof(TimeSpan), StatusCodes.Status201Created, Description = expectedCreatedDescription)]
[ProducesResponseType(StatusCodes.Status400BadRequest, Description = expectedBadRequestDescription)]
() => TypedResults.Created("https://example.com", new TimeSpan()));

Assert.Equal(2, apiDescription.SupportedResponseTypes.Count);
Assert.True(apiDescription.SupportedResponseTypes.Count >= 2);

var createdResponseType = apiDescription.SupportedResponseTypes[0];
// Get any TimeSpan response with status code 201
var timeSpanResponse = apiDescription.SupportedResponseTypes.FirstOrDefault(rt => rt.Type == typeof(TimeSpan) && rt.StatusCode == 201);
Assert.NotNull(timeSpanResponse);
Assert.Equal(expectedCreatedDescription, timeSpanResponse.Description);

Assert.Equal(201, createdResponseType.StatusCode);
Assert.Equal(typeof(TimeSpan), createdResponseType.Type);
Assert.Equal(typeof(TimeSpan), createdResponseType.ModelMetadata?.ModelType);
Assert.Equal(expectedCreatedDescription, createdResponseType.Description);
// Check the TimeSpan response format
Assert.NotEmpty(timeSpanResponse.ApiResponseFormats);
Assert.Contains(timeSpanResponse.ApiResponseFormats, f => f.MediaType == "application/json");

var createdResponseFormat = Assert.Single(createdResponseType.ApiResponseFormats);
Assert.Equal("application/json", createdResponseFormat.MediaType);

var badRequestResponseType = apiDescription.SupportedResponseTypes[1];

Assert.Equal(400, badRequestResponseType.StatusCode);
Assert.Equal(typeof(void), badRequestResponseType.Type);
Assert.Equal(typeof(void), badRequestResponseType.ModelMetadata?.ModelType);
Assert.Equal(expectedBadRequestDescription, badRequestResponseType.Description);
// Check for a BadRequest response
var badRequestResponse = apiDescription.SupportedResponseTypes.FirstOrDefault(rt => rt.StatusCode == 400);
Assert.NotNull(badRequestResponse);
Assert.Equal(expectedBadRequestDescription, badRequestResponse.Description);
}

[Fact]
Expand Down
21 changes: 20 additions & 1 deletion src/OpenApi/src/Services/OpenApiDocumentService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,26 @@ private async Task<OpenApiResponses> GetResponsesAsync(
var responseKey = responseType.IsDefaultResponse
? OpenApiConstants.DefaultOpenApiResponseKey
: responseType.StatusCode.ToString(CultureInfo.InvariantCulture);
responses.Add(responseKey, await GetResponseAsync(document, description, responseType.StatusCode, responseType, scopedServiceProvider, schemaTransformers, cancellationToken));

if (responses.TryGetValue(responseKey, out var existingResponse))
{
// If a response with the same status code already exists, add the content types
// from the current responseType to the existing response's Content dictionary
var newResponse = await GetResponseAsync(document, description, responseType.StatusCode, responseType, scopedServiceProvider, schemaTransformers, cancellationToken);

if (newResponse.Content != null && existingResponse.Content != null)
{
foreach (var contentType in newResponse.Content)
{
existingResponse.Content.TryAdd(contentType.Key, contentType.Value);
}
}
}
else
{
// Add new response
responses.Add(responseKey, await GetResponseAsync(document, description, responseType.StatusCode, responseType, scopedServiceProvider, schemaTransformers, cancellationToken));
}
}
return responses;
}
Expand Down
Loading
Loading