Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(core): set oidc issuer to custom domain #5509

Merged
merged 2 commits into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions packages/core/src/app/init.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,19 @@ export default async function initApp(app: Koa): Promise<void> {
return next();
}

const tenantId = await getTenantId(ctx.URL);
const [tenantId, isCustomDomain] = await getTenantId(ctx.URL);

if (!tenantId) {
ctx.status = 404;

return next();
}

const tenant = await trySafe(tenantPool.get(tenantId), (error) => {
// If the request is a custom domain of the tenant, use the custom endpoint to build "OIDC issuer"
// otherwise, build from the default endpoint (subdomain).
const customEndpoint = isCustomDomain ? ctx.URL.origin : undefined;

const tenant = await trySafe(tenantPool.get(tenantId, customEndpoint), (error) => {
ctx.status = error instanceof TenantNotFoundError ? 404 : 500;
void appInsights.trackException(error);
});
Expand Down
6 changes: 4 additions & 2 deletions packages/core/src/env-set/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ export class EnvSet {
return this.#oidc;
}

async load() {
async load(customDomain?: string) {
gao-sun marked this conversation as resolved.
Show resolved Hide resolved
const pool = await createPoolByEnv(
this.databaseUrl,
EnvSet.values.isUnitTest,
Expand All @@ -77,7 +77,9 @@ export class EnvSet {
});

const oidcConfigs = await getOidcConfigs();
const endpoint = getTenantEndpoint(this.tenantId, EnvSet.values);
const endpoint = customDomain
? new URL(customDomain)
: getTenantEndpoint(this.tenantId, EnvSet.values);
this.#oidc = await loadOidcValues(appendPath(endpoint, '/oidc').href, oidcConfigs);
}

Expand Down
2 changes: 1 addition & 1 deletion packages/core/src/middleware/koa-spa-session-guard.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ export default function koaSpaSessionGuard<
return;
}

const tenantId = await getTenantId(ctx.URL);
const [tenantId] = await getTenantId(ctx.URL);

if (!tenantId) {
throw new RequestError({ code: 'session.not_found', status: 404 });
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ mockEsm('@logto/shared', () => ({
}));

mockEsm('#src/utils/tenant.js', () => ({
getTenantId: () => adminTenantId,
getTenantId: () => [adminTenantId],
}));

const userQueries = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ mockEsm('@logto/shared', () => ({
}));

mockEsm('#src/utils/tenant.js', () => ({
getTenantId: () => adminTenantId,
getTenantId: () => [adminTenantId],
}));

const userQueries = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ async function handleSubmitRegister(
const { client_id } = ctx.interactionDetails.params;

const { isCloud } = EnvSet.values;
const isInAdminTenant = (await getTenantId(ctx.URL)) === adminTenantId;
const [currentTenantId] = await getTenantId(ctx.URL);
const isInAdminTenant = currentTenantId === adminTenantId;
const isCreatingFirstAdminUser =
isInAdminTenant && String(client_id) === adminConsoleApplicationId && !(await hasActiveUsers());

Expand Down
2 changes: 1 addition & 1 deletion packages/core/src/routes/user-assets.ts
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ export default function userAssetsRoutes<T extends AuthedRouter>(...[router]: Ro
'guard.mime_type_not_allowed'
);

const tenantId = await getTenantId(ctx.URL);
const [tenantId] = await getTenantId(ctx.URL);
assertThat(tenantId, 'guard.can_not_get_tenant_id');

const { storageProviderConfig } = SystemContext.shared;
Expand Down
5 changes: 3 additions & 2 deletions packages/core/src/tenants/Tenant.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,11 @@ import type TenantContext from './TenantContext.js';
import { getTenantDatabaseDsn } from './utils.js';

export default class Tenant implements TenantContext {
static async create(id: string, redisCache: RedisCache): Promise<Tenant> {
static async create(id: string, redisCache: RedisCache, customDomain?: string): Promise<Tenant> {
// Treat the default database URL as the management URL
const envSet = new EnvSet(id, await getTenantDatabaseDsn(id));
await envSet.load();
// Custom endpoint is used for building OIDC issuer URL when the request is a custom domain
await envSet.load(customDomain);

return new Tenant(envSet, id, new WellKnownCache(id, redisCache));
}
Expand Down
11 changes: 6 additions & 5 deletions packages/core/src/tenants/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ export class TenantPool {
},
});

async get(tenantId: string): Promise<Tenant> {
const tenantPromise = this.cache.get(tenantId);
async get(tenantId: string, customDomain?: string): Promise<Tenant> {
const cacheKey = `${tenantId}-${customDomain ?? 'default'}`;
const tenantPromise = this.cache.get(cacheKey);

if (tenantPromise) {
const tenant = await tenantPromise;
Expand All @@ -27,9 +28,9 @@ export class TenantPool {
// Otherwise, create a new tenant instance and store in LRU cache, using the code below.
}

consoleLog.info('Init tenant:', tenantId);
const newTenantPromise = Tenant.create(tenantId, redisCache);
this.cache.set(tenantId, newTenantPromise);
consoleLog.info('Init tenant:', tenantId, customDomain);
const newTenantPromise = Tenant.create(tenantId, redisCache, customDomain);
this.cache.set(cacheKey, newTenantPromise);

return newTenantPromise;
}
Expand Down
109 changes: 67 additions & 42 deletions packages/core/src/utils/tenant.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ mockEsm('#src/queries/domains.js', () => ({

const { getTenantId } = await import('./tenant.js');

const getTenantIdFirstElement = async (url: URL) => {
const [tenantId] = await getTenantId(url);
return tenantId;
};

describe('getTenantId()', () => {
const backupEnv = process.env;

Expand All @@ -37,7 +42,7 @@ describe('getTenantId()', () => {
DEVELOPMENT_TENANT_ID: 'foo',
};

await expect(getTenantId(new URL('https://some.random.url'))).resolves.toBe('foo');
await expect(getTenantIdFirstElement(new URL('https://some.random.url'))).resolves.toBe('foo');

process.env = {
...backupEnv,
Expand All @@ -46,20 +51,22 @@ describe('getTenantId()', () => {
DEVELOPMENT_TENANT_ID: 'bar',
};

await expect(getTenantId(new URL('https://some.random.url'))).resolves.toBe('bar');
await expect(getTenantIdFirstElement(new URL('https://some.random.url'))).resolves.toBe('bar');
});

it('should resolve proper tenant ID for similar localhost endpoints', async () => {
await expect(getTenantId(new URL('http://localhost:3002/some/path////'))).resolves.toBe(
adminTenantId
);
await expect(getTenantId(new URL('http://localhost:30021/some/path'))).resolves.toBe(
defaultTenantId
);
await expect(getTenantId(new URL('http://localhostt:30021/some/path'))).resolves.toBe(
await expect(
getTenantIdFirstElement(new URL('http://localhost:3002/some/path////'))
).resolves.toBe(adminTenantId);
await expect(
getTenantIdFirstElement(new URL('http://localhost:30021/some/path'))
).resolves.toBe(defaultTenantId);
await expect(
getTenantIdFirstElement(new URL('http://localhostt:30021/some/path'))
).resolves.toBe(defaultTenantId);
await expect(getTenantIdFirstElement(new URL('https://localhost:3002'))).resolves.toBe(
defaultTenantId
);
await expect(getTenantId(new URL('https://localhost:3002'))).resolves.toBe(defaultTenantId);
});

it('should resolve proper tenant ID for similar domain endpoints', async () => {
Expand All @@ -69,24 +76,30 @@ describe('getTenantId()', () => {
ENDPOINT: 'https://foo.*.logto.mock/app',
};

await expect(getTenantId(new URL('https://foo.foo.logto.mock/app///asdasd'))).resolves.toBe(
'foo'
);
await expect(getTenantId(new URL('https://foo.*.logto.mock/app'))).resolves.toBe(undefined);
await expect(getTenantId(new URL('https://foo.foo.logto.mockk/app///asdasd'))).resolves.toBe(
undefined
);
await expect(getTenantId(new URL('https://foo.foo.logto.mock/appp'))).resolves.toBe(undefined);
await expect(getTenantId(new URL('https://foo.foo.logto.mock:1/app/'))).resolves.toBe(
await expect(
getTenantIdFirstElement(new URL('https://foo.foo.logto.mock/app///asdasd'))
).resolves.toBe('foo');
await expect(getTenantIdFirstElement(new URL('https://foo.*.logto.mock/app'))).resolves.toBe(
undefined
);
await expect(getTenantId(new URL('http://foo.foo.logto.mock/app'))).resolves.toBe(undefined);
await expect(getTenantId(new URL('https://user.foo.bar.logto.mock/app'))).resolves.toBe(
await expect(
getTenantIdFirstElement(new URL('https://foo.foo.logto.mockk/app///asdasd'))
).resolves.toBe(undefined);
await expect(getTenantIdFirstElement(new URL('https://foo.foo.logto.mock/appp'))).resolves.toBe(
undefined
);
await expect(getTenantId(new URL('https://foo.bar.bar.logto.mock/app'))).resolves.toBe(
await expect(
getTenantIdFirstElement(new URL('https://foo.foo.logto.mock:1/app/'))
).resolves.toBe(undefined);
await expect(getTenantIdFirstElement(new URL('http://foo.foo.logto.mock/app'))).resolves.toBe(
undefined
);
await expect(
getTenantIdFirstElement(new URL('https://user.foo.bar.logto.mock/app'))
).resolves.toBe(undefined);
await expect(
getTenantIdFirstElement(new URL('https://foo.bar.bar.logto.mock/app'))
).resolves.toBe(undefined);
});

it('should resolve proper tenant ID if admin localhost is disabled', async () => {
Expand All @@ -99,17 +112,21 @@ describe('getTenantId()', () => {
ADMIN_DISABLE_LOCALHOST: '1',
};

await expect(getTenantId(new URL('http://localhost:5000/app///asdasd'))).resolves.toBe(
undefined
await expect(
getTenantIdFirstElement(new URL('http://localhost:5000/app///asdasd'))
).resolves.toBe(undefined);
await expect(
getTenantIdFirstElement(new URL('http://localhost:3002/app///asdasd'))
).resolves.toBe(undefined);
await expect(getTenantIdFirstElement(new URL('https://user.foo.logto.mock/app'))).resolves.toBe(
'foo'
);
await expect(getTenantId(new URL('http://localhost:3002/app///asdasd'))).resolves.toBe(
undefined
await expect(
getTenantIdFirstElement(new URL('https://user.admin.logto.mock/app//'))
).resolves.toBe(undefined); // Admin endpoint is explicitly set
await expect(getTenantIdFirstElement(new URL('https://admin.logto.mock/app'))).resolves.toBe(
adminTenantId
);
await expect(getTenantId(new URL('https://user.foo.logto.mock/app'))).resolves.toBe('foo');
await expect(getTenantId(new URL('https://user.admin.logto.mock/app//'))).resolves.toBe(
undefined
); // Admin endpoint is explicitly set
await expect(getTenantId(new URL('https://admin.logto.mock/app'))).resolves.toBe(adminTenantId);

process.env = {
...backupEnv,
Expand All @@ -118,9 +135,9 @@ describe('getTenantId()', () => {
ENDPOINT: 'https://user.*.logto.mock/app',
ADMIN_DISABLE_LOCALHOST: '1',
};
await expect(getTenantId(new URL('https://user.admin.logto.mock/app//'))).resolves.toBe(
'admin'
);
await expect(
getTenantIdFirstElement(new URL('https://user.admin.logto.mock/app//'))
).resolves.toBe('admin');
});

it('should resolve proper tenant ID for path-based multi-tenancy', async () => {
Expand All @@ -132,16 +149,24 @@ describe('getTenantId()', () => {
PATH_BASED_MULTI_TENANCY: '1',
};

await expect(getTenantId(new URL('http://localhost:5000/app///asdasd'))).resolves.toBe('app');
await expect(getTenantId(new URL('http://localhost:3002///bar///asdasd'))).resolves.toBe(
adminTenantId
await expect(
getTenantIdFirstElement(new URL('http://localhost:5000/app///asdasd'))
).resolves.toBe('app');
await expect(
getTenantIdFirstElement(new URL('http://localhost:3002///bar///asdasd'))
).resolves.toBe(adminTenantId);
await expect(getTenantIdFirstElement(new URL('https://user.foo.logto.mock/app'))).resolves.toBe(
undefined
);
await expect(getTenantId(new URL('https://user.foo.logto.mock/app'))).resolves.toBe(undefined);
await expect(getTenantId(new URL('https://user.admin.logto.mock/app//'))).resolves.toBe(
await expect(
getTenantIdFirstElement(new URL('https://user.admin.logto.mock/app//'))
).resolves.toBe(undefined);
await expect(getTenantIdFirstElement(new URL('https://user.logto.mock/app'))).resolves.toBe(
undefined
);
await expect(getTenantId(new URL('https://user.logto.mock/app'))).resolves.toBe(undefined);
await expect(getTenantId(new URL('https://user.logto.mock/app/admin'))).resolves.toBe('admin');
await expect(
getTenantIdFirstElement(new URL('https://user.logto.mock/app/admin'))
).resolves.toBe('admin');
});

it('should resolve proper custom domain', async () => {
Expand All @@ -151,6 +176,6 @@ describe('getTenantId()', () => {
NODE_ENV: 'production',
};
findActiveDomain.mockResolvedValueOnce({ domain: 'logto.mock.com', tenantId: 'mock' });
await expect(getTenantId(new URL('https://logto.mock.com'))).resolves.toBe('mock');
await expect(getTenantIdFirstElement(new URL('https://logto.mock.com'))).resolves.toBe('mock');
});
});
25 changes: 18 additions & 7 deletions packages/core/src/utils/tenant.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ export const clearCustomDomainCache = async (url: URL | string) => {
await trySafe(async () => redisCache.delete(getDomainCacheKey(url)));
};

/**
* Get tenant ID from the custom domain URL.
*/
const getTenantIdFromCustomDomain = async (
url: URL,
pool: CommonQueryMethods
Expand All @@ -74,7 +77,15 @@ const getTenantIdFromCustomDomain = async (
return domain?.tenantId;
};

export const getTenantId = async (url: URL) => {
/**
* Get tenant ID from the current request's URL.
*
* @param url The current request's URL
* @returns The tenant ID and whether the URL is a custom domain
*/
export const getTenantId = async (
url: URL
): Promise<[tenantId: string | undefined, isCustomDomain: boolean]> => {
const {
values: {
isMultiTenancy,
Expand All @@ -90,28 +101,28 @@ export const getTenantId = async (url: URL) => {
const pool = await sharedPool;

if (adminUrlSet.deduplicated().some((endpoint) => isEndpointOf(url, endpoint))) {
return adminTenantId;
return [adminTenantId, false];
}

if ((!isProduction || isIntegrationTest) && developmentTenantId) {
consoleLog.warn(`Found dev tenant ID ${developmentTenantId}.`);

return developmentTenantId;
return [developmentTenantId, false];
}

if (!isMultiTenancy) {
return defaultTenantId;
return [defaultTenantId, false];
}

if (isPathBasedMultiTenancy) {
return matchPathBasedTenantId(urlSet, url);
return [matchPathBasedTenantId(urlSet, url), false];
}

const customDomainTenantId = await getTenantIdFromCustomDomain(url, pool);

if (customDomainTenantId) {
return customDomainTenantId;
return [customDomainTenantId, true];
}

return matchDomainBasedTenantId(urlSet.endpoint, url);
return [matchDomainBasedTenantId(urlSet.endpoint, url), false];
};
Loading