diff --git a/src/Middleware/Diagnostics/src/ExceptionHandler/ExceptionHandlerExtensions.cs b/src/Middleware/Diagnostics/src/ExceptionHandler/ExceptionHandlerExtensions.cs index 2cb6d10f77c5..398a894f495c 100644 --- a/src/Middleware/Diagnostics/src/ExceptionHandler/ExceptionHandlerExtensions.cs +++ b/src/Middleware/Diagnostics/src/ExceptionHandler/ExceptionHandlerExtensions.cs @@ -97,6 +97,7 @@ private static IApplicationBuilder SetExceptionHandlerMiddleware(IApplicationBui { var loggerFactory = app.ApplicationServices.GetRequiredService(); var diagnosticListener = app.ApplicationServices.GetRequiredService(); + var exceptionHandlers = app.ApplicationServices.GetRequiredService>(); if (options is null) { @@ -110,7 +111,7 @@ private static IApplicationBuilder SetExceptionHandlerMiddleware(IApplicationBui options.Value.ExceptionHandler = newNext; } - return new ExceptionHandlerMiddlewareImpl(next, loggerFactory, options, diagnosticListener, problemDetailsService).Invoke; + return new ExceptionHandlerMiddlewareImpl(next, loggerFactory, options, diagnosticListener, exceptionHandlers, problemDetailsService).Invoke; }); } diff --git a/src/Middleware/Diagnostics/src/ExceptionHandler/ExceptionHandlerMiddleware.cs b/src/Middleware/Diagnostics/src/ExceptionHandler/ExceptionHandlerMiddleware.cs index 5b9f87514326..c5f6ffb0838c 100644 --- a/src/Middleware/Diagnostics/src/ExceptionHandler/ExceptionHandlerMiddleware.cs +++ b/src/Middleware/Diagnostics/src/ExceptionHandler/ExceptionHandlerMiddleware.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Diagnostics; +using System.Linq; using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Http; using Microsoft.Extensions.Logging; @@ -29,11 +30,12 @@ public ExceptionHandlerMiddleware( IOptions options, DiagnosticListener diagnosticListener) { - _innerMiddlewareImpl = new ( + _innerMiddlewareImpl = new( next, loggerFactory, options, diagnosticListener, + Enumerable.Empty(), problemDetailsService: null); } diff --git a/src/Middleware/Diagnostics/src/ExceptionHandler/ExceptionHandlerMiddlewareImpl.cs b/src/Middleware/Diagnostics/src/ExceptionHandler/ExceptionHandlerMiddlewareImpl.cs index c62206a43e47..471c37a72869 100644 --- a/src/Middleware/Diagnostics/src/ExceptionHandler/ExceptionHandlerMiddlewareImpl.cs +++ b/src/Middleware/Diagnostics/src/ExceptionHandler/ExceptionHandlerMiddlewareImpl.cs @@ -16,7 +16,7 @@ namespace Microsoft.AspNetCore.Diagnostics; /// /// A middleware for handling exceptions in the application. /// -internal class ExceptionHandlerMiddlewareImpl +internal sealed class ExceptionHandlerMiddlewareImpl { private const int DefaultStatusCode = StatusCodes.Status500InternalServerError; @@ -25,6 +25,7 @@ internal class ExceptionHandlerMiddlewareImpl private readonly ILogger _logger; private readonly Func _clearCacheHeadersDelegate; private readonly DiagnosticListener _diagnosticListener; + private readonly IExceptionHandler[] _exceptionHandlers; private readonly IProblemDetailsService? _problemDetailsService; /// @@ -34,12 +35,14 @@ internal class ExceptionHandlerMiddlewareImpl /// The used for logging. /// The options for configuring the middleware. /// The used for writing diagnostic messages. + /// /// The used for writing messages. public ExceptionHandlerMiddlewareImpl( RequestDelegate next, ILoggerFactory loggerFactory, IOptions options, DiagnosticListener diagnosticListener, + IEnumerable exceptionHandlers, IProblemDetailsService? problemDetailsService = null) { _next = next; @@ -47,6 +50,7 @@ public ExceptionHandlerMiddlewareImpl( _logger = loggerFactory.CreateLogger(); _clearCacheHeadersDelegate = ClearCacheHeaders; _diagnosticListener = diagnosticListener; + _exceptionHandlers = exceptionHandlers as IExceptionHandler[] ?? new List(exceptionHandlers).ToArray(); _problemDetailsService = problemDetailsService; if (_options.ExceptionHandler == null) @@ -133,7 +137,7 @@ private async Task HandleException(HttpContext context, ExceptionDispatchInfo ed edi.Throw(); } - PathString originalPath = context.Request.Path; + var originalPath = context.Request.Path; if (_options.ExceptionHandlingPath.HasValue) { context.Request.Path = _options.ExceptionHandlingPath; @@ -155,24 +159,35 @@ private async Task HandleException(HttpContext context, ExceptionDispatchInfo ed context.Response.StatusCode = DefaultStatusCode; context.Response.OnStarting(_clearCacheHeadersDelegate, context.Response); - var problemDetailsWritten = false; - if (_options.ExceptionHandler != null) + var handled = false; + foreach (var exceptionHandler in _exceptionHandlers) { - await _options.ExceptionHandler!(context); + handled = await exceptionHandler.TryHandleAsync(context, edi.SourceException, context.RequestAborted); + if (handled) + { + break; + } } - else + + if (!handled) { - problemDetailsWritten = await _problemDetailsService!.TryWriteAsync(new() + if (_options.ExceptionHandler is not null) { - HttpContext = context, - AdditionalMetadata = exceptionHandlerFeature.Endpoint?.Metadata, - ProblemDetails = { Status = DefaultStatusCode }, - Exception = edi.SourceException, - }); + await _options.ExceptionHandler!(context); + } + else + { + handled = await _problemDetailsService!.TryWriteAsync(new() + { + HttpContext = context, + AdditionalMetadata = exceptionHandlerFeature.Endpoint?.Metadata, + ProblemDetails = { Status = DefaultStatusCode }, + Exception = edi.SourceException, + }); + } } - // If the response has already started, assume exception handler was successful. - if (context.Response.HasStarted || problemDetailsWritten || context.Response.StatusCode != StatusCodes.Status404NotFound || _options.AllowStatusCode404Response) + if (context.Response.HasStarted || handled || context.Response.StatusCode != StatusCodes.Status404NotFound || _options.AllowStatusCode404Response) { const string eventName = "Microsoft.AspNetCore.Diagnostics.HandledException"; if (_diagnosticListener.IsEnabled() && _diagnosticListener.IsEnabled(eventName)) diff --git a/src/Middleware/Diagnostics/src/ExceptionHandler/ExceptionHandlerServiceCollectionExtensions.cs b/src/Middleware/Diagnostics/src/ExceptionHandler/ExceptionHandlerServiceCollectionExtensions.cs index ce758229b7d7..492e5dd3b1d7 100644 --- a/src/Middleware/Diagnostics/src/ExceptionHandler/ExceptionHandlerServiceCollectionExtensions.cs +++ b/src/Middleware/Diagnostics/src/ExceptionHandler/ExceptionHandlerServiceCollectionExtensions.cs @@ -1,7 +1,9 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Diagnostics.CodeAnalysis; using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Diagnostics; namespace Microsoft.Extensions.DependencyInjection; @@ -38,4 +40,16 @@ public static IServiceCollection AddExceptionHandler(this IServiceColl services.AddOptions().Configure(configureOptions); return services; } + + /// + /// + /// + /// + /// + /// + + public static IServiceCollection AddExceptionHandler<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] T>(this IServiceCollection services) where T : class, IExceptionHandler + { + return services.AddSingleton(); + } } diff --git a/src/Middleware/Diagnostics/src/ExceptionHandler/IExceptionHandler.cs b/src/Middleware/Diagnostics/src/ExceptionHandler/IExceptionHandler.cs new file mode 100644 index 000000000000..752230811753 --- /dev/null +++ b/src/Middleware/Diagnostics/src/ExceptionHandler/IExceptionHandler.cs @@ -0,0 +1,21 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.AspNetCore.Http; + +namespace Microsoft.AspNetCore.Diagnostics; + +/// +/// +/// +public interface IExceptionHandler +{ + /// + /// + /// + /// + /// + /// + /// + ValueTask TryHandleAsync(HttpContext httpContext, Exception exception, CancellationToken cancellationToken); +} diff --git a/src/Middleware/Diagnostics/src/PublicAPI.Unshipped.txt b/src/Middleware/Diagnostics/src/PublicAPI.Unshipped.txt index 9d26ac109b2d..7ace23668c8d 100644 --- a/src/Middleware/Diagnostics/src/PublicAPI.Unshipped.txt +++ b/src/Middleware/Diagnostics/src/PublicAPI.Unshipped.txt @@ -1,2 +1,5 @@ #nullable enable +Microsoft.AspNetCore.Diagnostics.IExceptionHandler +Microsoft.AspNetCore.Diagnostics.IExceptionHandler.TryHandleAsync(Microsoft.AspNetCore.Http.HttpContext! httpContext, System.Exception! exception, System.Threading.CancellationToken cancellationToken) -> System.Threading.Tasks.ValueTask Microsoft.AspNetCore.Diagnostics.StatusCodeReExecuteFeature.OriginalStatusCode.get -> int +static Microsoft.Extensions.DependencyInjection.ExceptionHandlerServiceCollectionExtensions.AddExceptionHandler(this Microsoft.Extensions.DependencyInjection.IServiceCollection! services) -> Microsoft.Extensions.DependencyInjection.IServiceCollection! diff --git a/src/Middleware/Diagnostics/test/UnitTests/ExceptionHandlerMiddlewareTest.cs b/src/Middleware/Diagnostics/test/UnitTests/ExceptionHandlerMiddlewareTest.cs index de31f6837aed..07097d9b3576 100644 --- a/src/Middleware/Diagnostics/test/UnitTests/ExceptionHandlerMiddlewareTest.cs +++ b/src/Middleware/Diagnostics/test/UnitTests/ExceptionHandlerMiddlewareTest.cs @@ -114,6 +114,105 @@ public async Task Invoke_ExceptionHandlerCaptureRouteValuesAndEndpoint() await middleware.Invoke(httpContext); } + [Fact] + public async Task IExceptionHandlers_CallNextIfNotHandled() + { + // Arrange + var httpContext = CreateHttpContext(); + + var optionsAccessor = CreateOptionsAccessor(); + + var exceptionHandlers = new List + { + new TestExceptionHandler(false, "1"), + new TestExceptionHandler(false, "2"), + new TestExceptionHandler(true, "3"), + }; + + var middleware = CreateMiddleware(_ => throw new InvalidOperationException(), optionsAccessor, exceptionHandlers); + + // Act & Assert + await middleware.Invoke(httpContext); + + Assert.True(httpContext.Items.ContainsKey("1")); + Assert.True(httpContext.Items.ContainsKey("2")); + Assert.True(httpContext.Items.ContainsKey("3")); + } + + [Fact] + public async Task IExceptionHandlers_SkipIfOneHandle() + { + // Arrange + var httpContext = CreateHttpContext(); + + var optionsAccessor = CreateOptionsAccessor(); + + var exceptionHandlers = new List + { + new TestExceptionHandler(false, "1"), + new TestExceptionHandler(true, "2"), + new TestExceptionHandler(true, "3"), + }; + + var middleware = CreateMiddleware(_ => throw new InvalidOperationException(), optionsAccessor, exceptionHandlers); + + // Act & Assert + await middleware.Invoke(httpContext); + + Assert.True(httpContext.Items.ContainsKey("1")); + Assert.True(httpContext.Items.ContainsKey("2")); + Assert.False(httpContext.Items.ContainsKey("3")); + } + + [Fact] + public async Task IExceptionHandlers_CallOptionExceptionHandlerIfNobodyHandles() + { + // Arrange + var httpContext = CreateHttpContext(); + + var optionsAccessor = CreateOptionsAccessor( + (context) => + { + context.Items["ExceptionHandler"] = true; + return Task.CompletedTask; + }); + + var exceptionHandlers = new List + { + new TestExceptionHandler(false, "1"), + new TestExceptionHandler(false, "2"), + new TestExceptionHandler(false, "3"), + }; + + var middleware = CreateMiddleware(_ => throw new InvalidOperationException(), optionsAccessor, exceptionHandlers); + + // Act & Assert + await middleware.Invoke(httpContext); + + Assert.True(httpContext.Items.ContainsKey("1")); + Assert.True(httpContext.Items.ContainsKey("2")); + Assert.True(httpContext.Items.ContainsKey("3")); + Assert.True(httpContext.Items.ContainsKey("ExceptionHandler")); + } + + private class TestExceptionHandler : IExceptionHandler + { + private readonly bool _handle; + private readonly string _name; + + public TestExceptionHandler(bool handle, string name) + { + _handle = handle; + _name = name; + } + + public ValueTask TryHandleAsync(HttpContext httpContext, Exception exception, CancellationToken cancellationToken) + { + httpContext.Items[_name] = true; + return ValueTask.FromResult(_handle); + } + } + private HttpContext CreateHttpContext() { var httpContext = new DefaultHttpContext @@ -138,18 +237,20 @@ private IOptions CreateOptionsAccessor( return optionsAccessor; } - private ExceptionHandlerMiddleware CreateMiddleware( + private ExceptionHandlerMiddlewareImpl CreateMiddleware( RequestDelegate next, - IOptions options) + IOptions options, + IEnumerable exceptionHandlers = null) { next ??= c => Task.CompletedTask; var listener = new DiagnosticListener("Microsoft.AspNetCore"); - var middleware = new ExceptionHandlerMiddleware( + var middleware = new ExceptionHandlerMiddlewareImpl( next, NullLoggerFactory.Instance, options, - listener); + listener, + exceptionHandlers ?? Enumerable.Empty()); return middleware; }