Skip to content

Commit

Permalink
⚡️ perf: 优化长文本时的渲染性能 (lobehub#3754)
Browse files Browse the repository at this point in the history
* ⚡️ perf: 优化渲染长文本时的渲染性能

* 👷 build: fix vercel build

* ⚡️ perf: 提高 smooth 每帧速度

* ✅ test: add test for tokenizer edge runtime

* 💚 build: fix build

* ♻️ refactor: refactor with webapi

* 🚨 ci: improve lint

* ✅ test: fix test

* ⚡️ perf: try o200k_base
  • Loading branch information
arvinxx committed Sep 3, 2024
1 parent 63113f9 commit 51c6b62
Show file tree
Hide file tree
Showing 17 changed files with 162 additions and 25 deletions.
1 change: 1 addition & 0 deletions src/app/(main)/profile/[[...slugs]]/Client.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ export const useStyles = createStyles(
border-radius: unset;
`,
}) as Partial<{
// eslint-disable-next-line unused-imports/no-unused-vars
[k in keyof ElementsConfig]: any;
}>,
);
Expand Down
32 changes: 32 additions & 0 deletions src/app/webapi/tokenizer/index.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// @vitest-environment edge-runtime
import { describe, expect, it } from 'vitest';

import { POST } from './route';

describe('tokenizer Route', () => {
it('count hello world', async () => {
const txt = 'Hello, world!';
const request = new Request('https://test.com', {
method: 'POST',
body: txt,
});

const response = await POST(request);

const data = await response.json();
expect(data.count).toEqual(4);
});

it('count Chinese', async () => {
const txt = '今天天气真好';
const request = new Request('https://test.com', {
method: 'POST',
body: txt,
});

const response = await POST(request);

const data = await response.json();
expect(data.count).toEqual(5);
});
});
8 changes: 8 additions & 0 deletions src/app/webapi/tokenizer/route.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import { encode } from 'gpt-tokenizer/encoding/o200k_base';
import { NextResponse } from 'next/server';

export const POST = async (req: Request) => {
const str = await req.text();

return NextResponse.json({ count: encode(str).length });
};
2 changes: 1 addition & 1 deletion src/database/server/models/session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ export class SessionModel {

if (!result) return;

// eslint-disable-next-line @typescript-eslint/no-unused-vars
// eslint-disable-next-line @typescript-eslint/no-unused-vars,unused-imports/no-unused-vars
const { agent, clientId, ...session } = result;
const sessionId = this.genId();

Expand Down
26 changes: 19 additions & 7 deletions src/hooks/useTokenCount.ts
Original file line number Diff line number Diff line change
@@ -1,20 +1,32 @@
import { startTransition, useEffect, useState } from 'react';
import { debounce } from 'lodash-es';
import { startTransition, useCallback, useEffect, useState } from 'react';

import { encodeAsync } from '@/utils/tokenizer';

export const useTokenCount = (input: string = '') => {
const [value, setNum] = useState(0);

useEffect(() => {
startTransition(() => {
encodeAsync(input || '')
const debouncedEncode = useCallback(
debounce((text: string) => {
encodeAsync(text)
.then(setNum)
.catch(() => {
// 兜底采用字符数
setNum(input.length);
setNum(text.length);
});
}, 300),
[],
);

useEffect(() => {
startTransition(() => {
debouncedEncode(input || '');
});
}, [input]);

// 清理函数
return () => {
debouncedEncode.cancel();
};
}, [input, debouncedEncode]);

return value;
};
1 change: 1 addition & 0 deletions src/layout/AuthProvider/Clerk/useAppearance.ts
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ export const useStyles = createStyles(
order: -1;
`,
}) as Partial<{
// eslint-disable-next-line unused-imports/no-unused-vars
[k in keyof ElementsConfig]: any;
}>,
);
Expand Down
4 changes: 3 additions & 1 deletion src/server/context.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ export const createContext = async (request: NextRequest): Promise<Context> => {
userId = session.user.id;
}
return createContextInner({ authorizationHeader: authorization, nextAuth: auth, userId });
} catch {}
} catch (e) {
console.error('next auth err', e);
}
}

return createContextInner({ authorizationHeader: authorization, userId });
Expand Down
2 changes: 1 addition & 1 deletion src/server/routers/edge/index.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* This file contains the root router of Lobe Chat tRPC-backend
* This file contains the edge router of Lobe Chat tRPC-backend
*/
import { publicProcedure, router } from '@/libs/trpc';

Expand Down
2 changes: 1 addition & 1 deletion src/services/user/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ export class ClientService implements IUserService {
await this.preferenceStorage.saveToLocalStorage(preference);
}

// eslint-disable-next-line @typescript-eslint/no-unused-vars
// eslint-disable-next-line @typescript-eslint/no-unused-vars,unused-imports/no-unused-vars
async updateGuide(guide: Partial<UserGuide>) {
throw new Error('Method not implemented.');
}
Expand Down
7 changes: 7 additions & 0 deletions src/types/worker.d.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
declare module '*.worker.ts' {
class WebpackWorker extends Worker {
constructor();
}

export default WebpackWorker;
}
8 changes: 4 additions & 4 deletions src/utils/fetch/__tests__/fetchSSE.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,8 @@ describe('fetchSSE', () => {
onFinish: mockOnFinish,
});

expect(mockOnMessageHandle).toHaveBeenNthCalledWith(1, { text: 'He', type: 'text' });
expect(mockOnMessageHandle).toHaveBeenNthCalledWith(2, { text: 'llo World', type: 'text' });
expect(mockOnMessageHandle).toHaveBeenNthCalledWith(1, { text: 'Hell', type: 'text' });
expect(mockOnMessageHandle).toHaveBeenNthCalledWith(2, { text: 'o World', type: 'text' });
// more assertions for each character...
expect(mockOnFinish).toHaveBeenCalledWith('Hello World', {
observationId: null,
Expand Down Expand Up @@ -232,8 +232,8 @@ describe('fetchSSE', () => {
signal: abortController.signal,
});

expect(mockOnMessageHandle).toHaveBeenNthCalledWith(1, { text: 'He', type: 'text' });
expect(mockOnMessageHandle).toHaveBeenNthCalledWith(2, { text: 'llo World', type: 'text' });
expect(mockOnMessageHandle).toHaveBeenNthCalledWith(1, { text: 'Hell', type: 'text' });
expect(mockOnMessageHandle).toHaveBeenNthCalledWith(2, { text: 'o World', type: 'text' });

expect(mockOnFinish).toHaveBeenCalledWith('Hello World', {
type: 'done',
Expand Down
14 changes: 9 additions & 5 deletions src/utils/fetch/fetchSSE.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ export interface FetchSSEOptions {
smoothing?: boolean;
}

const START_ANIMATION_SPEED = 4;

const END_ANIMATION_SPEED = 15;

const createSmoothMessage = (params: { onTextUpdate: (delta: string, text: string) => void }) => {
let buffer = '';
// why use queue: https://shareg.pt/GLBrjpK
Expand All @@ -64,7 +68,7 @@ const createSmoothMessage = (params: { onTextUpdate: (delta: string, text: strin

// define startAnimation function to display the text in buffer smooth
// when you need to start the animation, call this function
const startAnimation = (speed = 2) =>
const startAnimation = (speed = START_ANIMATION_SPEED) =>
new Promise<void>((resolve) => {
if (isAnimationActive) {
resolve();
Expand Down Expand Up @@ -137,7 +141,7 @@ const createSmoothToolCalls = (params: {
}
};

const startAnimation = (index: number, speed = 2) =>
const startAnimation = (index: number, speed = START_ANIMATION_SPEED) =>
new Promise<void>((resolve) => {
if (isAnimationActives[index]) {
resolve();
Expand Down Expand Up @@ -191,7 +195,7 @@ const createSmoothToolCalls = (params: {
});
};

const startAnimations = async (speed = 2) => {
const startAnimations = async (speed = START_ANIMATION_SPEED) => {
const pools = toolCallsBuffer.map(async (_, index) => {
if (outputQueues[index].length > 0 && !isAnimationActives[index]) {
await startAnimation(index, speed);
Expand Down Expand Up @@ -365,11 +369,11 @@ export const fetchSSE = async (url: string, options: RequestInit & FetchSSEOptio
const observationId = response.headers.get(LOBE_CHAT_OBSERVATION_ID);

if (textController.isTokenRemain()) {
await textController.startAnimation(15);
await textController.startAnimation(END_ANIMATION_SPEED);
}

if (toolCallsController.isTokenRemain()) {
await toolCallsController.startAnimations(15);
await toolCallsController.startAnimations(END_ANIMATION_SPEED);
}

await options?.onFinish?.(output, { observationId, toolCalls, traceId, type: finishedType });
Expand Down
5 changes: 0 additions & 5 deletions src/utils/tokenizer.ts

This file was deleted.

35 changes: 35 additions & 0 deletions src/utils/tokenizer/client.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
let worker: Worker | null = null;

const getWorker = () => {
if (!worker && typeof Worker !== 'undefined') {
worker = new Worker(new URL('tokenizer.worker.ts', import.meta.url));
}
return worker;
};

export const clientEncodeAsync = (str: string): Promise<number> =>
new Promise((resolve, reject) => {
const worker = getWorker();

if (!worker) {
// 如果 WebWorker 不可用,回退到字符串计算
resolve(str.length);
return;
}

const id = Date.now().toString();

const handleMessage = (event: MessageEvent) => {
if (event.data.id === id) {
worker.removeEventListener('message', handleMessage);
if (event.data.error) {
reject(new Error(event.data.error));
} else {
resolve(event.data.result);
}
}
};

worker.addEventListener('message', handleMessage);
worker.postMessage({ id, str });
});
15 changes: 15 additions & 0 deletions src/utils/tokenizer/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
export const encodeAsync = async (str: string): Promise<number> => {
if (str.length === 0) return 0;

// 50_000 is the limit of the client
// if the string is longer than 100_000, we will use the server
if (str.length <= 50_000) {
const { clientEncodeAsync } = await import('./client');

return await clientEncodeAsync(str);
} else {
const { serverEncodeAsync } = await import('./server');

return await serverEncodeAsync(str);
}
};
11 changes: 11 additions & 0 deletions src/utils/tokenizer/server.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
export const serverEncodeAsync = async (str: string): Promise<number> => {
try {
const res = await fetch('/webapi/tokenizer', { body: str, method: 'POST' });
const data = await res.json();

return data.count;
} catch (e) {
console.error('serverEncodeAsync:', e);
return str.length;
}
};
14 changes: 14 additions & 0 deletions src/utils/tokenizer/tokenizer.worker.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
addEventListener('message', async (event) => {
const { id, str } = event.data;
try {
const { encode } = await import('gpt-tokenizer');

console.time('client tokenizer');
const tokenCount = encode(str).length;
console.timeEnd('client tokenizer');

postMessage({ id, result: tokenCount });
} catch (error) {
postMessage({ error: (error as Error).message, id });
}
});

0 comments on commit 51c6b62

Please sign in to comment.