Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,155 +11,102 @@
namespace Microsoft.Extensions.DependencyInjection;

/// <summary>
/// Helper to cleanup expired server side sessions.
/// Helper to clean up expired server side sessions.
/// </summary>
public class ServerSideSessionCleanupHost : IHostedService
public class ServerSideSessionCleanupHost(
IServiceProvider serviceProvider,
IdentityServerOptions options,
ILogger<ServerSideSessionCleanupHost> logger) : BackgroundService
{
private readonly IServiceProvider _serviceProvider;
private readonly IdentityServerOptions _options;
private readonly ILogger<ServerSideSessionCleanupHost> _logger;

private CancellationTokenSource _source;

/// <summary>
/// Constructor for ServerSideSessionCleanupHost.
/// </summary>
/// <param name="serviceProvider"></param>
/// <param name="options"></param>
/// <param name="logger"></param>
public ServerSideSessionCleanupHost(IServiceProvider serviceProvider, IdentityServerOptions options, ILogger<ServerSideSessionCleanupHost> logger)
/// <inheritdoc />
public override Task StartAsync(CancellationToken cancellationToken) =>
!options.ServerSideSessions.RemoveExpiredSessions
? Task.CompletedTask
: base.StartAsync(cancellationToken);

/// <inheritdoc />
protected override async Task ExecuteAsync(CancellationToken stoppingToken)
{
_serviceProvider = serviceProvider ?? throw new ArgumentNullException(nameof(serviceProvider));
_options = options ?? throw new ArgumentNullException(nameof(options));
_logger = logger;
}

/// <summary>
/// Starts the token cleanup polling.
/// </summary>
public Task StartAsync(CancellationToken cancellationToken)
{
if (_options.ServerSideSessions.RemoveExpiredSessions)
{
if (_source != null)
{
throw new InvalidOperationException("Already started. Call Stop first.");
}

_logger.LogDebug("Starting server-side session removal");

_source = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);

Task.Factory.StartNew(() => StartInternalAsync(_source.Token), cancellationToken, TaskCreationOptions.None, TaskScheduler.Default);
}

return Task.CompletedTask;
}

/// <summary>
/// Stops the token cleanup polling.
/// </summary>
public async Task StopAsync(CancellationToken cancellationToken)
{
if (_options.ServerSideSessions.RemoveExpiredSessions)
{
if (_source == null)
{
throw new InvalidOperationException("Not started. Call Start first.");
}

_logger.LogDebug("Stopping server-side session removal");

await _source.CancelAsync();
_source = null;
}
}
logger.LogDebug("Starting server-side session removal");

private async Task StartInternalAsync(CancellationToken cancellationToken)
{
var removalFrequencySeconds = (int)_options.ServerSideSessions.RemoveExpiredSessionsFrequency.TotalSeconds;
var removalFrequencySeconds = (int)options.ServerSideSessions.RemoveExpiredSessionsFrequency.TotalSeconds;

// Start the first run at a random interval.
var delay = _options.ServerSideSessions.FuzzExpiredSessionRemovalStart
var delay = options.ServerSideSessions.FuzzExpiredSessionRemovalStart
#pragma warning disable CA5394 // Randomness for security does not apply here
? TimeSpan.FromSeconds(Random.Shared.Next(removalFrequencySeconds))
#pragma warning restore CA5394
: _options.ServerSideSessions.RemoveExpiredSessionsFrequency;
: options.ServerSideSessions.RemoveExpiredSessionsFrequency;

while (true)
while (!stoppingToken.IsCancellationRequested)
{
if (cancellationToken.IsCancellationRequested)
{
_logger.LogDebug("CancellationRequested. Exiting.");
break;
}

try
{
await Task.Delay(delay, cancellationToken);
await Task.Delay(delay, stoppingToken);
}
catch (TaskCanceledException)
{
_logger.LogDebug("TaskCanceledException. Exiting.");
logger.LogDebug("TaskCanceledException. Exiting.");
break;
}
catch (Exception ex)
{
_logger.LogError("Task.Delay exception: {ExceptionMessage}. Exiting.", ex.Message);
logger.LogError("Task.Delay exception: {ExceptionMessage}. Exiting.", ex.Message);
break;
}

if (cancellationToken.IsCancellationRequested)
if (stoppingToken.IsCancellationRequested)
{
_logger.LogDebug("CancellationRequested. Exiting.");
break;
}

await RunAsync(cancellationToken);
await RunAsync(stoppingToken);

delay = _options.ServerSideSessions.RemoveExpiredSessionsFrequency;
delay = options.ServerSideSessions.RemoveExpiredSessionsFrequency;
}

logger.LogDebug("Stopping server-side session removal");
}

private async Task RunAsync(CancellationToken cancellationToken = default)
{
// this is here for testing
if (!_options.ServerSideSessions.RemoveExpiredSessions)
if (!options.ServerSideSessions.RemoveExpiredSessions)
{
return;
}

try
{
await using (var serviceScope = _serviceProvider.GetRequiredService<IServiceScopeFactory>().CreateAsyncScope())
await using var serviceScope = serviceProvider.GetRequiredService<IServiceScopeFactory>().CreateAsyncScope();
var scopedLogger = serviceScope.ServiceProvider.GetRequiredService<ILogger<ServerSideSessionCleanupHost>>();
var scopedOptions = serviceScope.ServiceProvider.GetRequiredService<IdentityServerOptions>();
var serverSideTicketStore = serviceScope.ServiceProvider.GetRequiredService<IServerSideTicketStore>();
var sessionCoordinationService = serviceScope.ServiceProvider.GetRequiredService<ISessionCoordinationService>();

var found = int.MaxValue;

while (found > 0)
{
var logger = serviceScope.ServiceProvider.GetRequiredService<ILogger<ServerSideSessionCleanupHost>>();
var options = serviceScope.ServiceProvider.GetRequiredService<IdentityServerOptions>();
var serverSideTicketStore = serviceScope.ServiceProvider.GetRequiredService<IServerSideTicketStore>();
var sessionCoordinationService = serviceScope.ServiceProvider.GetRequiredService<ISessionCoordinationService>();
var sessions = await serverSideTicketStore.GetAndRemoveExpiredSessionsAsync(scopedOptions.ServerSideSessions.RemoveExpiredSessionsBatchSize, cancellationToken);
found = sessions.Count;

if (found <= 0)
{
continue;
}

var found = int.MaxValue;
scopedLogger.LogDebug("Processing expiration for {count} expired server-side sessions.", found);

while (found > 0)
foreach (var session in sessions)
{
var sessions = await serverSideTicketStore.GetAndRemoveExpiredSessionsAsync(options.ServerSideSessions.RemoveExpiredSessionsBatchSize, cancellationToken);
found = sessions.Count;

if (found > 0)
{
logger.LogDebug("Processing expiration for {count} expired server-side sessions.", found);

foreach (var session in sessions)
{
await sessionCoordinationService.ProcessExpirationAsync(session);
}
}
await sessionCoordinationService.ProcessExpirationAsync(session);
}
}
}
catch (Exception ex)
{
_logger.LogError(ex, "Exception removing expired sessions");
logger.LogError(ex, "Exception removing expired sessions");
}
}
}
Loading