Skip to content

Commit

Permalink
choreAdd nullable reference types to AspNetCore project (#503)
Browse files Browse the repository at this point in the history
* Initial nullable reference type implementations

* Add more NRTs

* Fix additional NRT issues

* Add NRTs to test project

* Cleanup

* Update method to return false rather than throw

* Initial NRT implementation for AspNetCore project

* Address the remaining project nullability warnings

* Add NRTs to unit tests

Co-authored-by: Brandon Foss <brandon_foss@selinc.com>
  • Loading branch information
fossbrandon and Brandon Foss authored Jan 13, 2022
1 parent 9ddb7f0 commit cb4c9f4
Show file tree
Hide file tree
Showing 24 changed files with 160 additions and 158 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,5 @@ public static class FinbuckleMultiTenantApplicationBuilderExtensions
/// <returns>The same IApplicationBuilder passed into the method.</returns>
public static IApplicationBuilder UseMultiTenant(this IApplicationBuilder builder)
=> builder.UseMiddleware<MultiTenantMiddleware>();

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ public static FinbuckleMultiTenantBuilder<TTenantInfo> WithPerTenantAuthenticati
/// <returns>The same MultiTenantBuilder passed into the method.</returns>
[SuppressMessage("ReSharper", "EmptyGeneralCatchClause")]
public static FinbuckleMultiTenantBuilder<TTenantInfo> WithPerTenantAuthenticationConventions<TTenantInfo>(
this FinbuckleMultiTenantBuilder<TTenantInfo> builder, Action<MultiTenantAuthenticationOptions> config = null)
this FinbuckleMultiTenantBuilder<TTenantInfo> builder, Action<MultiTenantAuthenticationOptions>? config = null)
where TTenantInfo : class, ITenantInfo, new()
{
// Set events to set and validate tenant for each cookie based authentication principal.
Expand All @@ -68,25 +68,24 @@ public static FinbuckleMultiTenantBuilder<TTenantInfo> WithPerTenantAuthenticati
// Skip if bypass set (e.g. ClaimsStrategy in effect)
if(context.HttpContext.Items.Keys.Contains($"{Constants.TenantToken}__bypass_validate_principal__"))
return;

var currentTenant = context.HttpContext.GetMultiTenantContext<TTenantInfo>()?.TenantInfo?.Identifier;
string authTenant = null;
string? authTenant = null;
if (context.Properties.Items.ContainsKey(Constants.TenantToken))
{
authTenant = context.Properties.Items[Constants.TenantToken];
}
else
{
var loggerFactory = context.HttpContext.RequestServices.GetService<ILoggerFactory>();
loggerFactory.CreateLogger<FinbuckleMultiTenantBuilder<TTenantInfo>>().LogWarning("No tenant found in authentication properties.");
loggerFactory?.CreateLogger<FinbuckleMultiTenantBuilder<TTenantInfo>>().LogWarning("No tenant found in authentication properties.");
}

// Does the current tenant match the auth property tenant?
if(!string.Equals(currentTenant, authTenant, StringComparison.OrdinalIgnoreCase))
context.RejectPrincipal();

if(origOnValidatePrincipal != null)
await origOnValidatePrincipal(context);

await origOnValidatePrincipal(context);
};
});

Expand All @@ -107,14 +106,14 @@ public static FinbuckleMultiTenantBuilder<TTenantInfo> WithPerTenantAuthenticati
try { options.ClientId = ((string)d.OpenIdConnectClientId).Replace(Constants.TenantToken, tc.Identifier); } catch { }
try { options.ClientSecret = ((string)d.OpenIdConnectClientSecret).Replace(Constants.TenantToken, tc.Identifier); } catch { }
});

var challengeSchemeProp = typeof(TTenantInfo).GetProperty("ChallengeScheme");
if (challengeSchemeProp != null && challengeSchemeProp.PropertyType == typeof(string))
{
builder.WithPerTenantOptions<AuthenticationOptions>((options, tc)
=> options.DefaultChallengeScheme = (string)challengeSchemeProp.GetValue(tc) ?? options.DefaultChallengeScheme);
=> options.DefaultChallengeScheme = (string?)challengeSchemeProp.GetValue(tc) ?? options.DefaultChallengeScheme);
}

return builder;
}

Expand All @@ -125,21 +124,21 @@ public static FinbuckleMultiTenantBuilder<TTenantInfo> WithPerTenantAuthenticati
/// <param name="config">Authentication options config</param>
/// <returns>The same MultiTenantBuilder passed into the method.</returns>
public static FinbuckleMultiTenantBuilder<TTenantInfo> WithPerTenantAuthenticationCore<TTenantInfo>(
this FinbuckleMultiTenantBuilder<TTenantInfo> builder, Action<MultiTenantAuthenticationOptions> config =
this FinbuckleMultiTenantBuilder<TTenantInfo> builder, Action<MultiTenantAuthenticationOptions>? config =
null)
where TTenantInfo : class, ITenantInfo, new()
{

config ??= _ => { };
builder.Services.Configure(config);

// We need to "decorate" IAuthenticationService so callbacks so that
// remote authentication can get the tenant from the authentication
// properties in the state parameter.
if (builder.Services.All(s => s.ServiceType != typeof(IAuthenticationService)))
throw new MultiTenantException("WithPerTenantAuthenticationCore() must be called after AddAuthentication() in ConfigureServices.");
builder.Services.DecorateService<IAuthenticationService, MultiTenantAuthenticationService<TTenantInfo>>();

// Replace IAuthenticationSchemeProvider so that the options aren't
// cached and can be used per-tenant.
builder.Services.Replace(ServiceDescriptor.Singleton<IAuthenticationSchemeProvider, MultiTenantAuthenticationSchemeProvider>());
Expand Down Expand Up @@ -183,7 +182,7 @@ public static FinbuckleMultiTenantBuilder<TTenantInfo> WithRemoteAuthenticationC
public static FinbuckleMultiTenantBuilder<TTenantInfo> WithBasePathStrategy<TTenantInfo>(this FinbuckleMultiTenantBuilder<TTenantInfo> builder)
where TTenantInfo : class, ITenantInfo, new()
=> builder.WithStrategy<BasePathStrategy>(ServiceLifetime.Singleton);

/// <summary>
/// Adds and configures a RouteStrategy with a route parameter Constants.TenantToken to the application.
/// </summary>
Expand Down Expand Up @@ -243,7 +242,7 @@ public static FinbuckleMultiTenantBuilder<TTenantInfo> WithHostStrategy<TTenantI
{
return builder.WithClaimStrategy(Constants.TenantToken);
}

/// <summary>
/// Adds and configures a ClaimStrategy to the application. Uses the default authentication handler scheme.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ public static class FinbuckleHttpContextExtensions
/// <summary>
/// Returns the current MultiTenantContext or null if there is none.
/// </summary>
public static IMultiTenantContext<T> GetMultiTenantContext<T>(this HttpContext httpContext)
public static IMultiTenantContext<T>? GetMultiTenantContext<T>(this HttpContext httpContext)
where T : class, ITenantInfo, new()
{
return httpContext.RequestServices.GetRequiredService<IMultiTenantContextAccessor<T>>().MultiTenantContext;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
<TargetFrameworks>net6.0;net5.0;netcoreapp3.1</TargetFrameworks>
<Title>Finbuckle.MultiTenant.AspNetCore</Title>
<Description>ASP.NET Core support for Finbuckle.MultiTenant.</Description>
<Nullable>enable</Nullable>
</PropertyGroup>

<ItemGroup>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,18 +57,18 @@ public MultiTenantAuthenticationSchemeProvider(IOptions<AuthenticationOptions> o
private readonly IDictionary<string, AuthenticationScheme> _schemes;
private readonly List<AuthenticationScheme> _requestHandlers;

private Task<AuthenticationScheme> GetDefaultSchemeAsync()
private Task<AuthenticationScheme?> GetDefaultSchemeAsync()
=> _optionsProvider.Value.DefaultScheme != null
? GetSchemeAsync(_optionsProvider.Value.DefaultScheme)
: Task.FromResult<AuthenticationScheme>(null);
: Task.FromResult<AuthenticationScheme?>(null);

/// <summary>
/// Returns the scheme for this tenant that will be used by default for <see cref="IAuthenticationService.AuthenticateAsync(HttpContext, string)"/>.
/// This is typically specified via <see cref="AuthenticationOptions.DefaultAuthenticateScheme"/>.
/// Otherwise, this will fallback to <see cref="AuthenticationOptions.DefaultScheme"/>.
/// </summary>
/// <returns>The scheme that will be used by default for <see cref="IAuthenticationService.AuthenticateAsync(HttpContext, string)"/>.</returns>
public virtual Task<AuthenticationScheme> GetDefaultAuthenticateSchemeAsync()
/// <returns>The scheme that will be used by default for <see cref="IAuthenticationService.AuthenticateAsync(HttpContext, string)"/> or null if not found.</returns>
public virtual Task<AuthenticationScheme?> GetDefaultAuthenticateSchemeAsync()
=> _optionsProvider.Value.DefaultAuthenticateScheme != null
? GetSchemeAsync(_optionsProvider.Value.DefaultAuthenticateScheme)
: GetDefaultSchemeAsync();
Expand All @@ -78,8 +78,8 @@ public virtual Task<AuthenticationScheme> GetDefaultAuthenticateSchemeAsync()
/// This is typically specified via <see cref="AuthenticationOptions.DefaultChallengeScheme"/>.
/// Otherwise, this will fallback to <see cref="AuthenticationOptions.DefaultScheme"/>.
/// </summary>
/// <returns>The scheme that will be used by default for <see cref="IAuthenticationService.ChallengeAsync(HttpContext, string, AuthenticationProperties)"/>.</returns>
public virtual Task<AuthenticationScheme> GetDefaultChallengeSchemeAsync()
/// <returns>The scheme that will be used by default for <see cref="IAuthenticationService.ChallengeAsync(HttpContext, string, AuthenticationProperties)"/> or null if not found.</returns>
public virtual Task<AuthenticationScheme?> GetDefaultChallengeSchemeAsync()
=> _optionsProvider.Value.DefaultChallengeScheme != null
? GetSchemeAsync(_optionsProvider.Value.DefaultChallengeScheme)
: GetDefaultSchemeAsync();
Expand All @@ -89,8 +89,8 @@ public virtual Task<AuthenticationScheme> GetDefaultChallengeSchemeAsync()
/// This is typically specified via <see cref="AuthenticationOptions.DefaultForbidScheme"/>.
/// Otherwise, this will fallback to <see cref="GetDefaultChallengeSchemeAsync"/> .
/// </summary>
/// <returns>The scheme that will be used by default for <see cref="IAuthenticationService.ForbidAsync(HttpContext, string, AuthenticationProperties)"/>.</returns>
public virtual Task<AuthenticationScheme> GetDefaultForbidSchemeAsync()
/// <returns>The scheme that will be used by default for <see cref="IAuthenticationService.ForbidAsync(HttpContext, string, AuthenticationProperties)"/> or null if not found.</returns>
public virtual Task<AuthenticationScheme?> GetDefaultForbidSchemeAsync()
=> _optionsProvider.Value.DefaultForbidScheme != null
? GetSchemeAsync(_optionsProvider.Value.DefaultForbidScheme)
: GetDefaultChallengeSchemeAsync();
Expand All @@ -100,8 +100,8 @@ public virtual Task<AuthenticationScheme> GetDefaultForbidSchemeAsync()
/// This is typically specified via <see cref="AuthenticationOptions.DefaultSignInScheme"/>.
/// Otherwise, this will fallback to <see cref="AuthenticationOptions.DefaultScheme"/>.
/// </summary>
/// <returns>The scheme that will be used by default for <see cref="IAuthenticationService.SignInAsync(HttpContext, string, System.Security.Claims.ClaimsPrincipal, AuthenticationProperties)"/>.</returns>
public virtual Task<AuthenticationScheme> GetDefaultSignInSchemeAsync()
/// <returns>The scheme that will be used by default for <see cref="IAuthenticationService.SignInAsync(HttpContext, string, System.Security.Claims.ClaimsPrincipal, AuthenticationProperties)"/> or null if not found.</returns>
public virtual Task<AuthenticationScheme?> GetDefaultSignInSchemeAsync()
=> _optionsProvider.Value.DefaultSignInScheme != null
? GetSchemeAsync(_optionsProvider.Value.DefaultSignInScheme)
: GetDefaultSchemeAsync();
Expand All @@ -111,8 +111,8 @@ public virtual Task<AuthenticationScheme> GetDefaultSignInSchemeAsync()
/// This is typically specified via <see cref="AuthenticationOptions.DefaultSignOutScheme"/>.
/// Otherwise this will fallback to <see cref="GetDefaultSignInSchemeAsync"/> if that supoorts sign out.
/// </summary>
/// <returns>The scheme that will be used by default for <see cref="IAuthenticationService.SignOutAsync(HttpContext, string, AuthenticationProperties)"/>.</returns>
public virtual Task<AuthenticationScheme> GetDefaultSignOutSchemeAsync()
/// <returns>The scheme that will be used by default for <see cref="IAuthenticationService.SignOutAsync(HttpContext, string, AuthenticationProperties)"/> or null if not found.</returns>
public virtual Task<AuthenticationScheme?> GetDefaultSignOutSchemeAsync()
=> _optionsProvider.Value.DefaultSignOutScheme != null
? GetSchemeAsync(_optionsProvider.Value.DefaultSignOutScheme)
: GetDefaultSignInSchemeAsync();
Expand All @@ -122,7 +122,7 @@ public virtual Task<AuthenticationScheme> GetDefaultSignOutSchemeAsync()
/// </summary>
/// <param name="name">The name of the authenticationScheme.</param>
/// <returns>The scheme or null if not found.</returns>
public virtual Task<AuthenticationScheme> GetSchemeAsync(string name)
public virtual Task<AuthenticationScheme?> GetSchemeAsync(string name)
=> Task.FromResult(_schemes.ContainsKey(name) ? _schemes[name] : null);

/// <summary>
Expand All @@ -135,7 +135,7 @@ public virtual Task<IEnumerable<AuthenticationScheme>> GetRequestHandlerSchemesA
=> Task.FromResult<IEnumerable<AuthenticationScheme>>(_requestHandlers);

/// <summary>
/// Registers a scheme for use by <see cref="IAuthenticationService"/>.
/// Registers a scheme for use by <see cref="IAuthenticationService"/>.
/// </summary>
/// <param name="scheme">The scheme.</param>
public virtual void AddScheme(AuthenticationScheme scheme)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,22 @@ public MultiTenantAuthenticationService(IAuthenticationService inner, IOptionsMo
this._multiTenantAuthenticationOptions = multiTenantAuthenticationOptions;
}

private static void AddTenantIdentifierToProperties(HttpContext context, ref AuthenticationProperties properties)
private static void AddTenantIdentifierToProperties(HttpContext context, ref AuthenticationProperties? properties)
{
// Add tenant identifier to the properties so on the callback we can use it to set the multitenant context.
var multiTenantContext = context.GetMultiTenantContext<TTenantInfo>();
if (multiTenantContext?.TenantInfo != null)
{
properties = properties ?? new AuthenticationProperties();
properties ??= new AuthenticationProperties();
if(!properties.Items.Keys.Contains(Constants.TenantToken))
properties.Items.Add(Constants.TenantToken, multiTenantContext.TenantInfo.Identifier);
}
}

public Task<AuthenticateResult> AuthenticateAsync(HttpContext context, string scheme)
public Task<AuthenticateResult> AuthenticateAsync(HttpContext context, string? scheme)
=> _inner.AuthenticateAsync(context, scheme);

public async Task ChallengeAsync(HttpContext context, string scheme, AuthenticationProperties properties)
public async Task ChallengeAsync(HttpContext context, string? scheme, AuthenticationProperties? properties)
{
if (_multiTenantAuthenticationOptions.CurrentValue.SkipChallengeIfTenantNotResolved)
{
Expand All @@ -50,19 +50,19 @@ public async Task ChallengeAsync(HttpContext context, string scheme, Authenticat
await _inner.ChallengeAsync(context, scheme, properties);
}

public async Task ForbidAsync(HttpContext context, string scheme, AuthenticationProperties properties)
public async Task ForbidAsync(HttpContext context, string? scheme, AuthenticationProperties? properties)
{
AddTenantIdentifierToProperties(context, ref properties);
await _inner.ForbidAsync(context, scheme, properties);
}

public async Task SignInAsync(HttpContext context, string scheme, ClaimsPrincipal principal, AuthenticationProperties properties)
public async Task SignInAsync(HttpContext context, string? scheme, ClaimsPrincipal principal, AuthenticationProperties? properties)
{
AddTenantIdentifierToProperties(context, ref properties);
await _inner.SignInAsync(context, scheme, principal, properties);
}

public async Task SignOutAsync(HttpContext context, string scheme, AuthenticationProperties properties)
public async Task SignOutAsync(HttpContext context, string? scheme, AuthenticationProperties? properties)
{
AddTenantIdentifierToProperties(context, ref properties);
await _inner.SignOutAsync(context, scheme, properties);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,7 @@ public async Task Invoke(HttpContext context)
accessor.MultiTenantContext = multiTenantContext;
}

if (next != null)
{
await next(context);
}
await next(context);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System;
using System.Collections.Generic;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Routing;

Expand All @@ -11,21 +12,16 @@ namespace Finbuckle.MultiTenant.AspNetCore
internal class MultiTenantRouteBuilder : IRouteBuilder
{
private readonly IServiceProvider serviceProvider;
private IRouter defaultHandler = new RouteHandler(context => null);
private IRouter defaultHandler = new RouteHandler(_ => Task.CompletedTask);

public MultiTenantRouteBuilder(IServiceProvider ServiceProvider)
public MultiTenantRouteBuilder(IServiceProvider serviceProvider)
{
if (ServiceProvider == null)
{
throw new ArgumentNullException(nameof(ServiceProvider));
}

serviceProvider = ServiceProvider;
this.serviceProvider = serviceProvider ?? throw new ArgumentNullException(nameof(serviceProvider));
}

public IApplicationBuilder ApplicationBuilder => throw new NotImplementedException();

public IRouter DefaultHandler { get => defaultHandler; set => throw new NotImplementedException(); }
public IRouter? DefaultHandler { get => defaultHandler; set => throw new NotImplementedException(); }

public IServiceProvider ServiceProvider => serviceProvider;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,18 @@ namespace Finbuckle.MultiTenant.Strategies
{
public class BasePathStrategy : IMultiTenantStrategy
{
public async Task<string> GetIdentifierAsync(object context)
public async Task<string?> GetIdentifierAsync(object context)
{
if(!(context is HttpContext))
if(!(context is HttpContext httpContext))
throw new MultiTenantException(null,
new ArgumentException($"\"{nameof(context)}\" type must be of type HttpContext", nameof(context)));

var path = (context as HttpContext).Request.Path;
var path = httpContext.Request.Path;

var pathSegments =
path.Value.Split(new char[] { '/' }, StringSplitOptions.RemoveEmptyEntries);
path.Value?.Split(new char[] { '/' }, StringSplitOptions.RemoveEmptyEntries);

if (pathSegments.Length == 0)
if (pathSegments is null || pathSegments.Length == 0)
return null;

string identifier = pathSegments[0];
Expand Down
Loading

0 comments on commit cb4c9f4

Please sign in to comment.