Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for OAuth login #1122

Draft
wants to merge 20 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/sour-eggs-tan.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@xata.io/cli": patch
---

Add support for OAuth login
2 changes: 2 additions & 0 deletions cli/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
"lodash.compact": "^3.0.1",
"lodash.get": "^4.4.2",
"lodash.set": "^4.3.2",
"nanoid": "^4.0.2",
"node-fetch": "^3.3.2",
"open": "^9.1.0",
"prompts": "^2.4.2",
Expand All @@ -59,6 +60,7 @@
"@types/lodash.compact": "^3.0.9",
"@types/lodash.get": "^4.4.9",
"@types/lodash.set": "^4.3.9",
"@types/nanoid": "^3.0.0",
"@types/relaxed-json": "^1.0.4",
"@types/text-table": "^0.2.5",
"@types/tmp": "^0.2.6",
Expand Down
19 changes: 10 additions & 9 deletions cli/src/auth-server.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@ import url from 'url';
import { describe, expect, test, vi } from 'vitest';
import { generateKeys, generateURL, handler } from './auth-server.js';

const domain = 'https://app.xata.io';
const port = 1234;
const { publicKey, privateKey, passphrase } = generateKeys();

describe('generateURL', () => {
test('generates a URL', async () => {
const uiURL = generateURL(port, publicKey);
const uiURL = generateURL({ port, publicKey, domain });

expect(uiURL.startsWith('https://app.xata.io/new-api-key?')).toBe(true);
expect(uiURL.startsWith(`${domain}/new-api-key?`)).toBe(true);

const parsed = url.parse(uiURL, true);
const { pub, name, redirect } = parsed.query;
Expand All @@ -25,7 +26,7 @@ describe('generateURL', () => {
describe('handler', () => {
test('405s if the method is not GET', async () => {
const callback = vi.fn();
const httpHandler = handler(publicKey, privateKey, passphrase, callback);
const httpHandler = handler({ publicKey, privateKey, passphrase, callback, domain });

const req = { method: 'POST', url: '/' } as unknown as IncomingMessage;
const res = {
Expand All @@ -42,7 +43,7 @@ describe('handler', () => {

test('redirects if the path is /new', async () => {
const callback = vi.fn();
const httpHandler = handler(publicKey, privateKey, passphrase, callback);
const httpHandler = handler({ publicKey, privateKey, passphrase, callback, domain });

const writeHead = vi.fn();
const req = { method: 'GET', url: '/new', socket: { localPort: 9999 } } as unknown as IncomingMessage;
Expand All @@ -55,15 +56,15 @@ describe('handler', () => {

const [status, headers] = writeHead.mock.calls[0];
expect(status).toEqual(302);
expect(String(headers.location).startsWith('https://app.xata.io/new-api-key?pub=')).toBeTruthy();
expect(String(headers.location).startsWith(`${domain}/new-api-key?pub=`)).toBeTruthy();
expect(String(headers.location).includes('9999')).toBeTruthy();
expect(res.end).toHaveBeenCalledWith();
expect(callback).not.toHaveBeenCalled();
});

test('404s if the path is not the root path', async () => {
const callback = vi.fn();
const httpHandler = handler(publicKey, privateKey, passphrase, callback);
const httpHandler = handler({ publicKey, privateKey, passphrase, callback, domain });

const req = { method: 'GET', url: '/foo' } as unknown as IncomingMessage;
const res = {
Expand All @@ -80,7 +81,7 @@ describe('handler', () => {

test('returns 400 if resource is called with the wrong parameters', async () => {
const callback = vi.fn();
const httpHandler = handler(publicKey, privateKey, passphrase, callback);
const httpHandler = handler({ publicKey, privateKey, passphrase, callback, domain });

const req = { method: 'GET', url: '/' } as unknown as IncomingMessage;
const res = {
Expand All @@ -97,7 +98,7 @@ describe('handler', () => {

test('hadles errors correctly', async () => {
const callback = vi.fn();
const httpHandler = handler(publicKey, privateKey, passphrase, callback);
const httpHandler = handler({ publicKey, privateKey, passphrase, callback, domain });

const req = { method: 'GET', url: '/?key=malformed-key' } as unknown as IncomingMessage;
const res = {
Expand All @@ -115,7 +116,7 @@ describe('handler', () => {

test('receives the API key if everything is fine', async () => {
const callback = vi.fn();
const httpHandler = handler(publicKey, privateKey, passphrase, callback);
const httpHandler = handler({ publicKey, privateKey, passphrase, callback, domain });
const apiKey = 'abcdef1234';
const encryptedKey = crypto.publicEncrypt(publicKey, Buffer.from(apiKey));

Expand Down
84 changes: 66 additions & 18 deletions cli/src/auth-server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,32 @@ import { AddressInfo } from 'net';
import open from 'open';
import path, { dirname } from 'path';
import url, { fileURLToPath } from 'url';
import { z } from 'zod';

const __filename = fileURLToPath(import.meta.url);
const __dirname = dirname(__filename);

export function handler(publicKey: string, privateKey: string, passphrase: string, callback: (apiKey: string) => void) {
const ResponseSchema = z.object({
accessToken: z.string(),
refreshToken: z.string(),
expires: z.string()
});

type OAuthResponse = z.infer<typeof ResponseSchema>;

export function handler({
domain,
publicKey,
privateKey,
passphrase,
callback
}: {
domain: string;
publicKey: string;
privateKey: string;
passphrase: string;
callback: (response: OAuthResponse) => void;
}) {
return (req: http.IncomingMessage, res: http.ServerResponse) => {
try {
if (req.method !== 'GET') {
Expand All @@ -22,7 +43,7 @@ export function handler(publicKey: string, privateKey: string, passphrase: strin
if (parsedURL.pathname === '/new') {
const port = req.socket.localPort ?? 80;
res.writeHead(302, {
location: generateURL(port, publicKey)
location: generateURL({ port, publicKey, domain })
});
res.end();
return;
Expand All @@ -37,38 +58,45 @@ export function handler(publicKey: string, privateKey: string, passphrase: strin
return res.end('Missing key parameter');
}
const privKey = crypto.createPrivateKey({ key: privateKey, passphrase });
const apiKey = crypto
const response = crypto
.privateDecrypt(privKey, Buffer.from(String(parsedURL.query.key).replace(/ /g, '+'), 'base64'))
.toString('utf8');
renderSuccessPage(req, res, String(parsedURL.query['color-mode']));
req.destroy();
callback(apiKey);
callback(ResponseSchema.parse(JSON.parse(response)));
} catch (err) {
res.writeHead(500);
res.end(`Something went wrong: ${err instanceof Error ? err.message : String(err)}`);
}
};
}

function renderSuccessPage(req: http.IncomingMessage, res: http.ServerResponse, colorMode: string) {
function renderSuccessPage(_req: http.IncomingMessage, res: http.ServerResponse, colorMode: string) {
res.writeHead(200, {
'Content-Type': 'text/html'
});
const html = readFileSync(path.join(__dirname, 'api-key-success.html'), 'utf-8');
res.end(html.replace('data-color-mode=""', `data-color-mode="${colorMode}"`));
}

export function generateURL(port: number, publicKey: string) {
export function generateURL({ port, publicKey, domain }: { port: number; publicKey: string; domain: string }) {
const name = 'Xata CLI';
const serverRedirect = `${domain}/api/integrations/cli/callback`;
const cliRedirect = `http://localhost:${port}`;
const pub = publicKey
.replace(/\n/g, '')
.replace('-----BEGIN PUBLIC KEY-----', '')
.replace('-----END PUBLIC KEY-----', '');
const name = 'Xata CLI';
const redirect = `http://localhost:${port}`;
const url = new URL('https://app.xata.io/new-api-key');
url.searchParams.append('pub', pub);
url.searchParams.append('name', name);
url.searchParams.append('redirect', redirect);

const url = new URL(`${domain}/integrations/oauth/authorize`);
url.searchParams.set('client_id', 'b7msdmpun91q33vpihpk3vs39g');
url.searchParams.set('redirect_uri', serverRedirect);
url.searchParams.set('response_type', 'code');
url.searchParams.set('scope', 'admin:all');
url.searchParams.set(
'state',
Buffer.from(JSON.stringify({ name, pub, cliRedirect, serverRedirect })).toString('base64')
);
return url.toString();
}

Expand All @@ -90,19 +118,25 @@ export function generateKeys() {
return { publicKey, privateKey, passphrase };
}

export async function createAPIKeyThroughWebUI() {
export async function loginWithWebUI(domain: string) {
const { publicKey, privateKey, passphrase } = generateKeys();

return new Promise<string>((resolve) => {
return new Promise<OAuthResponse>((resolve) => {
const server = http.createServer(
handler(publicKey, privateKey, passphrase, (apiKey) => {
resolve(apiKey);
server.close();
handler({
domain,
publicKey,
privateKey,
passphrase,
callback: (credentials) => {
resolve(credentials);
server.close();
}
})
);
server.listen(() => {
const { port } = server.address() as AddressInfo;
const openURL = generateURL(port, publicKey);
const openURL = generateURL({ port, publicKey, domain });
console.log(
`We are opening your default browser. If your browser doesn't open automatically, please copy and paste the following URL into your browser: ${chalk.bold(
`http://localhost:${port}/new`
Expand All @@ -112,3 +146,17 @@ export async function createAPIKeyThroughWebUI() {
});
});
}

export async function refreshAccessToken(domain: string, refreshToken: string) {
const response = await fetch(`${domain}/api/integrations/cli/refresh`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ refreshToken })
});

if (!response.ok) {
throw new Error(`Failed to refresh access token: ${response.status} ${response.statusText}`);
}

return ResponseSchema.parse(await response.json());
}
Loading
Loading