Skip to content

Commit

Permalink
🐛 fix: fix the missing user id in chat compeletition and fix remove u…
Browse files Browse the repository at this point in the history
…nstarred topic not working (lobehub#2677)

* 🐛 fix: fix remove all topic not working

* 🐛 fix: fix remove all topic not working

* 🐛 fix: fix user id missing in chat competition
  • Loading branch information
arvinxx authored May 27, 2024
1 parent 3fc4265 commit c9fb2de
Show file tree
Hide file tree
Showing 12 changed files with 76 additions and 51 deletions.
3 changes: 2 additions & 1 deletion src/app/api/chat/[provider]/route.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ describe('POST handler', () => {
accessCode: 'test-access-code',
apiKey: 'test-api-key',
azureApiVersion: 'v1',
userId: 'abc',
});

const mockParams = { provider: 'test-provider' };
Expand All @@ -176,7 +177,7 @@ describe('POST handler', () => {
const response = await POST(request as unknown as Request, { params: mockParams });

expect(response).toEqual(mockChatResponse);
expect(AgentRuntime.prototype.chat).toHaveBeenCalledWith(mockChatPayload);
expect(AgentRuntime.prototype.chat).toHaveBeenCalledWith(mockChatPayload, { user: 'abc' });
});

it('should return an error response when chat completion fails', async () => {
Expand Down
15 changes: 7 additions & 8 deletions src/app/api/chat/[provider]/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,16 @@ export const POST = checkAuth(async (req: Request, { params, jwtPayload }) => {

const tracePayload = getTracePayload(req);

let traceOptions = {};
// If user enable trace
if (tracePayload?.enabled) {
return await agentRuntime.chat(
data,
createTraceOptions(data, {
provider,
trace: tracePayload,
}),
);
traceOptions = createTraceOptions(data, {
provider,
trace: tracePayload,
});
}
return await agentRuntime.chat(data);

return await agentRuntime.chat(data, { user: jwtPayload.userId, ...traceOptions });
} catch (e) {
const {
errorType = ChatErrorType.InternalServerError,
Expand Down
6 changes: 6 additions & 0 deletions src/const/auth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,11 @@ export interface JWTPayload {
awsAccessKeyId?: string;
awsRegion?: string;
awsSecretAccessKey?: string;
/**
* user id
* in client db mode it's a uuid
* in server db mode it's a user id
*/
userId?: string;
}
/* eslint-enable */
4 changes: 4 additions & 0 deletions src/libs/agent-runtime/types/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,10 @@ export interface ChatCompetitionOptions {
callback?: ChatStreamCallbacks;
headers?: Record<string, any>;
signal?: AbortSignal;
/**
* userId for the chat completion
*/
user?: string;
}

export interface ChatCompletionFunctions {
Expand Down
17 changes: 10 additions & 7 deletions src/libs/agent-runtime/utils/openaiCompatibleFactory/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -195,18 +195,21 @@ describe('LobeOpenAICompatibleFactory', () => {
});

describe('handlePayload option', () => {
it('should modify request payload correctly', async () => {
it('should add user in payload correctly', async () => {
const mockCreateMethod = vi.spyOn(instance['client'].chat.completions, 'create');

await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'mistralai/mistral-7b-instruct:free',
temperature: 0,
});
await instance.chat(
{
messages: [{ content: 'Hello', role: 'user' }],
model: 'mistralai/mistral-7b-instruct:free',
temperature: 0,
},
{ user: 'abc' },
);

expect(mockCreateMethod).toHaveBeenCalledWith(
expect.objectContaining({
// 根据实际的 handlePayload 函数,添加断言
user: 'abc',
}),
expect.anything(),
);
Expand Down
13 changes: 8 additions & 5 deletions src/libs/agent-runtime/utils/openaiCompatibleFactory/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,14 @@ export const LobeOpenAICompatibleFactory = ({
stream: payload.stream ?? true,
} as OpenAI.ChatCompletionCreateParamsStreaming);

const response = await this.client.chat.completions.create(postPayload, {
// https://github.com/lobehub/lobe-chat/pull/318
headers: { Accept: '*/*' },
signal: options?.signal,
});
const response = await this.client.chat.completions.create(
{ ...postPayload, user: options?.user },
{
// https://github.com/lobehub/lobe-chat/pull/318
headers: { Accept: '*/*' },
signal: options?.signal,
},
);

if (postPayload.stream) {
const [prod, useForDebug] = response.tee();
Expand Down
9 changes: 7 additions & 2 deletions src/services/_auth.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import { JWTPayload, LOBE_CHAT_AUTH_HEADER } from '@/const/auth';
import { ModelProvider } from '@/libs/agent-runtime';
import { useUserStore } from '@/store/user';
import { keyVaultsConfigSelectors, settingsSelectors } from '@/store/user/selectors';
import {
keyVaultsConfigSelectors,
settingsSelectors,
userProfileSelectors,
} from '@/store/user/selectors';
import { GlobalLLMProviderKey } from '@/types/user/settings';
import { createJWT } from '@/utils/jwt';

Expand Down Expand Up @@ -48,8 +52,9 @@ export const getProviderAuthPayload = (provider: string) => {

const createAuthTokenWithPayload = async (payload = {}) => {
const accessCode = settingsSelectors.password(useUserStore.getState());
const userId = userProfileSelectors.userId(useUserStore.getState());

return await createJWT<JWTPayload>({ accessCode, ...payload });
return await createJWT<JWTPayload>({ accessCode, userId, ...payload });
};

interface AuthParams {
Expand Down
18 changes: 12 additions & 6 deletions src/store/chat/slices/topic/action.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -384,11 +384,14 @@ describe('topic action', () => {
// Set up mock state with unstarred topics
await act(async () => {
useChatStore.setState({
topics: [
{ id: 'topic-1', favorite: false },
{ id: 'topic-2', favorite: true },
{ id: 'topic-3', favorite: false },
] as ChatTopic[],
activeId: 'abc',
topicMaps: {
abc: [
{ id: 'topic-1', favorite: false },
{ id: 'topic-2', favorite: true },
{ id: 'topic-3', favorite: false },
] as ChatTopic[],
},
});
});
const refreshTopicSpy = vi.spyOn(result.current, 'refreshTopic');
Expand Down Expand Up @@ -431,7 +434,10 @@ describe('topic action', () => {
});

// Mock the `updateTopicTitleInSummary` and `refreshTopic` for spying
const updateTopicTitleInSummarySpy = vi.spyOn(result.current, 'updateTopicTitleInSummary');
const updateTopicTitleInSummarySpy = vi.spyOn(
result.current,
'internal_updateTopicTitleInSummary',
);
const refreshTopicSpy = vi.spyOn(result.current, 'refreshTopic');

// Mock the `chatService.fetchPresetTaskResult` to simulate the AI response
Expand Down
35 changes: 17 additions & 18 deletions src/store/chat/slices/topic/action.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
// DON'T REMOVE THE FIRST LINE
import isEqual from 'fast-deep-equal';
import { t } from 'i18next';
import { produce } from 'immer';
import useSWR, { SWRResponse, mutate } from 'swr';
import { StateCreator } from 'zustand/vanilla';

Expand Down Expand Up @@ -37,19 +36,19 @@ export interface ChatTopicAction {
removeAllTopics: () => Promise<void>;
removeSessionTopics: () => Promise<void>;
removeTopic: (id: string) => Promise<void>;
removeUnstarredTopic: () => void;
removeUnstarredTopic: () => Promise<void>;
saveToTopic: () => Promise<string | undefined>;
createTopic: () => Promise<string | undefined>;

autoRenameTopicTitle: (id: string) => Promise<void>;
duplicateTopic: (id: string) => Promise<void>;
summaryTopicTitle: (topicId: string, messages: ChatMessage[]) => Promise<void>;
switchTopic: (id?: string, skipRefreshMessage?: boolean) => Promise<void>;
updateTopicTitleInSummary: (id: string, title: string) => void;
updateTopicTitle: (id: string, title: string) => Promise<void>;
useFetchTopics: (sessionId: string) => SWRResponse<ChatTopic[]>;
useSearchTopics: (keywords?: string, sessionId?: string) => SWRResponse<ChatTopic[]>;

internal_updateTopicTitleInSummary: (id: string, title: string) => void;
internal_updateTopicLoading: (id: string, loading: boolean) => void;
internal_createTopic: (params: CreateTopicParams) => Promise<string>;
internal_updateTopic: (id: string, data: Partial<ChatTopic>) => Promise<void>;
Expand Down Expand Up @@ -133,18 +132,18 @@ export const chatTopic: StateCreator<
},
// update
summaryTopicTitle: async (topicId, messages) => {
const { updateTopicTitleInSummary, internal_updateTopicLoading } = get();
const { internal_updateTopicTitleInSummary, internal_updateTopicLoading } = get();
const topic = topicSelectors.getTopicById(topicId)(get());
if (!topic) return;

updateTopicTitleInSummary(topicId, LOADING_FLAT);
internal_updateTopicTitleInSummary(topicId, LOADING_FLAT);

let output = '';

// 自动总结话题标题
await chatService.fetchPresetTaskResult({
onError: () => {
updateTopicTitleInSummary(topicId, topic.title);
internal_updateTopicTitleInSummary(topicId, topic.title);
},
onFinish: async (text) => {
await get().internal_updateTopic(topicId, { title: text });
Expand All @@ -159,7 +158,7 @@ export const chatTopic: StateCreator<
}
}

updateTopicTitleInSummary(topicId, output);
internal_updateTopicTitleInSummary(topicId, output);
},
params: await chainSummaryTitle(messages),
trace: get().getCurrentTracePayload({ traceName: TraceNameMap.SummaryTopicTitle, topicId }),
Expand Down Expand Up @@ -264,15 +263,11 @@ export const chatTopic: StateCreator<
},

// Internal process method of the topics
updateTopicTitleInSummary: (id, title) => {
const topics = produce(get().topics, (draftState) => {
const topic = draftState.find((i) => i.id === id);

if (!topic) return;
topic.title = title;
});

set({ topics }, false, n(`updateTopicTitleInSummary`, { id, title }));
internal_updateTopicTitleInSummary: (id, title) => {
get().internal_dispatchTopic(
{ type: 'updateTopic', id, value: { title } },
'updateTopicTitleInSummary',
);
},
refreshTopic: async () => {
return mutate([SWR_USE_FETCH_TOPIC, get().activeId]);
Expand Down Expand Up @@ -317,8 +312,12 @@ export const chatTopic: StateCreator<
},

internal_dispatchTopic: (payload, action) => {
const nextTopics = topicReducer(get().topics, payload);
const nextTopics = topicReducer(topicSelectors.currentTopics(get()), payload);
const nextMap = { ...get().topicMaps, [get().activeId]: nextTopics };

// no need to update map if is the same
if (isEqual(nextMap, get().topicMaps)) return;

set({ topics: nextTopics }, false, action ?? n(`dispatchTopic/${payload.type}`));
set({ topicMaps: nextMap }, false, action ?? n(`dispatchTopic/${payload.type}`));
},
});
2 changes: 0 additions & 2 deletions src/store/chat/slices/topic/initialState.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ export interface ChatTopicState {
topicMaps: Record<string, ChatTopic[]>;
topicRenamingId?: string;
topicSearchKeywords: string;
topics: ChatTopic[];
/**
* whether topics have fetched
*/
Expand All @@ -23,6 +22,5 @@ export const initialTopicState: ChatTopicState = {
topicLoadingIds: [],
topicMaps: {},
topicSearchKeywords: '',
topics: [],
topicsInit: false,
};
2 changes: 1 addition & 1 deletion src/store/chat/slices/topic/selectors.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ describe('topicSelectors', () => {

describe('currentUnFavTopics', () => {
it('should return all unfavorited topics', () => {
const state = merge(initialStore, { topics: topicMaps.test });
const state = merge(initialStore, { topicMaps, activeId: 'test' });
const topics = topicSelectors.currentUnFavTopics(state);
expect(topics).toEqual([topicMaps.test[1]]);
});
Expand Down
3 changes: 2 additions & 1 deletion src/store/chat/slices/topic/selectors.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ const searchTopics = (s: ChatStore): ChatTopic[] => s.searchTopics;
const displayTopics = (s: ChatStore): ChatTopic[] | undefined =>
s.isSearchingTopic ? searchTopics(s) : currentTopics(s);

const currentUnFavTopics = (s: ChatStore): ChatTopic[] => s.topics.filter((s) => !s.favorite);
const currentUnFavTopics = (s: ChatStore): ChatTopic[] =>
currentTopics(s)?.filter((s) => !s.favorite) || [];

const currentTopicLength = (s: ChatStore): number => currentTopics(s)?.length || 0;

Expand Down

0 comments on commit c9fb2de

Please sign in to comment.