Skip to content

Commit

Permalink
Fix get client & create "DistributedCache" table (postgre)
Browse files Browse the repository at this point in the history
  • Loading branch information
thabart committed Sep 27, 2024
1 parent 1a8d216 commit a90323f
Show file tree
Hide file tree
Showing 10 changed files with 93 additions and 8 deletions.
42 changes: 40 additions & 2 deletions src/IdServer/SimpleIdServer.IdServer.Startup/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting;
using MySqlConnector;
using NeoSmart.Caching.Sqlite.AspNetCore;
using SimpleIdServer.Configuration;
using SimpleIdServer.Did.Key;
Expand All @@ -39,6 +40,7 @@
using SimpleIdServer.IdServer.WsFederation;
using System;
using System.Collections.Generic;
using System.Data.Common;
using System.Linq;
using System.Net;
using System.Security.Cryptography.X509Certificates;
Expand All @@ -63,6 +65,20 @@
"KEY `Index_ExpiresAtTime` (`ExpiresAtTime`)" +
")";

const string PostgreCreateSchemaAndTableSql =
$"""
CREATE SCHEMA IF NOT EXISTS "public";
CREATE TABLE IF NOT EXISTS "public"."DistributedCache"
(
"Id" text COLLATE pg_catalog."default" NOT NULL,
"Value" bytea,
"ExpiresAtTime" timestamp with time zone,
"SlidingExpirationInSeconds" double precision,
"AbsoluteExpiration" timestamp with time zone,
CONSTRAINT "DistCache_pkey" PRIMARY KEY ("Id")
)
""";

ServicePointManager.ServerCertificateValidationCallback += (o, c, ch, er) => true;
var builder = WebApplication.CreateBuilder(args);
builder.Configuration
Expand Down Expand Up @@ -498,6 +514,13 @@ async void SeedData(WebApplication application, string scimBaseUrl)
void EnableIsolationLevel(StoreDbContext dbContext)
{
if (dbContext.Database.IsInMemory()) return;
EnableSqlServer(dbContext);
EnableMysql(dbContext);
EnablePostgre(dbContext);
}

void EnableSqlServer(StoreDbContext dbContext)
{
var dbConnection = dbContext.Database.GetDbConnection();
var sqlConnection = dbConnection as SqlConnection;
if (sqlConnection != null)
Expand All @@ -509,17 +532,32 @@ void EnableIsolationLevel(StoreDbContext dbContext)
cmd = sqlConnection.CreateCommand();
cmd.CommandText = SQLServerCreateTableFormat;
cmd.ExecuteNonQuery();
return;
}
}

void EnableMysql(StoreDbContext dbContext)
{
var dbConnection = dbContext.Database.GetDbConnection();
var mysqlConnection = dbConnection as MySqlConnector.MySqlConnection;
if (mysqlConnection != null)
{
if (mysqlConnection.State != System.Data.ConnectionState.Open) mysqlConnection.Open();
var cmd = mysqlConnection.CreateCommand();
cmd.CommandText = MYSQLCreateTableFormat;
cmd.ExecuteNonQuery();
return;
}
}

void EnablePostgre(StoreDbContext dbContext)
{
var dbConnection = dbContext.Database.GetDbConnection();
var postgreconnection = dbConnection as Npgsql.NpgsqlConnection;
if(postgreconnection != null)
{
if (postgreconnection.State != System.Data.ConnectionState.Open) postgreconnection.Open();
var cmd = postgreconnection.CreateCommand();
cmd.CommandText = PostgreCreateSchemaAndTableSql;
cmd.ExecuteNonQuery();
}
}

Expand Down
10 changes: 10 additions & 0 deletions src/IdServer/SimpleIdServer.IdServer.Store.EF/ClientRepository.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,16 @@ public ClientRepository(StoreDbContext dbContext)
_dbContext = dbContext;
}

public Task<Client> GetById(string realm, string id, CancellationToken cancellationToken)
{
return _dbContext.Clients
.Include(c => c.Scopes).ThenInclude(s => s.ClaimMappers)
.Include(c => c.SerializedJsonWebKeys)
.Include(c => c.Translations)
.Include(c => c.Realms)
.SingleOrDefaultAsync(c => c.Id == id && c.Realms.Any(r => r.Name == realm), cancellationToken);
}

public Task<Client> GetByClientId(string realm, string clientId, CancellationToken cancellationToken)
{
return _dbContext.Clients
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,17 @@ public async Task<List<Client>> GetAll(string realm, List<string> clientIds, Can
return result.Select(r => r.ToDomain()).ToList();
}

public async Task<Client> GetById(string realm, string id, CancellationToken cancellationToken)
{
var result = await _dbContext.Client.Queryable<SugarClient>()
.Includes(c => c.ClientScopes, c => c.Scope, s => s.ClaimMappers)
.Includes(c => c.SerializedJsonWebKeys)
.Includes(c => c.Translations)
.Includes(c => c.Realms)
.FirstAsync(c => c.Id == id && c.Realms.Any(r => r.RealmsName == realm), cancellationToken);
return result?.ToDomain();
}

public async Task<Client> GetByClientId(string realm, string clientId, CancellationToken cancellationToken)
{
var result = await _dbContext.Client.Queryable<SugarClient>()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@
notificationService.Notify(new NotificationMessage { Severity = NotificationSeverity.Error, Summary = act.ErrorMessage });
StateHasChanged();
});
dispatcher.Dispatch(new GetClientAction { ClientId = id });
dispatcher.Dispatch(new GetClientAction { Id = id });
}

void OnChange(int index)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
</RadzenDataGridColumn>
<RadzenDataGridColumn TItem="SelectableClient" Filterable="false" Sortable="true" SortProperty="Value.ClientId" Title="@Global.Identifier" Width="80px">
<Template Context="data">
<RadzenLink Text="@data.Value.ClientId" Path="@(urlHelper.GetUrl($"/clients/{System.Web.HttpUtility.UrlEncode(data.Value.ClientId)}/settings"))" />
<RadzenLink Text="@data.Value.ClientId" Path="@(urlHelper.GetUrl($"/clients/{data.Value.Id}/settings"))" />
</Template>
</RadzenDataGridColumn>
<RadzenDataGridColumn TItem="SelectableClient" Property="Value.ClientName" Filterable="false" Sortable="false" Title="@Global.Name" Width="80px" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ public async Task Handle(GetClientAction action, IDispatcher dispatcher)
var httpClient = await _websiteHttpClientFactory.Build();
var requestMessage = new HttpRequestMessage
{
RequestUri = new Uri($"{baseUrl}/{System.Web.HttpUtility.UrlEncode(action.ClientId)}"),
RequestUri = new Uri($"{baseUrl}/bytechnicalid/{action.Id}"),
Method = HttpMethod.Get
};
var httpResult = await httpClient.SendAsync(requestMessage);
Expand Down Expand Up @@ -758,7 +758,8 @@ private async Task CreateClient(Domains.Client client, IDispatcher dispatcher, s
try
{
httpResult.EnsureSuccessStatusCode();
dispatcher.Dispatch(new AddClientSuccessAction { ClientId = client.ClientId, ClientName = client.ClientName, Language = client.Translations.FirstOrDefault()?.Language, ClientType = clientType, Pem = pemResult, JsonWebKeyStr = jsonWebKey });
var newClient = JsonSerializer.Deserialize<Client>(json);
dispatcher.Dispatch(new AddClientSuccessAction { Id = newClient.Id, ClientId = client.ClientId, ClientName = client.ClientName, Language = client.Translations.FirstOrDefault()?.Language, ClientType = clientType, Pem = pemResult, JsonWebKeyStr = jsonWebKey });
}
catch
{
Expand Down Expand Up @@ -901,6 +902,7 @@ public class AddClientFailureAction

public class AddClientSuccessAction
{
public string Id { get; set; }
public string ClientId { get; set; } = null!;
public string? ClientName { get; set; } = null;
public string? Language { get; set; } = null;
Expand Down Expand Up @@ -932,7 +934,7 @@ public class ToggleAllClientSelectionAction

public class GetClientAction
{
public string ClientId { get; set; } = null!;
public string Id { get; set; } = null!;
}

public class GetClientFailureAction
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public static SearchClientsState ReduceAddClientSuccessAction(SearchClientsState
{
var clients = state.Clients?.ToList();
if (clients == null) return state;
var newClient = new Domains.Client { ClientId = act.ClientId, CreateDateTime = DateTime.Now, UpdateDateTime = DateTime.Now, ClientType = act.ClientType };
var newClient = new Domains.Client { Id = act.Id, ClientId = act.ClientId, CreateDateTime = DateTime.Now, UpdateDateTime = DateTime.Now, ClientType = act.ClientType };
if(!string.IsNullOrWhiteSpace(act.ClientName))
newClient.Translations.Add(new Translation
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ public async Task<IActionResult> Add([FromRoute] string prefix, [FromBody] Clien
throw new OAuthException(ErrorCodes.INVALID_REQUEST, string.Format(Global.ClientIdentifierAlreadyExists, request.ClientId));
request.Scopes = await GetScopes(prefix, request.Scope, CancellationToken.None);
var realm = await _realmRepository.Get(prefix, cancellationToken);
request.Realms.Clear();
request.Realms.Add(realm);
await _registerClientRequestValidator.Validate(prefix, request, CancellationToken.None);
_clientRepository.Add(request);
Expand Down Expand Up @@ -166,6 +167,25 @@ public async Task<IActionResult> Get([FromRoute] string prefix, string id, Cance
}
}

[HttpGet]
public async Task<IActionResult> GetByTechnicalId([FromRoute] string prefix, string id, CancellationToken cancellationToken)
{
prefix = prefix ?? Constants.DefaultRealm;
try
{
id = System.Web.HttpUtility.UrlDecode(id);
await CheckAccessToken(prefix, Constants.StandardScopes.Clients.Name);
var result = await _clientRepository.GetById(prefix, id, cancellationToken);
if (result == null) throw new OAuthException(HttpStatusCode.NotFound, ErrorCodes.NOT_FOUND, string.Format(Global.UnknownClient, id));
return new OkObjectResult(result);
}
catch (OAuthException ex)
{
_logger.LogError(ex.ToString());
return BuildError(ex);
}
}

[HttpDelete]
public async Task<IActionResult> Delete([FromRoute] string prefix, string id, CancellationToken cancellationToken)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ namespace SimpleIdServer.IdServer.Stores
{
public interface IClientRepository
{
Task<Client> GetById(string realm, string id, CancellationToken cancellationToken);
Task<Client> GetByClientId(string realm, string clientId, CancellationToken cancellationToken);
Task<List<Client>> GetByClientIds(string realm, List<string> clientIds, CancellationToken cancellationToken);
Task<List<Client>> GetByClientIdsAndExistingBackchannelLogoutUri(string realm, List<string> clientIds, CancellationToken cancellationToken);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,9 @@ public static WebApplication UseSID(this WebApplication webApplication, bool coo
webApplication.SidMapControllerRoute("addClient",
pattern: (usePrefix ? "{prefix}/" : string.Empty) + Constants.EndPoints.Clients,
defaults: new { controller = "Clients", action = "Add" });
webApplication.SidMapControllerRoute("getClientByTechnicalId",
pattern: (usePrefix ? "{prefix}/" : string.Empty) + Constants.EndPoints.Clients + "/bytechnicalid/{id}",
defaults: new { controller = "Clients", action = "GetByTechnicalId" });
webApplication.SidMapControllerRoute("getClient",
pattern: (usePrefix ? "{prefix}/" : string.Empty) + Constants.EndPoints.Clients + "/{id}",
defaults: new { controller = "Clients", action = "Get" });
Expand Down

0 comments on commit a90323f

Please sign in to comment.