Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 101 additions & 1 deletion __tests__/session.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { NextRequest, NextResponse } from 'next/server';
import { cookies, headers } from 'next/headers';
import { redirect } from 'next/navigation';
import { generateTestToken } from './test-helpers.js';
import { withAuth, updateSession, refreshSession, updateSessionMiddleware } from '../src/session.js';
import { withAuth, updateSession, refreshSession, updateSessionMiddleware, getCustomClaims } from '../src/session.js';
import { getWorkOS } from '../src/workos.js';
import * as envVariables from '../src/env-variables.js';

Expand Down Expand Up @@ -842,4 +842,104 @@ describe('session.ts', () => {
await expect(refreshSession()).rejects.toThrow('error');
});
});

describe('getCustomClaims', () => {
beforeEach(async () => {
const nextCookies = await cookies();
// @ts-expect-error - _reset is part of the mock
nextCookies._reset();
jest.clearAllMocks();
});

it('should return custom claims when accessToken is provided', async () => {
const customClaims = { department: 'engineering', level: 5, metadata: { theme: 'dark' } };
const token = await generateTestToken({
sub: 'user_123',
org_id: 'org_123',
role: 'admin',
permissions: ['read', 'write'],
entitlements: ['feature_a'],
...customClaims,
});

const result = await getCustomClaims(token);

expect(result).toEqual(customClaims);
});

it('should return null when no accessToken is provided and no session exists', async () => {
const result = await getCustomClaims();

expect(result).toBeNull();
});

it('should return empty object when token has no custom claims', async () => {
const token = await generateTestToken({
sub: 'user_123',
org_id: 'org_123',
role: 'admin',
permissions: ['read', 'write'],
entitlements: ['feature_a'],
});

const result = await getCustomClaims(token);

expect(result).toEqual({});
});

it('should filter out all standard JWT claims', async () => {
const customClaims = { customField: 'value', anotherCustom: 42 };
const token = await generateTestToken({
aud: 'audience',
exp: Math.floor(Date.now() / 1000) + 3600,
iat: Math.floor(Date.now() / 1000),
iss: 'issuer',
sub: 'user_123',
sid: 'session_123',
org_id: 'org_123',
role: 'admin',
permissions: ['read', 'write'],
entitlements: ['feature_a'],
jti: 'jwt_123',
nbf: Math.floor(Date.now() / 1000),
...customClaims,
});

const result = await getCustomClaims(token);

expect(result).toEqual(customClaims);
expect(result).not.toHaveProperty('aud');
expect(result).not.toHaveProperty('exp');
expect(result).not.toHaveProperty('iat');
expect(result).not.toHaveProperty('iss');
expect(result).not.toHaveProperty('sub');
expect(result).not.toHaveProperty('sid');
expect(result).not.toHaveProperty('org_id');
expect(result).not.toHaveProperty('role');
expect(result).not.toHaveProperty('permissions');
expect(result).not.toHaveProperty('entitlements');
expect(result).not.toHaveProperty('jti');
expect(result).not.toHaveProperty('nbf');
});

it('should handle complex nested custom claims', async () => {
const customClaims = {
metadata: {
preferences: { theme: 'dark', language: 'en' },
settings: ['setting1', 'setting2'],
},
tags: ['tag1', 'tag2'],
permissions_custom: { read: true, write: false },
};
const token = await generateTestToken({
sub: 'user_123',
org_id: 'org_123',
...customClaims,
});

const result = await getCustomClaims(token);

expect(result).toEqual(customClaims);
});
});
});
6 changes: 3 additions & 3 deletions __tests__/useAccessToken.spec.tsx
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import React from 'react';
import { render, waitFor, act } from '@testing-library/react';
import '@testing-library/jest-dom';
import { useAccessToken } from '../src/components/useAccessToken.js';
import { act, render, waitFor } from '@testing-library/react';
import React from 'react';
import { getAccessTokenAction, refreshAccessTokenAction } from '../src/actions.js';
import { useAuth } from '../src/components/authkit-provider.js';
import { useAccessToken } from '../src/components/useAccessToken.js';

jest.mock('../src/actions.js', () => ({
getAccessTokenAction: jest.fn(),
Expand Down
193 changes: 193 additions & 0 deletions __tests__/useCustomClaims.spec.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
import '@testing-library/jest-dom';
import { render, waitFor } from '@testing-library/react';
import React from 'react';
import { getAccessTokenAction } from '../src/actions.js';
import { useAuth } from '../src/components/authkit-provider.js';
import { useCustomClaims } from '../src/components/useCustomClaims.js';

jest.mock('../src/actions.js', () => ({
getAccessTokenAction: jest.fn(),
refreshAccessTokenAction: jest.fn(),
}));

jest.mock('../src/components/authkit-provider.js', () => {
const originalModule = jest.requireActual('../src/components/authkit-provider.js');
return {
...originalModule,
useAuth: jest.fn(),
};
});

jest.mock('jose', () => ({
decodeJwt: jest.fn((token: string) => {
try {
const parts = token.split('.');
if (parts.length !== 3) return null;
const payload = JSON.parse(atob(parts[1]));
return payload;
} catch {
return null;
}
}),
}));

describe('useCustomClaims', () => {
beforeEach(() => {
jest.clearAllMocks();
jest.useFakeTimers();

(useAuth as jest.Mock).mockImplementation(() => ({
user: { id: 'user_123' },
sessionId: 'session_123',
refreshAuth: jest.fn().mockResolvedValue({}),
}));
});

afterEach(() => {
jest.useRealTimers();
});

const CustomClaimsTestComponent = () => {
const customClaims = useCustomClaims();
return (
<div>
<div data-testid="claims">{JSON.stringify(customClaims)}</div>
</div>
);
};

it('should return null when no access token is available', async () => {
(getAccessTokenAction as jest.Mock).mockResolvedValue(undefined);

const { getByTestId } = render(<CustomClaimsTestComponent />);

await waitFor(() => {
expect(getByTestId('claims')).toHaveTextContent('null');
});
});

it('should return custom claims when access token is available', async () => {
const payload = {
aud: 'audience',
exp: 9999999999,
iat: 1234567800,
iss: 'issuer',
sub: 'user_123',
sid: 'session_123',
org_id: 'org_123',
role: 'admin',
permissions: ['read', 'write'],
entitlements: ['feature_a'],
jti: 'jwt_123',
nbf: 1234567800,
// Custom claims
customField1: 'value1',
customField2: 42,
customObject: { nested: 'data' },
};
const token = `eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.${btoa(JSON.stringify(payload))}.mock-signature`;

(getAccessTokenAction as jest.Mock).mockResolvedValue(token);

const { getByTestId } = render(<CustomClaimsTestComponent />);

await waitFor(() => {
const expectedCustomClaims = {
customField1: 'value1',
customField2: 42,
customObject: { nested: 'data' },
};
expect(getByTestId('claims')).toHaveTextContent(JSON.stringify(expectedCustomClaims));
});
});

it('should return empty object when token has no custom claims', async () => {
const payload = {
aud: 'audience',
exp: 9999999999,
iat: 1234567800,
iss: 'issuer',
sub: 'user_123',
sid: 'session_123',
org_id: 'org_123',
role: 'admin',
permissions: ['read', 'write'],
entitlements: ['feature_a'],
jti: 'jwt_123',
nbf: 1234567800,
};
const token = `eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.${btoa(JSON.stringify(payload))}.mock-signature`;

(getAccessTokenAction as jest.Mock).mockResolvedValue(token);

const { getByTestId } = render(<CustomClaimsTestComponent />);

await waitFor(() => {
expect(getByTestId('claims')).toHaveTextContent('{}');
});
});

it('should handle partial standard claims', async () => {
const payload = {
sub: 'user_123',
exp: 9999999999,
customField: 'value',
anotherCustom: true,
};
const token = `eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.${btoa(JSON.stringify(payload))}.mock-signature`;

(getAccessTokenAction as jest.Mock).mockResolvedValue(token);

const { getByTestId } = render(<CustomClaimsTestComponent />);

await waitFor(() => {
const expectedCustomClaims = {
customField: 'value',
anotherCustom: true,
};
expect(getByTestId('claims')).toHaveTextContent(JSON.stringify(expectedCustomClaims));
});
});

it('should handle complex nested custom claims', async () => {
const payload = {
sub: 'user_123',
exp: 9999999999,
metadata: {
preferences: {
theme: 'dark',
language: 'en',
},
settings: ['setting1', 'setting2'],
},
tags: ['tag1', 'tag2'],
permissions_custom: {
read: true,
write: false,
},
};
const token = `eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.${btoa(JSON.stringify(payload))}.mock-signature`;

(getAccessTokenAction as jest.Mock).mockResolvedValue(token);

const { getByTestId } = render(<CustomClaimsTestComponent />);

await waitFor(() => {
const expectedCustomClaims = {
metadata: {
preferences: {
theme: 'dark',
language: 'en',
},
settings: ['setting1', 'setting2'],
},
tags: ['tag1', 'tag2'],
permissions_custom: {
read: true,
write: false,
},
};
expect(getByTestId('claims')).toHaveTextContent(JSON.stringify(expectedCustomClaims));
});
});
});
3 changes: 2 additions & 1 deletion src/components/index.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { Impersonation } from './impersonation.js';
import { AuthKitProvider, useAuth } from './authkit-provider.js';
import { useAccessToken } from './useAccessToken.js';
import { useCustomClaims } from './useCustomClaims.js';

export { Impersonation, AuthKitProvider, useAuth, useAccessToken };
export { Impersonation, AuthKitProvider, useAuth, useAccessToken, useCustomClaims };
24 changes: 24 additions & 0 deletions src/components/useCustomClaims.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import { useMemo } from 'react';
import { useAccessToken } from './useAccessToken.js';
import { decodeJwt } from 'jose';

/**
* Extracts custom claims from the access token.
* @returns The custom claims as a record of key-value pairs.
*/
export function useCustomClaims<T = Record<string, unknown>>() {
const { accessToken } = useAccessToken();

return useMemo(() => {
if (!accessToken) {
return null;
}

const decoded = decodeJwt(accessToken);

// eslint-disable-next-line @typescript-eslint/no-unused-vars
const { aud, exp, iat, iss, sub, sid, org_id, role, permissions, entitlements, jti, nbf, ...custom } = decoded;

return custom as T;
}, [accessToken]);
}
3 changes: 2 additions & 1 deletion src/index.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { getSignInUrl, getSignUpUrl, signOut, switchToOrganization } from './auth.js';
import { handleAuth } from './authkit-callback-route.js';
import { authkit, authkitMiddleware } from './middleware.js';
import { refreshSession, saveSession, withAuth } from './session.js';
import { getCustomClaims, refreshSession, saveSession, withAuth } from './session.js';
import { getWorkOS } from './workos.js';

export * from './interfaces.js';
Expand All @@ -18,4 +18,5 @@ export {
signOut,
switchToOrganization,
withAuth,
getCustomClaims,
};
12 changes: 12 additions & 0 deletions src/session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,18 @@ async function redirectToSignIn() {
redirect(await getAuthorizationUrl({ returnPathname, screenHint }));
}

export async function getCustomClaims<T = Record<string, unknown>>(accessToken?: string) {
const token = accessToken ?? (await withAuth()).accessToken;
if (!token) {
return null;
}

const decoded = decodeJwt(token);
// eslint-disable-next-line @typescript-eslint/no-unused-vars
const { aud, exp, iat, iss, sub, sid, org_id, role, permissions, entitlements, jti, nbf, ...custom } = decoded;
return custom as T;
}

async function withAuth(options: { ensureSignedIn: true }): Promise<UserInfo>;
async function withAuth(options?: { ensureSignedIn?: true | false }): Promise<UserInfo | NoUserInfo>;
async function withAuth(options?: { ensureSignedIn?: boolean }): Promise<UserInfo | NoUserInfo> {
Expand Down