Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(core): set oidc issuer to custom domain #5509

Merged
merged 2 commits into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions packages/core/src/app/init.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import type Koa from 'koa';
import { EnvSet } from '#src/env-set/index.js';
import { TenantNotFoundError, tenantPool } from '#src/tenants/index.js';
import { consoleLog } from '#src/utils/console.js';
import { getTenantId } from '#src/utils/tenant.js';
import { getTenantId, getTenantIdFromCustomDomain } from '#src/utils/tenant.js';

const logListening = (type: 'core' | 'admin' = 'core') => {
const urlSet = type === 'core' ? EnvSet.values.urlSet : EnvSet.values.adminUrlSet;
Expand All @@ -29,15 +29,20 @@ export default async function initApp(app: Koa): Promise<void> {
return next();
}

const tenantId = await getTenantId(ctx.URL);
const tenantIdFromCustomDomain = await getTenantIdFromCustomDomain(ctx.URL);
const tenantId = tenantIdFromCustomDomain ?? (await getTenantId(ctx.URL, true));

if (!tenantId) {
ctx.status = 404;

return next();
}

const tenant = await trySafe(tenantPool.get(tenantId), (error) => {
// If the request is a custom domain of the tenant, use the custom endpoint to build "OIDC issuer"
// otherwise, build from the default endpoint (subdomain).
const customEndpoint = tenantIdFromCustomDomain ? ctx.URL.origin : undefined;

const tenant = await trySafe(tenantPool.get(tenantId, customEndpoint), (error) => {
ctx.status = error instanceof TenantNotFoundError ? 404 : 500;
void appInsights.trackException(error);
});
Expand Down
6 changes: 4 additions & 2 deletions packages/core/src/env-set/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ export class EnvSet {
return this.#oidc;
}

async load() {
async load(customDomain?: string) {
gao-sun marked this conversation as resolved.
Show resolved Hide resolved
const pool = await createPoolByEnv(
this.databaseUrl,
EnvSet.values.isUnitTest,
Expand All @@ -77,7 +77,9 @@ export class EnvSet {
});

const oidcConfigs = await getOidcConfigs();
const endpoint = getTenantEndpoint(this.tenantId, EnvSet.values);
const endpoint = customDomain
? new URL(customDomain)
: getTenantEndpoint(this.tenantId, EnvSet.values);
this.#oidc = await loadOidcValues(appendPath(endpoint, '/oidc').href, oidcConfigs);
}

Expand Down
5 changes: 3 additions & 2 deletions packages/core/src/tenants/Tenant.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,11 @@ import type TenantContext from './TenantContext.js';
import { getTenantDatabaseDsn } from './utils.js';

export default class Tenant implements TenantContext {
static async create(id: string, redisCache: RedisCache): Promise<Tenant> {
static async create(id: string, redisCache: RedisCache, customDomain?: string): Promise<Tenant> {
// Treat the default database URL as the management URL
const envSet = new EnvSet(id, await getTenantDatabaseDsn(id));
await envSet.load();
// Custom endpoint is used for building OIDC issuer URL when the request is a custom domain
await envSet.load(customDomain);

return new Tenant(envSet, id, new WellKnownCache(id, redisCache));
}
Expand Down
11 changes: 6 additions & 5 deletions packages/core/src/tenants/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ export class TenantPool {
},
});

async get(tenantId: string): Promise<Tenant> {
const tenantPromise = this.cache.get(tenantId);
async get(tenantId: string, customDomain?: string): Promise<Tenant> {
const cacheKey = `${tenantId}-${customDomain ?? 'default'}`;
const tenantPromise = this.cache.get(cacheKey);

if (tenantPromise) {
const tenant = await tenantPromise;
Expand All @@ -27,9 +28,9 @@ export class TenantPool {
// Otherwise, create a new tenant instance and store in LRU cache, using the code below.
}

consoleLog.info('Init tenant:', tenantId);
const newTenantPromise = Tenant.create(tenantId, redisCache);
this.cache.set(tenantId, newTenantPromise);
consoleLog.info('Init tenant:', tenantId, customDomain);
const newTenantPromise = Tenant.create(tenantId, redisCache, customDomain);
this.cache.set(cacheKey, newTenantPromise);

return newTenantPromise;
}
Expand Down
10 changes: 10 additions & 0 deletions packages/core/src/utils/tenant.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -153,4 +153,14 @@ describe('getTenantId()', () => {
findActiveDomain.mockResolvedValueOnce({ domain: 'logto.mock.com', tenantId: 'mock' });
await expect(getTenantId(new URL('https://logto.mock.com'))).resolves.toBe('mock');
});

it('should skip custom domain searching', async () => {
process.env = {
...backupEnv,
ENDPOINT: 'https://foo.*.logto.mock/app',
NODE_ENV: 'production',
};
findActiveDomain.mockResolvedValueOnce({ domain: 'logto.mock.com', tenantId: 'mock' });
await expect(getTenantId(new URL('https://logto.mock.com'), true)).resolves.toBeUndefined();
});
});
29 changes: 22 additions & 7 deletions packages/core/src/utils/tenant.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,17 +53,20 @@ export const clearCustomDomainCache = async (url: URL | string) => {
await trySafe(async () => redisCache.delete(getDomainCacheKey(url)));
};

const getTenantIdFromCustomDomain = async (
/**
* Get tenant ID from the custom domain URL.
*/
export const getTenantIdFromCustomDomain = async (
url: URL,
pool: CommonQueryMethods
pool?: CommonQueryMethods
): Promise<string | undefined> => {
const cachedValue = await trySafe(async () => redisCache.get(getDomainCacheKey(url)));

if (cachedValue) {
return cachedValue;
}

const { findActiveDomain } = createDomainsQueries(pool);
const { findActiveDomain } = createDomainsQueries(pool ?? (await EnvSet.sharedPool));

const domain = await findActiveDomain(url.hostname);

Expand All @@ -74,7 +77,17 @@ const getTenantIdFromCustomDomain = async (
return domain?.tenantId;
};

export const getTenantId = async (url: URL) => {
/**
* Get tenant ID from the current request's URL.
*
* @param url The current request's URL
* @param skipCustomDomain Indicating whether to skip looking for custom domain
* @returns tenantId or undefined
*/
export const getTenantId = async (
url: URL,
skipCustomDomain?: boolean
): Promise<string | undefined> => {
wangsijie marked this conversation as resolved.
Show resolved Hide resolved
const {
values: {
isMultiTenancy,
Expand Down Expand Up @@ -107,10 +120,12 @@ export const getTenantId = async (url: URL) => {
return matchPathBasedTenantId(urlSet, url);
}

const customDomainTenantId = await getTenantIdFromCustomDomain(url, pool);
if (!skipCustomDomain) {
const customDomainTenantId = await getTenantIdFromCustomDomain(url, pool);

if (customDomainTenantId) {
return customDomainTenantId;
if (customDomainTenantId) {
return customDomainTenantId;
}
}

return matchDomainBasedTenantId(urlSet.endpoint, url);
Expand Down
Loading