Skip to content

Commit

Permalink
⚡️ perf: fix stream smoothing (lobehub#3758)
Browse files Browse the repository at this point in the history
  • Loading branch information
arvinxx authored Sep 4, 2024
1 parent 7b70773 commit 0e19893
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 28 deletions.
15 changes: 10 additions & 5 deletions src/utils/fetch/__tests__/fetchSSE.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { afterEach, describe, expect, it, vi } from 'vitest';

import { MESSAGE_CANCEL_FLAT } from '@/const/message';
import { ChatMessageError } from '@/types/message';
import { sleep } from '@/utils/sleep';

import { FetchEventSourceInit } from '../fetchEventSource';
import { fetchEventSource } from '../fetchEventSource';
Expand Down Expand Up @@ -127,9 +128,10 @@ describe('fetchSSE', () => {
const mockOnFinish = vi.fn();

(fetchEventSource as any).mockImplementationOnce(
(url: string, options: FetchEventSourceInit) => {
async (url: string, options: FetchEventSourceInit) => {
options.onopen!({ clone: () => ({ ok: true, headers: new Headers() }) } as any);
options.onmessage!({ event: 'text', data: JSON.stringify('Hello') } as any);
await sleep(100);
options.onmessage!({ event: 'text', data: JSON.stringify(' World') } as any);
},
);
Expand All @@ -140,7 +142,8 @@ describe('fetchSSE', () => {
});

expect(mockOnMessageHandle).toHaveBeenNthCalledWith(1, { text: 'Hell', type: 'text' });
expect(mockOnMessageHandle).toHaveBeenNthCalledWith(2, { text: 'o World', type: 'text' });
expect(mockOnMessageHandle).toHaveBeenNthCalledWith(2, { text: 'o', type: 'text' });
expect(mockOnMessageHandle).toHaveBeenNthCalledWith(3, { text: ' World', type: 'text' });
// more assertions for each character...
expect(mockOnFinish).toHaveBeenCalledWith('Hello World', {
observationId: null,
Expand Down Expand Up @@ -184,7 +187,7 @@ describe('fetchSSE', () => {

// TODO: need to check whether the `aarg1` is correct
expect(mockOnMessageHandle).toHaveBeenNthCalledWith(1, {
isAnimationActives: [true],
isAnimationActives: [true, true],
tool_calls: [
{ id: '1', type: 'function', function: { name: 'func1', arguments: 'aarg1' } },
{ function: { arguments: 'aarg2', name: 'func2' }, id: '2', type: 'function' },
Expand Down Expand Up @@ -218,9 +221,10 @@ describe('fetchSSE', () => {
const abortController = new AbortController();

(fetchEventSource as any).mockImplementationOnce(
(url: string, options: FetchEventSourceInit) => {
async (url: string, options: FetchEventSourceInit) => {
options.onopen!({ clone: () => ({ ok: true, headers: new Headers() }) } as any);
options.onmessage!({ event: 'text', data: JSON.stringify('Hello') } as any);
await sleep(100);
abortController.abort();
options.onmessage!({ event: 'text', data: JSON.stringify(' World') } as any);
},
Expand All @@ -233,7 +237,8 @@ describe('fetchSSE', () => {
});

expect(mockOnMessageHandle).toHaveBeenNthCalledWith(1, { text: 'Hell', type: 'text' });
expect(mockOnMessageHandle).toHaveBeenNthCalledWith(2, { text: 'o World', type: 'text' });
expect(mockOnMessageHandle).toHaveBeenNthCalledWith(2, { text: 'o', type: 'text' });
expect(mockOnMessageHandle).toHaveBeenNthCalledWith(3, { text: ' World', type: 'text' });

expect(mockOnFinish).toHaveBeenCalledWith('Hello World', {
type: 'done',
Expand Down
45 changes: 22 additions & 23 deletions src/utils/fetch/fetchSSE.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,17 +52,15 @@ const createSmoothMessage = (params: { onTextUpdate: (delta: string, text: strin
let buffer = '';
// why use queue: https://shareg.pt/GLBrjpK
let outputQueue: string[] = [];

// eslint-disable-next-line no-undef
let animationTimeoutId: NodeJS.Timeout | null = null;
let isAnimationActive = false;
let animationFrameId: number | null = null;

// when you need to stop the animation, call this function
const stopAnimation = () => {
isAnimationActive = false;
if (animationTimeoutId !== null) {
clearTimeout(animationTimeoutId);
animationTimeoutId = null;
if (animationFrameId !== null) {
cancelAnimationFrame(animationFrameId);
animationFrameId = null;
}
};

Expand All @@ -80,32 +78,33 @@ const createSmoothMessage = (params: { onTextUpdate: (delta: string, text: strin
const updateText = () => {
// 如果动画已经不再激活,则停止更新文本
if (!isAnimationActive) {
clearTimeout(animationTimeoutId!);
animationTimeoutId = null;
cancelAnimationFrame(animationFrameId!);
animationFrameId = null;
resolve();
return;
}

// 如果还有文本没有显示
// 检查队列中是否有字符待显示
if (outputQueue.length > 0) {
// 从队列中获取前两个字符(如果存在)
// 从队列中获取前 n 个字符(如果存在)
const charsToAdd = outputQueue.splice(0, speed).join('');
buffer += charsToAdd;

// 更新消息内容,这里可能需要结合实际情况调整
params.onTextUpdate(charsToAdd, buffer);

// 设置下一个字符的延迟
animationTimeoutId = setTimeout(updateText, 16); // 16 毫秒的延迟模拟打字机效果
} else {
// 当所有字符都显示完毕时,清除动画状态
isAnimationActive = false;
animationTimeoutId = null;
animationFrameId = null;
resolve();
return;
}

animationFrameId = requestAnimationFrame(updateText);
};

updateText();
animationFrameId = requestAnimationFrame(updateText);
});

const pushToQueue = (text: string) => {
Expand All @@ -128,16 +127,15 @@ const createSmoothToolCalls = (params: {

// 为每个 tool_call 维护一个输出队列和动画控制器

// eslint-disable-next-line no-undef
const animationTimeoutIds: (NodeJS.Timeout | null)[] = [];
const outputQueues: string[][] = [];
const isAnimationActives: boolean[] = [];
const animationFrameIds: (number | null)[] = [];

const stopAnimation = (index: number) => {
isAnimationActives[index] = false;
if (animationTimeoutIds[index] !== null) {
clearTimeout(animationTimeoutIds[index]!);
animationTimeoutIds[index] = null;
if (animationFrameIds[index] !== null) {
cancelAnimationFrame(animationFrameIds[index]!);
animationFrameIds[index] = null;
}
};

Expand All @@ -153,6 +151,7 @@ const createSmoothToolCalls = (params: {
const updateToolCall = () => {
if (!isAnimationActives[index]) {
resolve();
return;
}

if (outputQueues[index].length > 0) {
Expand All @@ -167,15 +166,15 @@ const createSmoothToolCalls = (params: {
params.onToolCallsUpdate(toolCallsBuffer, [...isAnimationActives]);
}

animationTimeoutIds[index] = setTimeout(updateToolCall, 16);
animationFrameIds[index] = requestAnimationFrame(() => updateToolCall());
} else {
isAnimationActives[index] = false;
animationTimeoutIds[index] = null;
animationFrameIds[index] = null;
resolve();
}
};

updateToolCall();
animationFrameIds[index] = requestAnimationFrame(() => updateToolCall());
});

const pushToQueue = (toolCallChunks: MessageToolCallChunk[]) => {
Expand All @@ -188,7 +187,7 @@ const createSmoothToolCalls = (params: {
if (!outputQueues[chunk.index]) {
outputQueues[chunk.index] = [];
isAnimationActives[chunk.index] = false;
animationTimeoutIds[chunk.index] = null;
animationFrameIds[chunk.index] = null;
}

outputQueues[chunk.index].push(...(chunk.function?.arguments || '').split(''));
Expand Down

0 comments on commit 0e19893

Please sign in to comment.