Skip to content

Commit

Permalink
⚡ perf: improve message loading (lobehub#2097)
Browse files Browse the repository at this point in the history
* ⚡ perf: improve message loading

* ⚡ perf: improve topic loading
  • Loading branch information
arvinxx committed Apr 19, 2024
1 parent 630433e commit 148825b
Show file tree
Hide file tree
Showing 12 changed files with 399 additions and 65 deletions.
4 changes: 2 additions & 2 deletions src/app/chat/features/TopicListContent/Topic/TopicContent.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ const TopicContent = memo<TopicContentProps>(({ id, title, fav, showMore }) => {
modal.confirm({
centered: true,
okButtonProps: { danger: true },
onOk: () => {
removeTopic(id);
onOk: async () => {
await removeTopic(id);
},
title: t('topic.confirmRemoveTopic', { ns: 'chat' }),
});
Expand Down
8 changes: 6 additions & 2 deletions src/features/ChatInput/Topic/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import { useTranslation } from 'react-i18next';

import HotKeys from '@/components/HotKeys';
import { PREFIX_KEY, SAVE_TOPIC_KEY } from '@/const/hotkeys';
import { useActionSWR } from '@/libs/swr';
import { useChatStore } from '@/store/chat';

const SaveTopic = memo<{ mobile?: boolean }>(({ mobile }) => {
Expand All @@ -16,20 +17,23 @@ const SaveTopic = memo<{ mobile?: boolean }>(({ mobile }) => {
s.openNewTopicOrSaveTopic,
]);

const { mutate, isValidating } = useActionSWR('openNewTopicOrSaveTopic', openNewTopicOrSaveTopic);

const icon = hasTopic ? LucideMessageSquarePlus : LucideGalleryVerticalEnd;
const Render = mobile ? ActionIcon : Button;
const iconRender: any = mobile ? icon : <Icon icon={icon} />;
const desc = t(hasTopic ? 'topic.openNewTopic' : 'topic.saveCurrentMessages');

const hotkeys = [PREFIX_KEY, SAVE_TOPIC_KEY].join('+');
useHotkeys(hotkeys, openNewTopicOrSaveTopic, {

useHotkeys(hotkeys, () => mutate(), {
enableOnFormTags: true,
preventDefault: true,
});

return (
<Tooltip title={<HotKeys desc={desc} keys={hotkeys} />}>
<Render aria-label={desc} icon={iconRender} onClick={openNewTopicOrSaveTopic} />
<Render aria-label={desc} icon={iconRender} loading={isValidating} onClick={() => mutate()} />
</Tooltip>
);
});
Expand Down
11 changes: 8 additions & 3 deletions src/features/Conversation/components/ChatItem/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ import ActionsBar from './ActionsBar';
import HistoryDivider from './HistoryDivider';

const useStyles = createStyles(({ css, prefixCls }) => ({
loading: css`
opacity: 0.6;
`,
message: css`
// prevent the textarea too long
.${prefixCls}-input {
Expand All @@ -35,7 +38,7 @@ export interface ChatListItemProps {
const Item = memo<ChatListItemProps>(({ index, id }) => {
const fontSize = useGlobalStore((s) => settingsSelectors.currentSettings(s).fontSize);
const { t } = useTranslation('common');
const { styles } = useStyles();
const { styles, cx } = useStyles();
const [editing, setEditing] = useState(false);
const [type = 'chat'] = useSessionStore((s) => {
const config = agentSelectors.currentAgentConfig(s);
Expand All @@ -54,10 +57,12 @@ const Item = memo<ChatListItemProps>(({ index, id }) => {
const historyLength = useChatStore((s) => chatSelectors.currentChats(s).length);

const [loading, updateMessageContent] = useChatStore((s) => [
s.chatLoadingId === id,
s.chatLoadingId === id || s.messageLoadingIds.includes(id),
s.modifyMessageContent,
]);

const [isMessageLoading] = useChatStore((s) => [s.messageLoadingIds.includes(id)]);

const onAvatarsClick = useAvatarsClick();

const RenderMessage = useCallback(
Expand Down Expand Up @@ -110,7 +115,7 @@ const Item = memo<ChatListItemProps>(({ index, id }) => {
<ChatItem
actions={<ActionsBar index={index} setEditing={setEditing} />}
avatar={item.meta}
className={styles.message}
className={cx(styles.message, isMessageLoading && styles.loading)}
editing={editing}
error={error}
errorMessage={<ErrorMessageExtra data={item} />}
Expand Down
9 changes: 9 additions & 0 deletions src/libs/swr/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,12 @@ export const useActionSWR: SWRHook = (key, fetch, config) =>
revalidateOnReconnect: false,
...config,
});

export interface SWRRefreshParams<T, A = (...args: any[]) => any> {
action: A;
optimisticData?: (data: T | undefined) => T;
}

export type SWRefreshMethod<T> = <A extends (...args: any[]) => Promise<any>>(
params?: SWRRefreshParams<T, A>,
) => ReturnType<A>;
122 changes: 80 additions & 42 deletions src/store/chat/slices/message/action.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
/* eslint-disable sort-keys-fix/sort-keys-fix, typescript-sort-keys/interface */
// Disable the auto sort key eslint rule to make the code more logic and readable
import { copyToClipboard } from '@lobehub/ui';
import { produce } from 'immer';
import { template } from 'lodash-es';
import { SWRResponse, mutate } from 'swr';
import { StateCreator } from 'zustand/vanilla';
Expand All @@ -19,6 +20,7 @@ import { agentSelectors } from '@/store/session/selectors';
import { ChatMessage } from '@/types/message';
import { TraceEventPayloads } from '@/types/trace';
import { setNamespace } from '@/utils/storeDebug';
import { nanoid } from '@/utils/uuid';

import { chatSelectors } from '../../selectors';
import { MessageDispatch, messagesReducer } from './reducer';
Expand Down Expand Up @@ -97,6 +99,7 @@ export interface ChatMessageAction {
id?: string,
action?: string,
) => AbortController | undefined;
toggleMessageLoading: (loading: boolean, id: string) => void;
refreshMessages: () => Promise<void>;
// TODO: 后续 smoothMessage 实现考虑落到 sse 这一层
createSmoothMessage: (id: string) => {
Expand All @@ -111,6 +114,7 @@ export interface ChatMessageAction {
* @param content
*/
internalUpdateMessageContent: (id: string, content: string) => Promise<void>;
internalCreateMessage: (params: CreateMessageParams) => Promise<string>;
internalResendMessage: (id: string, traceId?: string) => Promise<void>;
internalTraceMessage: (id: string, payload: TraceEventPayloads) => Promise<void>;
}
Expand All @@ -130,6 +134,7 @@ export const chatMessage: StateCreator<
ChatMessageAction
> = (set, get) => ({
deleteMessage: async (id) => {
get().dispatchMessage({ type: 'deleteMessage', id });
await messageService.removeMessage(id);
await get().refreshMessages();
},
Expand Down Expand Up @@ -167,43 +172,6 @@ export const chatMessage: StateCreator<
await messageService.removeAllMessages();
await refreshMessages();
},
internalResendMessage: async (messageId, traceId) => {
// 1. 构造所有相关的历史记录
const chats = chatSelectors.currentChats(get());

const currentIndex = chats.findIndex((c) => c.id === messageId);
if (currentIndex < 0) return;

const currentMessage = chats[currentIndex];

let contextMessages: ChatMessage[] = [];

switch (currentMessage.role) {
case 'function':
case 'user': {
contextMessages = chats.slice(0, currentIndex + 1);
break;
}
case 'assistant': {
// 消息是 AI 发出的因此需要找到它的 user 消息
const userId = currentMessage.parentId;
const userIndex = chats.findIndex((c) => c.id === userId);
// 如果消息没有 parentId,那么同 user/function 模式
contextMessages = chats.slice(0, userIndex < 0 ? currentIndex + 1 : userIndex + 1);
break;
}
}

if (contextMessages.length <= 0) return;

const { coreProcessMessage } = get();

const latestMsg = contextMessages.filter((s) => s.role === 'user').at(-1);

if (!latestMsg) return;

await coreProcessMessage(contextMessages, latestMsg.id, traceId);
},
sendMessage: async ({ message, files, onlyAddUserMessage }) => {
const { coreProcessMessage, activeTopicId, activeId } = get();
if (!activeId) return;
Expand All @@ -223,8 +191,7 @@ export const chatMessage: StateCreator<
topicId: activeTopicId,
};

const id = await messageService.createMessage(newMessage);
await get().refreshMessages();
const id = await get().internalCreateMessage(newMessage);

// if only add user message, then stop
if (onlyAddUserMessage) return;
Expand Down Expand Up @@ -315,8 +282,7 @@ export const chatMessage: StateCreator<
topicId: activeTopicId, // if there is activeTopicId,then add it to topicId
};

const mid = await messageService.createMessage(assistantMessage);
await refreshMessages();
const mid = await get().internalCreateMessage(assistantMessage);

// 2. fetch the AI response
const { isFunctionCall, content, functionCallAtEnd, functionCallContent, traceId } =
Expand Down Expand Up @@ -344,7 +310,7 @@ export const chatMessage: StateCreator<
traceId,
};

functionId = await messageService.createMessage(functionMessage);
functionId = await get().internalCreateMessage(functionMessage);
}

await refreshMessages();
Expand Down Expand Up @@ -533,6 +499,62 @@ export const chatMessage: StateCreator<
window.removeEventListener('beforeunload', preventLeavingFn);
}
},
toggleMessageLoading: (loading, id) => {
set(
{
messageLoadingIds: produce(get().messageLoadingIds, (draft) => {
if (loading) {
draft.push(id);
} else {
const index = draft.indexOf(id);

if (index >= 0) draft.splice(index, 1);
}
}),
},
false,
'toggleMessageLoading',
);
},

internalResendMessage: async (messageId, traceId) => {
// 1. 构造所有相关的历史记录
const chats = chatSelectors.currentChats(get());

const currentIndex = chats.findIndex((c) => c.id === messageId);
if (currentIndex < 0) return;

const currentMessage = chats[currentIndex];

let contextMessages: ChatMessage[] = [];

switch (currentMessage.role) {
case 'function':
case 'user': {
contextMessages = chats.slice(0, currentIndex + 1);
break;
}
case 'assistant': {
// 消息是 AI 发出的因此需要找到它的 user 消息
const userId = currentMessage.parentId;
const userIndex = chats.findIndex((c) => c.id === userId);
// 如果消息没有 parentId,那么同 user/function 模式
contextMessages = chats.slice(0, userIndex < 0 ? currentIndex + 1 : userIndex + 1);
break;
}
}

if (contextMessages.length <= 0) return;

const { coreProcessMessage } = get();

const latestMsg = contextMessages.filter((s) => s.role === 'user').at(-1);

if (!latestMsg) return;

await coreProcessMessage(contextMessages, latestMsg.id, traceId);
},

internalUpdateMessageContent: async (id, content) => {
const { dispatchMessage, refreshMessages } = get();

Expand All @@ -545,6 +567,22 @@ export const chatMessage: StateCreator<
await refreshMessages();
},

internalCreateMessage: async (message) => {
const { dispatchMessage, refreshMessages, toggleMessageLoading } = get();

// use optimistic update to avoid the slow waiting
const tempId = 'tmp_' + nanoid();
dispatchMessage({ type: 'createMessage', id: tempId, value: message });

toggleMessageLoading(true, tempId);
const id = await messageService.createMessage(message);

await refreshMessages();
toggleMessageLoading(false, tempId);

return id;
},

createSmoothMessage: (id) => {
const { dispatchMessage } = get();

Expand Down
2 changes: 1 addition & 1 deletion src/store/chat/slices/message/initialState.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ export interface ChatMessageState {
activeId: string;
chatLoadingId?: string;
inputMessage: string;
messageLoadingIds: [];
messageLoadingIds: string[];
messages: ChatMessage[];
/**
* whether messages have fetched
Expand Down
33 changes: 32 additions & 1 deletion src/store/chat/slices/message/reducer.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import isEqual from 'fast-deep-equal';
import { produce } from 'immer';

import { CreateMessageParams } from '@/services/message';
import { ChatMessage } from '@/types/message';
import { merge } from '@/utils/merge';

Expand All @@ -10,6 +11,15 @@ interface UpdateMessage {
type: 'updateMessage';
value: ChatMessage[keyof ChatMessage];
}
interface CreateMessage {
id: string;
type: 'createMessage';
value: CreateMessageParams;
}
interface DeleteMessage {
id: string;
type: 'deleteMessage';
}

interface UpdatePluginState {
id: string;
Expand All @@ -24,7 +34,12 @@ interface UpdateMessageExtra {
value: any;
}

export type MessageDispatch = UpdateMessage | UpdatePluginState | UpdateMessageExtra;
export type MessageDispatch =
| CreateMessage
| UpdateMessage
| UpdatePluginState
| UpdateMessageExtra
| DeleteMessage;

export const messagesReducer = (state: ChatMessage[], payload: MessageDispatch): ChatMessage[] => {
switch (payload.type) {
Expand Down Expand Up @@ -76,6 +91,22 @@ export const messagesReducer = (state: ChatMessage[], payload: MessageDispatch):
});
}

case 'createMessage': {
return produce(state, (draftState) => {
const { value, id } = payload;

draftState.push({ ...value, createdAt: Date.now(), id, meta: {}, updatedAt: Date.now() });
});
}
case 'deleteMessage': {
return produce(state, (draft) => {
const { id } = payload;

const index = draft.findIndex((m) => m.id === id);

if (index >= 0) draft.splice(index, 1);
});
}
default: {
throw new Error('暂未实现的 type,请检查 reducer');
}
Expand Down
Loading

0 comments on commit 148825b

Please sign in to comment.