diff --git a/src/app/(main)/settings/llm/components/ProviderConfig/index.tsx b/src/app/(main)/settings/llm/components/ProviderConfig/index.tsx index dbc6f6f45a8e4..1f3d9f0f75255 100644 --- a/src/app/(main)/settings/llm/components/ProviderConfig/index.tsx +++ b/src/app/(main)/settings/llm/components/ProviderConfig/index.tsx @@ -50,6 +50,8 @@ interface ProviderConfigProps { canDeactivate?: boolean; checkModel?: string; checkerItem?: FormItemProps; + className?: string; + hideSwitch?: boolean; modelList?: { azureDeployName?: boolean; notFoundContent?: ReactNode; @@ -81,11 +83,12 @@ const ProviderConfig = memo( checkerItem, modelList, showBrowserRequest, + className, }) => { const { t } = useTranslation('setting'); const { t: modelT } = useTranslation('modelProvider'); const [form] = Form.useForm(); - const { styles } = useStyles(); + const { cx, styles } = useStyles(); const [ toggleProviderEnabled, setSettings, @@ -192,7 +195,7 @@ const ProviderConfig = memo( return (
(({ id, provider }) => { e.stopPropagation(); e.preventDefault(); - const isConfirm = await modal.confirm({ + await modal.confirm({ centered: true, content: s('llm.customModelCards.confirmDelete'), okButtonProps: { danger: true }, + onOk: async () => { + // delete model and deactivate id + await dispatchCustomModelCards(provider, { id, type: 'delete' }); + await removeEnabledModels(provider, id); + }, type: 'warning', }); - - // delete model and deactive id - if (isConfirm) { - await dispatchCustomModelCards(provider, { id, type: 'delete' }); - await removeEnabledModels(provider, id); - } }} title={t('delete')} /> diff --git a/src/app/(main)/settings/llm/components/ProviderModelList/ModelConfigModal.tsx b/src/app/(main)/settings/llm/components/ProviderModelList/ModelConfigModal/Form.tsx similarity index 56% rename from src/app/(main)/settings/llm/components/ProviderModelList/ModelConfigModal.tsx rename to src/app/(main)/settings/llm/components/ProviderModelList/ModelConfigModal/Form.tsx index 267fc7889edbf..f2e715bb69e4c 100644 --- a/src/app/(main)/settings/llm/components/ProviderModelList/ModelConfigModal.tsx +++ b/src/app/(main)/settings/llm/components/ProviderModelList/ModelConfigModal/Form.tsx @@ -1,71 +1,28 @@ -import { Modal } from '@lobehub/ui'; -import { Button, Checkbox, Form, Input } from 'antd'; -import isEqual from 'fast-deep-equal'; -import { memo } from 'react'; +import { Checkbox, Form, FormInstance, Input } from 'antd'; +import { memo, useEffect } from 'react'; import { useTranslation } from 'react-i18next'; -import { useUserStore } from '@/store/user'; -import { modelConfigSelectors } from '@/store/user/selectors'; +import { ChatModelCard } from '@/types/llm'; import MaxTokenSlider from './MaxTokenSlider'; -interface ModelConfigModalProps { - provider?: string; +interface ModelConfigFormProps { + initialValues?: ChatModelCard; + onFormInstanceReady: (instance: FormInstance) => void; showAzureDeployName?: boolean; } -const ModelConfigModal = memo(({ showAzureDeployName, provider }) => { - const [formInstance] = Form.useForm(); - const { t } = useTranslation('setting'); - const { t: tc } = useTranslation('common'); - const [open, id, editingProvider, dispatchCustomModelCards, toggleEditingCustomModelCard] = - useUserStore((s) => [ - !!s.editingCustomCardModel && provider === s.editingCustomCardModel?.provider, - s.editingCustomCardModel?.id, - s.editingCustomCardModel?.provider, - s.dispatchCustomModelCards, - s.toggleEditingCustomModelCard, - ]); +const ModelConfigForm = memo( + ({ showAzureDeployName, onFormInstanceReady, initialValues }) => { + const { t } = useTranslation('setting'); - const modelCard = useUserStore( - modelConfigSelectors.getCustomModelCard({ id, provider: editingProvider }), - isEqual, - ); + const [formInstance] = Form.useForm(); - const closeModal = () => { - toggleEditingCustomModelCard(undefined); - }; + useEffect(() => { + onFormInstanceReady(formInstance); + }, []); - return ( - - {tc('cancel')} - , - - , - ]} - maskClosable - onCancel={closeModal} - open={open} - title={t('llm.customModelCards.modelConfig.modalTitle')} - zIndex={1051} // Select is 1050 - > + return (
{ e.stopPropagation(); @@ -77,9 +34,8 @@ const ModelConfigModal = memo(({ showAzureDeployName, pro @@ -136,7 +92,7 @@ const ModelConfigModal = memo(({ showAzureDeployName, pro
-
- ); -}); -export default ModelConfigModal; + ); + }, +); +export default ModelConfigForm; diff --git a/src/app/(main)/settings/llm/components/ProviderModelList/MaxTokenSlider.tsx b/src/app/(main)/settings/llm/components/ProviderModelList/ModelConfigModal/MaxTokenSlider.tsx similarity index 100% rename from src/app/(main)/settings/llm/components/ProviderModelList/MaxTokenSlider.tsx rename to src/app/(main)/settings/llm/components/ProviderModelList/ModelConfigModal/MaxTokenSlider.tsx diff --git a/src/app/(main)/settings/llm/components/ProviderModelList/ModelConfigModal/index.tsx b/src/app/(main)/settings/llm/components/ProviderModelList/ModelConfigModal/index.tsx new file mode 100644 index 0000000000000..aa022076d590d --- /dev/null +++ b/src/app/(main)/settings/llm/components/ProviderModelList/ModelConfigModal/index.tsx @@ -0,0 +1,78 @@ +import { Modal } from '@lobehub/ui'; +import { Button, FormInstance } from 'antd'; +import isEqual from 'fast-deep-equal'; +import { memo, useState } from 'react'; +import { useTranslation } from 'react-i18next'; + +import { useUserStore } from '@/store/user'; +import { modelConfigSelectors } from '@/store/user/slices/modelList/selectors'; + +import ModelConfigForm from './Form'; + +interface ModelConfigModalProps { + provider?: string; + showAzureDeployName?: boolean; +} + +const ModelConfigModal = memo(({ showAzureDeployName, provider }) => { + const { t } = useTranslation('setting'); + const { t: tc } = useTranslation('common'); + const [formInstance, setFormInstance] = useState(); + + const [open, id, editingProvider, dispatchCustomModelCards, toggleEditingCustomModelCard] = + useUserStore((s) => [ + !!s.editingCustomCardModel && provider === s.editingCustomCardModel?.provider, + s.editingCustomCardModel?.id, + s.editingCustomCardModel?.provider, + s.dispatchCustomModelCards, + s.toggleEditingCustomModelCard, + ]); + + const modelCard = useUserStore( + modelConfigSelectors.getCustomModelCard({ id, provider: editingProvider }), + isEqual, + ); + + const closeModal = () => { + toggleEditingCustomModelCard(undefined); + }; + + return ( + + {tc('cancel')} + , + + , + ]} + maskClosable + onCancel={closeModal} + open={open} + title={t('llm.customModelCards.modelConfig.modalTitle')} + zIndex={1251} // Select is 1150 + > + + + ); +}); +export default ModelConfigModal; diff --git a/src/app/(main)/settings/llm/components/ProviderModelList/Option.tsx b/src/app/(main)/settings/llm/components/ProviderModelList/Option.tsx index 401499b22eea7..f47f23404a91b 100644 --- a/src/app/(main)/settings/llm/components/ProviderModelList/Option.tsx +++ b/src/app/(main)/settings/llm/components/ProviderModelList/Option.tsx @@ -1,6 +1,10 @@ +import { ActionIcon, Tooltip } from '@lobehub/ui'; import { Typography } from 'antd'; +import { useTheme } from 'antd-style'; import isEqual from 'fast-deep-equal'; +import { Recycle } from 'lucide-react'; import { memo } from 'react'; +import { useTranslation } from 'react-i18next'; import { Flexbox } from 'react-layout-kit'; import ModelIcon from '@/components/ModelIcon'; @@ -16,25 +20,45 @@ interface OptionRenderProps { id: string; isAzure?: boolean; provider: GlobalLLMProviderKey; + removed?: boolean; } -const OptionRender = memo(({ displayName, id, provider, isAzure }) => { +const OptionRender = memo(({ displayName, id, provider, isAzure, removed }) => { const model = useUserStore((s) => modelProviderSelectors.getModelCardById(id)(s), isEqual); - + const { t } = useTranslation('components'); + const theme = useTheme(); // if there is isCustom, it means it is a user defined custom model if (model?.isCustom || isAzure) return ; return ( - - - - - {displayName} - + + + + + + {displayName} + + + + {id} + - - {id} - + {removed && ( + + + + )} ); }); diff --git a/src/app/(main)/settings/llm/components/ProviderModelList/index.tsx b/src/app/(main)/settings/llm/components/ProviderModelList/index.tsx index 00173d33a57a9..fdebb77f4867f 100644 --- a/src/app/(main)/settings/llm/components/ProviderModelList/index.tsx +++ b/src/app/(main)/settings/llm/components/ProviderModelList/index.tsx @@ -51,10 +51,9 @@ const ProviderModelListSelect = memo( ({ showModelFetcher = false, provider, showAzureDeployName, notFoundContent, placeholder }) => { const { t } = useTranslation('common'); const { t: transSetting } = useTranslation('setting'); - const [setModelProviderConfig, dispatchCustomModelCards] = useUserStore((s) => [ + const [setModelProviderConfig, updateEnabledModels] = useUserStore((s) => [ s.setModelProviderConfig, - s.dispatchCustomModelCards, - s.useFetchProviderModelList, + s.updateEnabledModels, ]); const chatModelCards = useUserStore( @@ -94,21 +93,7 @@ const ProviderModelListSelect = memo( mode="tags" notFoundContent={notFoundContent} onChange={(value, options) => { - setModelProviderConfig(provider, { enabledModels: value.filter(Boolean) }); - - // if there is a new model, add it to `customModelCards` - options.forEach((option: { label?: string; value?: string }, index: number) => { - // if is a known model, it should have value - // if is an unknown model, the option will be {} - if (option.value) return; - - const modelId = value[index]; - - dispatchCustomModelCards(provider, { - modelCard: { id: modelId }, - type: 'add', - }); - }); + updateEnabledModels(provider, value, options as any[]); }} optionFilterProp="label" optionRender={({ label, value }) => { @@ -123,6 +108,18 @@ const ProviderModelListSelect = memo( /> ); + if (enabledModels?.some((m) => value === m)) { + return ( + + ); + } + // model is defined by user in client return ( diff --git a/src/components/ModelProviderIcon/index.tsx b/src/components/ModelProviderIcon/index.tsx index 6313a4ba866fc..3738e973d3dc5 100644 --- a/src/components/ModelProviderIcon/index.tsx +++ b/src/components/ModelProviderIcon/index.tsx @@ -5,6 +5,7 @@ import { DeepSeek, Google, Groq, + LobeHub, Minimax, Mistral, Moonshot, @@ -16,7 +17,6 @@ import { ZeroOne, Zhipu, } from '@lobehub/icons'; -import { Logo } from '@lobehub/ui'; import { memo } from 'react'; import { Center } from 'react-layout-kit'; @@ -29,7 +29,7 @@ interface ModelProviderIconProps { const ModelProviderIcon = memo(({ provider }) => { switch (provider) { case 'lobehub': { - return ; + return ; } case ModelProvider.ZhiPu: { diff --git a/src/components/ModelSelect/index.tsx b/src/components/ModelSelect/index.tsx index 76260314ef4f8..7c5fc94d8713c 100644 --- a/src/components/ModelSelect/index.tsx +++ b/src/components/ModelSelect/index.tsx @@ -79,10 +79,10 @@ export const ModelInfoTags = memo( return ( {model.files && ( -
@@ -90,9 +90,9 @@ export const ModelInfoTags = memo( )} {model.vision && ( -
@@ -128,15 +128,6 @@ export const ModelInfoTags = memo( )} - {/*{model.isCustom && (*/} - {/* */} - {/*
DIY
*/} - {/* */} - {/*)}*/} ); }, diff --git a/src/features/User/UserPanel/PanelContent.tsx b/src/features/User/UserPanel/PanelContent.tsx index 65d00e502a92a..4d6ba4e0b05a9 100644 --- a/src/features/User/UserPanel/PanelContent.tsx +++ b/src/features/User/UserPanel/PanelContent.tsx @@ -22,7 +22,7 @@ const PanelContent = memo<{ closePopover: () => void }>(({ closePopover }) => { s.logout, s.openUserProfile, s.enableAuth(), - s.enabledNextAuth(), + s.enabledNextAuth, ]); const { mainItems, logoutItems } = useMenu(); diff --git a/src/hooks/useSyncData.ts b/src/hooks/useSyncData.ts index 5148b25e12cbb..93698216e3b3b 100644 --- a/src/hooks/useSyncData.ts +++ b/src/hooks/useSyncData.ts @@ -1,6 +1,7 @@ import { useCallback } from 'react'; import { useChatStore } from '@/store/chat'; +import { featureFlagsSelectors, useServerConfigStore } from '@/store/serverConfig'; import { useSessionStore } from '@/store/session'; import { useUserStore } from '@/store/user'; import { syncSettingsSelectors } from '@/store/user/selectors'; @@ -42,7 +43,8 @@ export const useEnabledDataSync = () => { s.useEnabledSync, ]); + const { enableWebrtc } = useServerConfigStore(featureFlagsSelectors); const syncEvent = useSyncEvent(); - useEnabledSync(userEnableSync, userId, syncEvent); + useEnabledSync(enableWebrtc, { onEvent: syncEvent, userEnableSync, userId }); }; diff --git a/src/layout/GlobalProvider/StoreInitialization.tsx b/src/layout/GlobalProvider/StoreInitialization.tsx index 0959b2d221547..4683458f57e98 100644 --- a/src/layout/GlobalProvider/StoreInitialization.tsx +++ b/src/layout/GlobalProvider/StoreInitialization.tsx @@ -10,32 +10,40 @@ import { useIsMobile } from '@/hooks/useIsMobile'; import { useEnabledDataSync } from '@/hooks/useSyncData'; import { useAgentStore } from '@/store/agent'; import { useGlobalStore } from '@/store/global'; +import { useServerConfigStore } from '@/store/serverConfig'; import { useUserStore } from '@/store/user'; +import { authSelectors } from '@/store/user/selectors'; const StoreInitialization = memo(() => { - const [useFetchServerConfig, useFetchUserConfig, useInitPreference] = useUserStore((s) => [ - s.useFetchServerConfig, - s.useFetchUserConfig, - s.useInitPreference, + const router = useRouter(); + + const [useInitUserState, isLogin] = useUserStore((s) => [ + s.useInitUserState, + authSelectors.isLogin(s), ]); + + const { serverConfig } = useServerConfigStore(); + const useInitGlobalPreference = useGlobalStore((s) => s.useInitGlobalPreference); const useFetchDefaultAgentConfig = useAgentStore((s) => s.useFetchDefaultAgentConfig); // init the system preference - useInitPreference(); useInitGlobalPreference(); - useFetchDefaultAgentConfig(); - const { isLoading } = useFetchServerConfig(); - useFetchUserConfig(!isLoading); + useInitUserState(isLogin, serverConfig, { + onSuccess: (state) => { + if (state.isOnboard === false) { + router.push('/onboard'); + } + }, + }); useEnabledDataSync(); const useStoreUpdater = createStoreUpdater(useGlobalStore); const mobile = useIsMobile(); - const router = useRouter(); useStoreUpdater('isMobile', mobile); useStoreUpdater('router', router); diff --git a/src/layout/GlobalProvider/index.tsx b/src/layout/GlobalProvider/index.tsx index 532d419537da4..fefccd7d20f88 100644 --- a/src/layout/GlobalProvider/index.tsx +++ b/src/layout/GlobalProvider/index.tsx @@ -58,13 +58,13 @@ const GlobalLayout = async ({ children }: GlobalLayoutProps) => { defaultNeutralColor={neutralColor?.value as any} defaultPrimaryColor={primaryColor?.value as any} > - {children} + diff --git a/src/locales/default/components.ts b/src/locales/default/components.ts index 96e9c4ac44e66..d4dd0468dbd85 100644 --- a/src/locales/default/components.ts +++ b/src/locales/default/components.ts @@ -7,6 +7,7 @@ export default { tokens: '该模型单个会话最多支持 {{tokens}} Tokens', vision: '该模型支持视觉识别', }, + removed: '该模型不在列表中,若取消选中将会自动移除', }, ModelSwitchPanel: { emptyModel: '没有启用的模型,请前往设置开启', diff --git a/src/services/message/client.test.ts b/src/services/message/client.test.ts index fbdb6bacded29..1c24dbac80374 100644 --- a/src/services/message/client.test.ts +++ b/src/services/message/client.test.ts @@ -343,28 +343,4 @@ describe('MessageClientService', () => { expect(result).toBe(false); }); }); - - describe('messageCountToCheckTrace', () => { - it('should return true if message count is greater than or equal to 4', async () => { - // Setup - (MessageModel.count as Mock).mockResolvedValue(5); - - // Execute - const result = await messageService.messageCountToCheckTrace(); - - // Assert - expect(result).toBe(true); - }); - - it('should return false if message count is less than 4', async () => { - // Setup - (MessageModel.count as Mock).mockResolvedValue(3); - - // Execute - const result = await messageService.messageCountToCheckTrace(); - - // Assert - expect(result).toBe(false); - }); - }); }); diff --git a/src/services/message/client.ts b/src/services/message/client.ts index 638da5684c063..945b4a794c18f 100644 --- a/src/services/message/client.ts +++ b/src/services/message/client.ts @@ -80,9 +80,4 @@ export class ClientService implements IMessageService { const number = await this.countMessages(); return number > 0; } - - async messageCountToCheckTrace() { - const number = await this.countMessages(); - return number >= 4; - } } diff --git a/src/services/message/type.ts b/src/services/message/type.ts index 064f0f32c018d..a169b4b0441cf 100644 --- a/src/services/message/type.ts +++ b/src/services/message/type.ts @@ -43,5 +43,4 @@ export interface IMessageService { removeAllMessages(): Promise; hasMessages(): Promise; - messageCountToCheckTrace(): Promise; } diff --git a/src/services/user/client.test.ts b/src/services/user/client.test.ts new file mode 100644 index 0000000000000..bbaff3cdcada9 --- /dev/null +++ b/src/services/user/client.test.ts @@ -0,0 +1,100 @@ +import { DeepPartial } from 'utility-types'; +import { Mock, beforeEach, describe, expect, it, vi } from 'vitest'; + +import { UserModel } from '@/database/client/models/user'; +import { GlobalSettings } from '@/types/settings'; +import { UserPreference } from '@/types/user'; +import { AsyncLocalStorage } from '@/utils/localStorage'; + +import { ClientService } from './client'; + +vi.mock('@/database/client/models/user', () => ({ + UserModel: { + getUser: vi.fn(), + updateSettings: vi.fn(), + resetSettings: vi.fn(), + updateAvatar: vi.fn(), + }, +})); + +const mockUser = { + avatar: 'avatar.png', + settings: { themeMode: 'light' } as unknown as GlobalSettings, + uuid: 'user-id', +}; + +const mockPreference = { + useCmdEnterToSend: true, +} as UserPreference; + +describe('ClientService', () => { + let clientService: ClientService; + + beforeEach(() => { + vi.clearAllMocks(); + clientService = new ClientService(); + }); + + it('should get user state correctly', async () => { + (UserModel.getUser as Mock).mockResolvedValue(mockUser); + const spyOn = vi + .spyOn(clientService['preferenceStorage'], 'getFromLocalStorage') + .mockResolvedValue(mockPreference); + + const userState = await clientService.getUserState(); + + expect(userState).toEqual({ + avatar: mockUser.avatar, + isOnboard: true, + canEnableTrace: false, + preference: mockPreference, + settings: mockUser.settings, + userId: mockUser.uuid, + }); + expect(UserModel.getUser).toHaveBeenCalledTimes(1); + expect(spyOn).toHaveBeenCalledTimes(1); + }); + + it('should update user settings correctly', async () => { + const settingsPatch: DeepPartial = { themeMode: 'dark' }; + (UserModel.updateSettings as Mock).mockResolvedValue(undefined); + + await clientService.updateUserSettings(settingsPatch); + + expect(UserModel.updateSettings).toHaveBeenCalledWith(settingsPatch); + expect(UserModel.updateSettings).toHaveBeenCalledTimes(1); + }); + + it('should reset user settings correctly', async () => { + (UserModel.resetSettings as Mock).mockResolvedValue(undefined); + + await clientService.resetUserSettings(); + + expect(UserModel.resetSettings).toHaveBeenCalledTimes(1); + }); + + it('should update user avatar correctly', async () => { + const newAvatar = 'new-avatar.png'; + (UserModel.updateAvatar as Mock).mockResolvedValue(undefined); + + await clientService.updateAvatar(newAvatar); + + expect(UserModel.updateAvatar).toHaveBeenCalledWith(newAvatar); + expect(UserModel.updateAvatar).toHaveBeenCalledTimes(1); + }); + + it('should update user preference correctly', async () => { + const newPreference = { + useCmdEnterToSend: false, + } as UserPreference; + + const spyOn = vi + .spyOn(clientService['preferenceStorage'], 'saveToLocalStorage') + .mockResolvedValue(undefined); + + await clientService.updatePreference(newPreference); + + expect(spyOn).toHaveBeenCalledWith(newPreference); + expect(spyOn).toHaveBeenCalledTimes(1); + }); +}); diff --git a/src/services/user/client.ts b/src/services/user/client.ts index 848b3c19beb7e..bddc5479d9b1b 100644 --- a/src/services/user/client.ts +++ b/src/services/user/client.ts @@ -1,27 +1,33 @@ import { DeepPartial } from 'utility-types'; +import { MessageModel } from '@/database/client/models/message'; import { UserModel } from '@/database/client/models/user'; -import { IUserService } from '@/services/user/type'; import { GlobalSettings } from '@/types/settings'; -import { UserPreference } from '@/types/user'; +import { UserInitializationState, UserPreference } from '@/types/user'; import { AsyncLocalStorage } from '@/utils/localStorage'; -export interface UserConfig { - avatar?: string; - settings: DeepPartial; - uuid: string; -} +import { IUserService } from './type'; export class ClientService implements IUserService { private preferenceStorage: AsyncLocalStorage; + constructor() { this.preferenceStorage = new AsyncLocalStorage('LOBE_PREFERENCE'); } - getUserConfig = async () => { + async getUserState(): Promise { const user = await UserModel.getUser(); - return user as unknown as UserConfig; - }; + const messageCount = await MessageModel.count(); + + return { + avatar: user.avatar, + canEnableTrace: messageCount >= 4, + isOnboard: true, + preference: await this.preferenceStorage.getFromLocalStorage(), + settings: user.settings as GlobalSettings, + userId: user.uuid, + }; + } updateUserSettings = async (patch: DeepPartial) => { return UserModel.updateSettings(patch); @@ -35,10 +41,6 @@ export class ClientService implements IUserService { return UserModel.updateAvatar(avatar); } - async getPreference() { - return this.preferenceStorage.getFromLocalStorage(); - } - async updatePreference(preference: UserPreference) { await this.preferenceStorage.saveToLocalStorage(preference); } diff --git a/src/services/user/index.ts b/src/services/user/index.ts index c09a2f1b128d5..be41b3b3a2f7e 100644 --- a/src/services/user/index.ts +++ b/src/services/user/index.ts @@ -8,6 +8,4 @@ // export const userService = ENABLED_SERVER_SERVICE ? new ServerService() : new ClientService(); import { ClientService } from './client'; -export type { UserConfig } from './client'; - export const userService = new ClientService(); diff --git a/src/services/user/type.ts b/src/services/user/type.ts index 0c3837891115d..18d2c23fbdc9f 100644 --- a/src/services/user/type.ts +++ b/src/services/user/type.ts @@ -1,12 +1,10 @@ import { DeepPartial } from 'utility-types'; -import { UserConfig } from '@/services/user/client'; import { GlobalSettings } from '@/types/settings'; -import { UserPreference } from '@/types/user'; +import { UserInitializationState, UserPreference } from '@/types/user'; export interface IUserService { - getPreference: () => Promise; - getUserConfig: () => Promise; + getUserState: () => Promise; resetUserSettings: () => Promise; updateAvatar: (avatar: string) => Promise; updatePreference: (preference: UserPreference) => Promise; diff --git a/src/store/user/initialState.ts b/src/store/user/initialState.ts index 7ed7a6ff9a835..5dd17f901d8b6 100644 --- a/src/store/user/initialState.ts +++ b/src/store/user/initialState.ts @@ -1,13 +1,22 @@ import { UserAuthState, initialAuthState } from './slices/auth/initialState'; +import { CommonState, initialCommonState } from './slices/common/initialState'; +import { ModelListState, initialModelListState } from './slices/modelList/initialState'; import { UserPreferenceState, initialPreferenceState } from './slices/preference/initialState'; import { UserSettingsState, initialSettingsState } from './slices/settings/initialState'; import { UserSyncState, initialSyncState } from './slices/sync/initialState'; -export type UserState = UserSyncState & UserSettingsState & UserPreferenceState & UserAuthState; +export type UserState = UserSyncState & + UserSettingsState & + UserPreferenceState & + UserAuthState & + ModelListState & + CommonState; export const initialState: UserState = { ...initialSyncState, ...initialSettingsState, ...initialPreferenceState, ...initialAuthState, + ...initialCommonState, + ...initialModelListState, }; diff --git a/src/store/user/selectors.ts b/src/store/user/selectors.ts index 06bbe7036a6f1..4ea96617db636 100644 --- a/src/store/user/selectors.ts +++ b/src/store/user/selectors.ts @@ -1,9 +1,5 @@ export { authSelectors, userProfileSelectors } from './slices/auth/selectors'; +export { modelConfigSelectors, modelProviderSelectors } from './slices/modelList/selectors'; export { preferenceSelectors } from './slices/preference/selectors'; -export { - modelConfigSelectors, - modelProviderSelectors, - settingsSelectors, - syncSettingsSelectors, - systemAgentSelectors, -} from './slices/settings/selectors'; +export { settingsSelectors, systemAgentSelectors } from './slices/settings/selectors'; +export { syncSettingsSelectors } from './slices/sync/selectors'; diff --git a/src/store/user/slices/auth/action.test.ts b/src/store/user/slices/auth/action.test.ts index dc1866a9d63d3..2cd9b889192aa 100644 --- a/src/store/user/slices/auth/action.test.ts +++ b/src/store/user/slices/auth/action.test.ts @@ -9,10 +9,6 @@ import { switchLang } from '@/utils/client/switchLang'; vi.mock('zustand/traditional'); -vi.mock('@/utils/client/switchLang', () => ({ - switchLang: vi.fn(), -})); - vi.mock('swr', async (importOriginal) => { const modules = await importOriginal(); return { @@ -54,93 +50,15 @@ vi.mock('next-auth/react', async () => { }); describe('createAuthSlice', () => { - describe('refreshUserConfig', () => { + describe('refreshUserState', () => { it('should refresh user config', async () => { const { result } = renderHook(() => useUserStore()); await act(async () => { - await result.current.refreshUserConfig(); + await result.current.refreshUserState(); }); - expect(mutate).toHaveBeenCalledWith(['fetchUserConfig', true]); - }); - }); - - describe('useFetchUserConfig', () => { - it('should not fetch user config if initServer is false', async () => { - const mockUserConfig: any = undefined; // 模拟未初始化服务器的情况 - vi.spyOn(userService, 'getUserConfig').mockResolvedValueOnce(mockUserConfig); - - const { result } = renderHook(() => useUserStore().useFetchUserConfig(false), { - wrapper: withSWR, - }); - - // 因为 initServer 为 false,所以不会触发 getUserConfig 的调用 - expect(userService.getUserConfig).not.toHaveBeenCalled(); - // 确保状态未改变 - expect(result.current.data).toBeUndefined(); - }); - - it('should fetch user config correctly when initServer is true', async () => { - const mockUserConfig: any = { - avatar: 'new-avatar-url', - settings: { - language: 'en', - }, - }; - vi.spyOn(userService, 'getUserConfig').mockResolvedValueOnce(mockUserConfig); - - const { result } = renderHook(() => useUserStore().useFetchUserConfig(true), { - wrapper: withSWR, - }); - - // 等待 SWR 完成数据获取 - await waitFor(() => expect(result.current.data).toEqual(mockUserConfig)); - - // 验证状态是否正确更新 - expect(useUserStore.getState().avatar).toBe(mockUserConfig.avatar); - expect(useUserStore.getState().settings).toEqual(mockUserConfig.settings); - - // 验证是否正确处理了语言设置 - expect(switchLang).not.toHaveBeenCalledWith('auto'); - }); - it('should call switch language when language is auto', async () => { - const mockUserConfig: any = { - avatar: 'new-avatar-url', - settings: { - language: 'auto', - }, - }; - vi.spyOn(userService, 'getUserConfig').mockResolvedValueOnce(mockUserConfig); - - const { result } = renderHook(() => useUserStore().useFetchUserConfig(true), { - wrapper: withSWR, - }); - - // 等待 SWR 完成数据获取 - await waitFor(() => expect(result.current.data).toEqual(mockUserConfig)); - - // 验证状态是否正确更新 - expect(useUserStore.getState().avatar).toBe(mockUserConfig.avatar); - expect(useUserStore.getState().settings).toEqual(mockUserConfig.settings); - - // 验证是否正确处理了语言设置 - expect(switchLang).toHaveBeenCalledWith('auto'); - }); - - it('should handle the case when user config is null', async () => { - vi.spyOn(userService, 'getUserConfig').mockResolvedValueOnce(null as any); - - const { result } = renderHook(() => useUserStore().useFetchUserConfig(true), { - wrapper: withSWR, - }); - - // 等待 SWR 完成数据获取 - await waitFor(() => expect(result.current.data).toBeNull()); - - // 验证状态未被错误更新 - expect(useUserStore.getState().avatar).toBeUndefined(); - expect(useUserStore.getState().settings).toEqual({}); + expect(mutate).toHaveBeenCalledWith('initUserState'); }); }); @@ -174,7 +92,7 @@ describe('createAuthSlice', () => { }); it('should call next-auth signOut when NextAuth is enabled', async () => { - useUserStore.setState({ enabledNextAuth: () => true }); + useUserStore.setState({ enabledNextAuth: true }); const { result } = renderHook(() => useUserStore()); @@ -228,7 +146,7 @@ describe('createAuthSlice', () => { }); it('should call next-auth signIn when NextAuth is enabled', async () => { - useUserStore.setState({ enabledNextAuth: () => true }); + useUserStore.setState({ enabledNextAuth: true }); const { result } = renderHook(() => useUserStore()); diff --git a/src/store/user/slices/auth/action.ts b/src/store/user/slices/auth/action.ts index 24b2080c9967d..f501793b5fc9d 100644 --- a/src/store/user/slices/auth/action.ts +++ b/src/store/user/slices/auth/action.ts @@ -1,21 +1,11 @@ -import useSWR, { SWRResponse, mutate } from 'swr'; import { StateCreator } from 'zustand/vanilla'; import { enableClerk } from '@/const/auth'; -import { UserConfig, userService } from '@/services/user'; -import { switchLang } from '@/utils/client/switchLang'; -import { setNamespace } from '@/utils/storeDebug'; import { UserStore } from '../../store'; -import { settingsSelectors } from '../settings/selectors'; - -const n = setNamespace('auth'); -const USER_CONFIG_FETCH_KEY = 'fetchUserConfig'; export interface UserAuthAction { enableAuth: () => boolean; - enabledNextAuth: () => boolean; - getUserConfig: () => void; /** * universal logout method */ @@ -25,9 +15,6 @@ export interface UserAuthAction { */ openLogin: () => Promise; openUserProfile: () => Promise; - - refreshUserConfig: () => Promise; - useFetchUserConfig: (initServer: boolean) => SWRResponse; } export const createAuthSlice: StateCreator< @@ -37,13 +24,7 @@ export const createAuthSlice: StateCreator< UserAuthAction > = (set, get) => ({ enableAuth: () => { - return enableClerk || get()?.enabledNextAuth(); - }, - enabledNextAuth: () => { - return !!get()?.serverConfig.enabledOAuthSSO; - }, - getUserConfig: () => { - console.log(n('userconfig')); + return enableClerk || get()?.enabledNextAuth || false; }, logout: async () => { if (enableClerk) { @@ -52,7 +33,7 @@ export const createAuthSlice: StateCreator< return; } - const enableNextAuth = get().enabledNextAuth(); + const enableNextAuth = get().enabledNextAuth; if (enableNextAuth) { const { signOut } = await import('next-auth/react'); signOut(); @@ -60,14 +41,12 @@ export const createAuthSlice: StateCreator< }, openLogin: async () => { if (enableClerk) { - console.log('fallbackRedirectUrl:', location.toString()); - get().clerkSignIn?.({ fallbackRedirectUrl: location.toString() }); return; } - const enableNextAuth = get().enabledNextAuth(); + const enableNextAuth = get().enabledNextAuth; if (enableNextAuth) { const { signIn } = await import('next-auth/react'); signIn(); @@ -81,38 +60,4 @@ export const createAuthSlice: StateCreator< return; } }, - refreshUserConfig: async () => { - await mutate([USER_CONFIG_FETCH_KEY, true]); - - // when get the user config ,refresh the model provider list to the latest - get().refreshModelProviderList(); - }, - useFetchUserConfig: (initServer) => - useSWR( - [USER_CONFIG_FETCH_KEY, initServer], - async () => { - if (!initServer) return; - return userService.getUserConfig(); - }, - { - onSuccess: (data) => { - if (!data) return; - - set( - { avatar: data.avatar, settings: data.settings, userId: data.uuid }, - false, - n('fetchUserConfig', data), - ); - - // when get the user config ,refresh the model provider list to the latest - get().refreshDefaultModelProviderList({ trigger: 'fetchUserConfig' }); - - const { language } = settingsSelectors.currentSettings(get()); - if (language === 'auto') { - switchLang('auto'); - } - }, - revalidateOnFocus: false, - }, - ), }); diff --git a/src/store/user/slices/auth/initialState.ts b/src/store/user/slices/auth/initialState.ts index 94d2cedd817c1..6e9b59311872e 100644 --- a/src/store/user/slices/auth/initialState.ts +++ b/src/store/user/slices/auth/initialState.ts @@ -15,12 +15,13 @@ export interface UserAuthState { * @deprecated */ avatar?: string; - clerkOpenUserProfile?: (props?: UserProfileProps) => void; + clerkSession?: ActiveSessionResource; clerkSignIn?: (props?: SignInProps) => void; clerkSignOut?: SignOut; clerkUser?: UserResource; + enabledNextAuth?: boolean; isLoaded?: boolean; isSignedIn?: boolean; diff --git a/src/store/user/slices/common/action.test.ts b/src/store/user/slices/common/action.test.ts index d8f3463f4a1f8..a557afec856e9 100644 --- a/src/store/user/slices/common/action.test.ts +++ b/src/store/user/slices/common/action.test.ts @@ -3,15 +3,20 @@ import { mutate } from 'swr'; import { afterEach, describe, expect, it, vi } from 'vitest'; import { withSWR } from '~test-utils'; -import { globalService } from '@/services/global'; -import { messageService } from '@/services/message'; +import { DEFAULT_PREFERENCE } from '@/const/user'; import { userService } from '@/services/user'; import { useUserStore } from '@/store/user'; import { preferenceSelectors } from '@/store/user/selectors'; import { GlobalServerConfig } from '@/types/serverConfig'; +import { UserInitializationState, UserPreference } from '@/types/user'; +import { switchLang } from '@/utils/client/switchLang'; vi.mock('zustand/traditional'); +vi.mock('@/utils/client/switchLang', () => ({ + switchLang: vi.fn(), +})); + vi.mock('swr', async (importOriginal) => { const modules = await importOriginal(); return { @@ -30,7 +35,7 @@ describe('createCommonSlice', () => { const { result } = renderHook(() => useUserStore()); const avatar = 'new-avatar'; - const spyOn = vi.spyOn(result.current, 'refreshUserConfig'); + const spyOn = vi.spyOn(result.current, 'refreshUserState'); const updateAvatarSpy = vi.spyOn(userService, 'updateAvatar'); await act(async () => { @@ -42,28 +47,197 @@ describe('createCommonSlice', () => { }); }); - describe('useFetchServerConfig', () => { - it('should fetch server config correctly', async () => { - const mockServerConfig = { - defaultAgent: 'agent1', - languageModel: 'model1', - telemetry: {}, - } as GlobalServerConfig; - vi.spyOn(globalService, 'getGlobalConfig').mockResolvedValueOnce(mockServerConfig); + describe('useInitUserState', () => { + const mockServerConfig = { + defaultAgent: 'agent1', + languageModel: 'model1', + telemetry: {}, + } as GlobalServerConfig; + + it('should not fetch user state if user is not login', async () => { + const mockUserConfig: any = undefined; // 模拟未初始化服务器的情况 + vi.spyOn(userService, 'getUserState').mockResolvedValueOnce(mockUserConfig); + const successCallback = vi.fn(); + + const { result } = renderHook( + () => + useUserStore().useInitUserState(false, mockServerConfig, { + onSuccess: successCallback, + }), + { wrapper: withSWR }, + ); + + // 因为 initServer 为 false,所以不会触发 getUserState 的调用 + expect(userService.getUserState).not.toHaveBeenCalled(); + // 也不会触发 onSuccess 回调 + expect(successCallback).not.toHaveBeenCalled(); + // 确保状态未改变 + expect(result.current.data).toBeUndefined(); + }); + + it('should fetch user state correctly when user is login', async () => { + const mockUserState: UserInitializationState = { + userId: 'user-id', + isOnboard: true, + preference: { + telemetry: true, + }, + settings: { + language: 'en-US', + }, + }; + + vi.spyOn(userService, 'getUserState').mockResolvedValueOnce(mockUserState); + const successCallback = vi.fn(); + + const { result } = renderHook( + () => + useUserStore().useInitUserState(true, mockServerConfig, { + onSuccess: successCallback, + }), + { + wrapper: withSWR, + }, + ); + + // 等待 SWR 完成数据获取 + await waitFor(() => expect(result.current.data).toEqual(mockUserState)); + + // 验证状态是否正确更新 + expect(useUserStore.getState().avatar).toBe(mockUserState.avatar); + expect(useUserStore.getState().settings).toEqual(mockUserState.settings); + expect(successCallback).toHaveBeenCalledWith(mockUserState); + + // 验证是否正确处理了语言设置 + expect(switchLang).not.toHaveBeenCalledWith('auto'); + }); + + it('should call switch language when language is auto', async () => { + const mockUserState: UserInitializationState = { + userId: 'user-id', + isOnboard: true, + preference: { + telemetry: true, + }, + settings: {}, + }; + + vi.spyOn(userService, 'getUserState').mockResolvedValueOnce(mockUserState); + + const { result } = renderHook(() => useUserStore().useInitUserState(true, mockServerConfig), { + wrapper: withSWR, + }); + + // 等待 SWR 完成数据获取 + await waitFor(() => expect(result.current.data).toEqual(mockUserState)); + + // 验证是否正确处理了语言设置 + expect(switchLang).toHaveBeenCalledWith('auto'); + }); + + it('should fetch use server config correctly', async () => { + const mockUserState: UserInitializationState = { + userId: 'user-id', + isOnboard: true, + preference: { + telemetry: true, + }, + settings: {}, + }; + vi.spyOn(userService, 'getUserState').mockResolvedValueOnce(mockUserState); + + const { result } = renderHook(() => useUserStore().useInitUserState(true, mockServerConfig)); + + await waitFor(() => expect(result.current.data).toEqual(mockUserState)); + }); + + it('should return saved preference when local storage has data', async () => { + const { result } = renderHook(() => useUserStore()); + + const savedPreference: UserPreference = { + ...DEFAULT_PREFERENCE, + hideSyncAlert: true, + guide: { topic: false, moveSettingsToAvatar: true }, + }; + + const mockUserState: UserInitializationState = { + userId: 'user-id', + isOnboard: true, + preference: savedPreference, + settings: { + language: 'en-US', + }, + }; + vi.spyOn(userService, 'getUserState').mockResolvedValueOnce(mockUserState); + + const { result: preference } = renderHook( + () => result.current.useInitUserState(true, mockServerConfig), + { wrapper: withSWR }, + ); + + await waitFor(() => { + expect(preference.current.data.preference).toEqual(savedPreference); + expect(result.current.isUserStateInit).toBeTruthy(); + expect(result.current.preference).toEqual(savedPreference); + }); + }); + + it('should handle the case when user config is null', async () => { + const { result } = renderHook(() => useUserStore()); + const mockUserState: UserInitializationState = { + userId: 'user-id', + isOnboard: true, + preference: undefined as any, + settings: null as any, + }; + + vi.spyOn(userService, 'getUserState').mockResolvedValueOnce(mockUserState); + + renderHook(() => result.current.useInitUserState(true, mockServerConfig), { + wrapper: withSWR, + }); + + // 等待 SWR 完成数据获取 + await waitFor(() => { + expect(result.current.isUserStateInit).toBeTruthy(); + // 验证状态未被错误更新 + expect(result.current.avatar).toBeUndefined(); + expect(result.current.settings).toEqual({}); + }); + }); + + it('should return default preference when local storage is empty', async () => { + const { result } = renderHook(() => useUserStore()); - const { result } = renderHook(() => useUserStore().useFetchServerConfig()); + const mockUserState: UserInitializationState = { + userId: 'user-id', + isOnboard: true, + preference: {} as any, + settings: { + language: 'en-US', + }, + }; - await waitFor(() => expect(result.current.data).toEqual(mockServerConfig)); + vi.spyOn(userService, 'getUserState').mockResolvedValueOnce(mockUserState); + + renderHook(() => result.current.useInitUserState(true, mockServerConfig), { + wrapper: withSWR, + }); + + await waitFor(() => { + expect(result.current.isUserStateInit).toBeTruthy(); + expect(result.current.preference).toEqual(DEFAULT_PREFERENCE); + }); }); }); describe('useCheckTrace', () => { - it('should return false when shouldFetch is false', async () => { + it('should return undefined when shouldFetch is false', async () => { const { result } = renderHook(() => useUserStore().useCheckTrace(false), { wrapper: withSWR, }); - await waitFor(() => expect(result.current.data).toBe(false)); + await waitFor(() => expect(result.current.data).toBeUndefined()); }); it('should return false when userAllowTrace is already set', async () => { @@ -78,16 +252,18 @@ describe('createCommonSlice', () => { it('should call messageService.messageCountToCheckTrace when needed', async () => { vi.spyOn(preferenceSelectors, 'userAllowTrace').mockReturnValueOnce(null); - const messageCountToCheckTraceSpy = vi - .spyOn(messageService, 'messageCountToCheckTrace') - .mockResolvedValueOnce(true); - const { result } = renderHook(() => useUserStore().useCheckTrace(true), { + act(() => { + useUserStore.setState({ + isUserCanEnableTrace: true, + }); + }); + + const { result } = renderHook(() => useUserStore.getState().useCheckTrace(true), { wrapper: withSWR, }); await waitFor(() => expect(result.current.data).toBe(true)); - expect(messageCountToCheckTraceSpy).toHaveBeenCalled(); }); }); }); diff --git a/src/store/user/slices/common/action.ts b/src/store/user/slices/common/action.ts index 038a5768df33a..4f40f30797783 100644 --- a/src/store/user/slices/common/action.ts +++ b/src/store/user/slices/common/action.ts @@ -1,26 +1,38 @@ -import useSWR, { SWRResponse } from 'swr'; +import useSWR, { SWRResponse, mutate } from 'swr'; import { DeepPartial } from 'utility-types'; import type { StateCreator } from 'zustand/vanilla'; -import { messageService } from '@/services/message'; +import { DEFAULT_PREFERENCE } from '@/const/user'; import { userService } from '@/services/user'; import type { UserStore } from '@/store/user'; import type { GlobalServerConfig } from '@/types/serverConfig'; import type { GlobalSettings } from '@/types/settings'; +import { UserInitializationState } from '@/types/user'; +import { switchLang } from '@/utils/client/switchLang'; import { merge } from '@/utils/merge'; import { setNamespace } from '@/utils/storeDebug'; import { preferenceSelectors } from '../preference/selectors'; +import { settingsSelectors } from '../settings/selectors'; const n = setNamespace('common'); +const GET_USER_STATE_KEY = 'initUserState'; /** * 设置操作 */ export interface CommonAction { + refreshUserState: () => Promise; + updateAvatar: (avatar: string) => Promise; useCheckTrace: (shouldFetch: boolean) => SWRResponse; - useFetchServerConfig: () => SWRResponse; + useInitUserState: ( + isLogin: boolean | undefined, + serverConfig: GlobalServerConfig, + options?: { + onSuccess: (data: UserInitializationState) => void; + }, + ) => SWRResponse; } export const createCommonSlice: StateCreator< @@ -29,55 +41,72 @@ export const createCommonSlice: StateCreator< [], CommonAction > = (set, get) => ({ + refreshUserState: async () => { + await mutate(GET_USER_STATE_KEY); + }, updateAvatar: async (avatar) => { await userService.updateAvatar(avatar); - await get().refreshUserConfig(); + await get().refreshUserState(); }, useCheckTrace: (shouldFetch) => useSWR( - ['checkTrace', shouldFetch], + shouldFetch ? 'checkTrace' : null, () => { const userAllowTrace = preferenceSelectors.userAllowTrace(get()); - // if not init with server side, return false - if (!shouldFetch) return Promise.resolve(false); // if user have set the trace, return false if (typeof userAllowTrace === 'boolean') return Promise.resolve(false); - return messageService.messageCountToCheckTrace(); + return Promise.resolve(get().isUserCanEnableTrace); }, { revalidateOnFocus: false, }, ), - /** - * TODO: need to be removed in the future - * the serverConfig should be fetched only in the serverConfigStore - * @deprecated - */ - useFetchServerConfig: () => - useSWR( - 'fetchGlobalConfig', - async () => { - const { globalService } = await import('@/services/global'); - - return globalService.getGlobalConfig(); - }, + useInitUserState: (isLogin, serverConfig, options) => + useSWR( + !!isLogin ? GET_USER_STATE_KEY : null, + () => userService.getUserState(), { onSuccess: (data) => { + options?.onSuccess?.(data); + if (data) { + // merge settings const serverSettings: DeepPartial = { - defaultAgent: data.defaultAgent, - languageModel: data.languageModel, + defaultAgent: serverConfig.defaultAgent, + languageModel: serverConfig.languageModel, }; - const defaultSettings = merge(get().defaultSettings, serverSettings); - set({ defaultSettings, serverConfig: data }, false, n('initGlobalConfig')); + // merge preference + const isEmpty = Object.keys(data.preference || {}).length === 0; + const preference = isEmpty ? DEFAULT_PREFERENCE : data.preference; + + set( + { + defaultSettings, + enabledNextAuth: serverConfig.enabledOAuthSSO, + isUserCanEnableTrace: data.canEnableTrace, + isUserStateInit: true, + preference, + serverLanguageModel: serverConfig.languageModel, + settings: data.settings || {}, + userId: data.userId, + }, + false, + n('initUserState'), + ); + + get().refreshDefaultModelProviderList({ trigger: 'fetchUserState' }); - get().refreshDefaultModelProviderList({ trigger: 'fetchServerConfig' }); + // auto switch language + const { language } = settingsSelectors.currentSettings(get()); + if (language === 'auto') { + switchLang('auto'); + } } }, revalidateOnFocus: false, diff --git a/src/store/user/slices/common/initialState.ts b/src/store/user/slices/common/initialState.ts new file mode 100644 index 0000000000000..53ae71737eabd --- /dev/null +++ b/src/store/user/slices/common/initialState.ts @@ -0,0 +1,9 @@ +export interface CommonState { + isUserCanEnableTrace: boolean; + isUserStateInit: boolean; +} + +export const initialCommonState: CommonState = { + isUserCanEnableTrace: false, + isUserStateInit: false, +}; diff --git a/src/store/user/slices/modelList/action.test.ts b/src/store/user/slices/modelList/action.test.ts new file mode 100644 index 0000000000000..5ac73c02e7d05 --- /dev/null +++ b/src/store/user/slices/modelList/action.test.ts @@ -0,0 +1,363 @@ +import { act, renderHook, waitFor } from '@testing-library/react'; +import { describe, expect, it, vi } from 'vitest'; + +import { modelsService } from '@/services/models'; +import { userService } from '@/services/user'; +import { useUserStore } from '@/store/user'; +import { GeneralModelProviderConfig } from '@/types/settings'; + +import { settingsSelectors } from '../settings/selectors'; +import { CustomModelCardDispatch } from './reducers/customModelCard'; +import { modelProviderSelectors } from './selectors'; + +// Mock userService +vi.mock('@/services/user', () => ({ + userService: { + updateUserSettings: vi.fn(), + resetUserSettings: vi.fn(), + }, +})); + +vi.mock('zustand/traditional'); + +describe('LLMSettingsSliceAction', () => { + describe('setModelProviderConfig', () => { + it('should set OpenAI configuration', async () => { + const { result } = renderHook(() => useUserStore()); + const openAIConfig: Partial = { apiKey: 'test-key' }; + + // Perform the action + await act(async () => { + await result.current.setModelProviderConfig('openai', openAIConfig); + }); + + // Assert that updateUserSettings was called with the correct OpenAI configuration + expect(userService.updateUserSettings).toHaveBeenCalledWith({ + languageModel: { + openai: openAIConfig, + }, + }); + }); + }); + + describe('dispatchCustomModelCards', () => { + it('should return early when prevState does not exist', async () => { + const { result } = renderHook(() => useUserStore()); + const provider = 'openai'; + const payload: CustomModelCardDispatch = { type: 'add', modelCard: { id: 'test-id' } }; + + // Mock the selector to return undefined + vi.spyOn(settingsSelectors, 'providerConfig').mockReturnValueOnce(() => undefined); + vi.spyOn(result.current, 'setModelProviderConfig'); + + await act(async () => { + await result.current.dispatchCustomModelCards(provider, payload); + }); + + // Assert that setModelProviderConfig was not called + expect(result.current.setModelProviderConfig).not.toHaveBeenCalled(); + }); + }); + + describe('refreshDefaultModelProviderList', () => { + it('default', async () => { + const { result } = renderHook(() => useUserStore()); + + act(() => { + useUserStore.setState({ + serverLanguageModel: { + azure: { serverModelCards: [{ id: 'abc', deploymentName: 'abc' }] }, + }, + }); + }); + + act(() => { + result.current.refreshDefaultModelProviderList(); + }); + + // Assert that setModelProviderConfig was not called + const azure = result.current.defaultModelProviderList.find((m) => m.id === 'azure'); + expect(azure?.chatModels).toEqual([{ id: 'abc', deploymentName: 'abc' }]); + }); + + it('openai', () => { + const { result } = renderHook(() => useUserStore()); + act(() => { + useUserStore.setState({ + serverLanguageModel: { + openai: { + enabled: true, + enabledModels: ['gpt-4-0125-preview', 'gpt-4-turbo-2024-04-09'], + serverModelCards: [ + { + displayName: 'ChatGPT-4', + functionCall: true, + id: 'gpt-4-0125-preview', + tokens: 128000, + enabled: true, + }, + { + displayName: 'ChatGPT-4 Vision', + functionCall: true, + id: 'gpt-4-turbo-2024-04-09', + tokens: 128000, + vision: true, + enabled: true, + }, + ], + }, + }, + }); + }); + + act(() => { + result.current.refreshDefaultModelProviderList(); + }); + + // Assert that setModelProviderConfig was not called + const openai = result.current.defaultModelProviderList.find((m) => m.id === 'openai'); + expect(openai?.chatModels).toEqual([ + { + displayName: 'ChatGPT-4', + enabled: true, + functionCall: true, + id: 'gpt-4-0125-preview', + tokens: 128000, + }, + { + displayName: 'ChatGPT-4 Vision', + enabled: true, + functionCall: true, + id: 'gpt-4-turbo-2024-04-09', + tokens: 128000, + vision: true, + }, + ]); + }); + }); + + describe('refreshModelProviderList', () => { + it('visible', async () => { + const { result } = renderHook(() => useUserStore()); + act(() => { + useUserStore.setState({ + settings: { + languageModel: { + ollama: { enabledModels: ['llava'] }, + }, + }, + }); + }); + + act(() => { + result.current.refreshModelProviderList(); + }); + + const ollamaList = result.current.modelProviderList.find((r) => r.id === 'ollama'); + // Assert that setModelProviderConfig was not called + expect(ollamaList?.chatModels.find((c) => c.id === 'llava')).toEqual({ + displayName: 'LLaVA 7B', + enabled: true, + id: 'llava', + tokens: 4096, + vision: true, + }); + }); + + it('modelProviderListForModelSelect should return only enabled providers', () => { + const { result } = renderHook(() => useUserStore()); + + act(() => { + useUserStore.setState({ + settings: { + languageModel: { + perplexity: { enabled: true }, + azure: { enabled: false }, + }, + }, + }); + }); + + act(() => { + result.current.refreshModelProviderList(); + }); + + const enabledProviders = modelProviderSelectors.modelProviderListForModelSelect( + result.current, + ); + expect(enabledProviders).toHaveLength(3); + expect(enabledProviders.at(-1)!.id).toBe('perplexity'); + }); + }); + + describe('removeEnabledModels', () => { + it('should remove the specified model from enabledModels', async () => { + const { result } = renderHook(() => useUserStore()); + const model = 'gpt-3.5-turbo'; + + const spyOn = vi.spyOn(userService, 'updateUserSettings'); + + act(() => { + useUserStore.setState({ + settings: { + languageModel: { + azure: { enabledModels: ['gpt-3.5-turbo', 'gpt-4'] }, + }, + }, + }); + }); + + await act(async () => { + console.log(JSON.stringify(result.current.settings)); + await result.current.removeEnabledModels('azure', model); + }); + + expect(spyOn).toHaveBeenCalledWith({ + languageModel: { + azure: { enabledModels: ['gpt-4'] }, + }, + }); + }); + }); + + describe('toggleEditingCustomModelCard', () => { + it('should update editingCustomCardModel when params are provided', () => { + const { result } = renderHook(() => useUserStore()); + + act(() => { + result.current.toggleEditingCustomModelCard({ id: 'test-id', provider: 'openai' }); + }); + + expect(result.current.editingCustomCardModel).toEqual({ id: 'test-id', provider: 'openai' }); + }); + + it('should reset editingCustomCardModel when no params are provided', () => { + const { result } = renderHook(() => useUserStore()); + + act(() => { + result.current.toggleEditingCustomModelCard(); + }); + + expect(result.current.editingCustomCardModel).toBeUndefined(); + }); + }); + + describe('toggleProviderEnabled', () => { + it('should enable the provider', async () => { + const { result } = renderHook(() => useUserStore()); + + await act(async () => { + await result.current.toggleProviderEnabled('minimax', true); + }); + + expect(userService.updateUserSettings).toHaveBeenCalledWith({ + languageModel: { + minimax: { enabled: true }, + }, + }); + }); + + it('should disable the provider', async () => { + const { result } = renderHook(() => useUserStore()); + const provider = 'openai'; + + await act(async () => { + await result.current.toggleProviderEnabled(provider, false); + }); + + expect(userService.updateUserSettings).toHaveBeenCalledWith({ + languageModel: { + openai: { enabled: false }, + }, + }); + }); + }); + + describe('updateEnabledModels', () => { + // TODO: 有待 updateEnabledModels 实现的同步改造 + it('should add new custom model to customModelCards', async () => { + const { result } = renderHook(() => useUserStore()); + const provider = 'openai'; + const modelKeys = ['gpt-3.5-turbo', 'custom-model']; + const options = [{ value: 'gpt-3.5-turbo' }, {}]; + + await act(async () => { + await result.current.updateEnabledModels(provider, modelKeys, options); + }); + + expect(userService.updateUserSettings).toHaveBeenCalledWith({ + languageModel: { + openai: { + customModelCards: [{ id: 'custom-model' }], + // TODO:目标单测中需要包含下面这一行 + // enabledModels: ['gpt-3.5-turbo', 'custom-model'], + }, + }, + }); + }); + + it('should not add removed model to customModelCards', async () => { + const { result } = renderHook(() => useUserStore()); + const provider = 'openai'; + const modelKeys = ['gpt-3.5-turbo']; + const options = [{ value: 'gpt-3.5-turbo' }]; + + act(() => { + useUserStore.setState({ + settings: { + languageModel: { + openai: { enabledModels: ['gpt-3.5-turbo', 'gpt-4'] }, + }, + }, + }); + }); + + await act(async () => { + await result.current.updateEnabledModels(provider, modelKeys, options); + }); + + expect(userService.updateUserSettings).toHaveBeenCalledWith({ + languageModel: { + openai: { enabledModels: ['gpt-3.5-turbo'] }, + }, + }); + }); + }); + + describe('useFetchProviderModelList', () => { + it('should fetch data when enabledAutoFetch is true', async () => { + const { result } = renderHook(() => useUserStore()); + const provider = 'openai'; + const enabledAutoFetch = true; + + const spyOn = vi.spyOn(result.current, 'refreshDefaultModelProviderList'); + + vi.spyOn(modelsService, 'getChatModels').mockResolvedValueOnce([]); + + renderHook(() => result.current.useFetchProviderModelList(provider, enabledAutoFetch)); + + await waitFor(() => { + expect(spyOn).toHaveBeenCalled(); + }); + + // expect(result.current.settings.languageModel.openai?.latestFetchTime).toBeDefined(); + // expect(result.current.settings.languageModel.openai?.remoteModelCards).toBeDefined(); + }); + + it('should not fetch data when enabledAutoFetch is false', async () => { + const { result } = renderHook(() => useUserStore()); + const provider = 'openai'; + const enabledAutoFetch = false; + + const spyOn = vi.spyOn(result.current, 'refreshDefaultModelProviderList'); + + vi.spyOn(modelsService, 'getChatModels').mockResolvedValueOnce([]); + + renderHook(() => result.current.useFetchProviderModelList(provider, enabledAutoFetch)); + + await waitFor(() => { + expect(spyOn).not.toHaveBeenCalled(); + }); + }); + }); +}); diff --git a/src/store/user/slices/settings/actions/llm.ts b/src/store/user/slices/modelList/action.ts similarity index 60% rename from src/store/user/slices/settings/actions/llm.ts rename to src/store/user/slices/modelList/action.ts index 9a0dfb3d00e46..0cecf3dc18b70 100644 --- a/src/store/user/slices/settings/actions/llm.ts +++ b/src/store/user/slices/modelList/action.ts @@ -1,39 +1,21 @@ +import { produce } from 'immer'; import useSWR, { SWRResponse } from 'swr'; import type { StateCreator } from 'zustand/vanilla'; -import { - AnthropicProviderCard, - AzureProviderCard, - BedrockProviderCard, - DeepSeekProviderCard, - GoogleProviderCard, - GroqProviderCard, - MinimaxProviderCard, - MistralProviderCard, - MoonshotProviderCard, - OllamaProviderCard, - OpenAIProviderCard, - OpenRouterProviderCard, - PerplexityProviderCard, - TogetherAIProviderCard, - ZeroOneProviderCard, - ZhiPuProviderCard, -} from '@/config/modelProviders'; +import { DEFAULT_MODEL_PROVIDER_LIST } from '@/config/modelProviders'; +import { ModelProvider } from '@/libs/agent-runtime'; import { UserStore } from '@/store/user'; import { ChatModelCard } from '@/types/llm'; import { GlobalLLMConfig, GlobalLLMProviderKey } from '@/types/settings'; -import { setNamespace } from '@/utils/storeDebug'; -import { CustomModelCardDispatch, customModelCardsReducer } from '../reducers/customModelCard'; -import { modelProviderSelectors } from '../selectors/modelProvider'; -import { settingsSelectors } from '../selectors/settings'; - -const n = setNamespace('settings'); +import { settingsSelectors } from '../settings/selectors'; +import { CustomModelCardDispatch, customModelCardsReducer } from './reducers/customModelCard'; +import { modelProviderSelectors } from './selectors/modelProvider'; /** * 设置操作 */ -export interface LLMSettingsAction { +export interface ModelListAction { dispatchCustomModelCards: ( provider: GlobalLLMProviderKey, payload: CustomModelCardDispatch, @@ -48,21 +30,28 @@ export interface LLMSettingsAction { provider: T, config: Partial, ) => Promise; + toggleEditingCustomModelCard: (params?: { id: string; provider: GlobalLLMProviderKey }) => void; toggleProviderEnabled: (provider: GlobalLLMProviderKey, enabled: boolean) => Promise; + updateEnabledModels: ( + provider: GlobalLLMProviderKey, + modelKeys: string[], + options: { label?: string; value?: string }[], + ) => Promise; + useFetchProviderModelList: ( provider: GlobalLLMProviderKey, enabledAutoFetch: boolean, ) => SWRResponse; } -export const llmSettingsSlice: StateCreator< +export const createModelListSlice: StateCreator< UserStore, [['zustand/devtools', never]], [], - LLMSettingsAction + ModelListAction > = (set, get) => ({ dispatchCustomModelCards: async (provider, payload) => { const prevState = settingsSelectors.providerConfig(provider)(get()); @@ -73,7 +62,6 @@ export const llmSettingsSlice: StateCreator< await get().setModelProviderConfig(provider, { customModelCards: nextState }); }, - refreshDefaultModelProviderList: (params) => { /** * Because we have several model cards sources, we need to merge the model cards @@ -92,39 +80,27 @@ export const llmSettingsSlice: StateCreator< return serverChatModels ?? remoteChatModels ?? defaultChatModels; }; - const defaultModelProviderList = [ - { - ...OpenAIProviderCard, - chatModels: mergeModels('openai', OpenAIProviderCard.chatModels), - }, - { ...AzureProviderCard, chatModels: mergeModels('azure', []) }, - { ...OllamaProviderCard, chatModels: mergeModels('ollama', OllamaProviderCard.chatModels) }, - AnthropicProviderCard, - GoogleProviderCard, - { - ...OpenRouterProviderCard, - chatModels: mergeModels('openrouter', OpenRouterProviderCard.chatModels), - }, - { - ...TogetherAIProviderCard, - chatModels: mergeModels('togetherai', TogetherAIProviderCard.chatModels), - }, - BedrockProviderCard, - DeepSeekProviderCard, - PerplexityProviderCard, - MinimaxProviderCard, - MistralProviderCard, - GroqProviderCard, - MoonshotProviderCard, - ZeroOneProviderCard, - ZhiPuProviderCard, - ]; - - set({ defaultModelProviderList }, false, n(`refreshDefaultModelList - ${params?.trigger}`)); + const defaultModelProviderList = produce(DEFAULT_MODEL_PROVIDER_LIST, (draft) => { + const openai = draft.find((d) => d.id === ModelProvider.OpenAI); + if (openai) openai.chatModels = mergeModels('openai', openai.chatModels); + + const azure = draft.find((d) => d.id === ModelProvider.Azure); + if (azure) azure.chatModels = mergeModels('azure', azure.chatModels); + + const ollama = draft.find((d) => d.id === ModelProvider.Ollama); + if (ollama) ollama.chatModels = mergeModels('ollama', ollama.chatModels); + + const openrouter = draft.find((d) => d.id === ModelProvider.OpenRouter); + if (openrouter) openrouter.chatModels = mergeModels('openrouter', openrouter.chatModels); + + const togetherai = draft.find((d) => d.id === ModelProvider.TogetherAI); + if (togetherai) togetherai.chatModels = mergeModels('togetherai', togetherai.chatModels); + }); + + set({ defaultModelProviderList }, false, `refreshDefaultModelList - ${params?.trigger}`); get().refreshModelProviderList({ trigger: 'refreshDefaultModelList' }); }, - refreshModelProviderList: (params) => { const modelProviderList = get().defaultModelProviderList.map((list) => ({ ...list, @@ -143,7 +119,7 @@ export const llmSettingsSlice: StateCreator< enabled: modelProviderSelectors.isProviderEnabled(list.id as any)(get()), })); - set({ modelProviderList }, false, n(`refreshModelList - ${params?.trigger}`)); + set({ modelProviderList }, false, `refreshModelList - ${params?.trigger}`); }, removeEnabledModels: async (provider, model) => { @@ -157,14 +133,44 @@ export const llmSettingsSlice: StateCreator< setModelProviderConfig: async (provider, config) => { await get().setSettings({ languageModel: { [provider]: config } }); }, + toggleEditingCustomModelCard: (params) => { set({ editingCustomCardModel: params }, false, 'toggleEditingCustomModelCard'); }, - toggleProviderEnabled: async (provider, enabled) => { await get().setSettings({ languageModel: { [provider]: { enabled } } }); }, + updateEnabledModels: async (provider, value, options) => { + const { dispatchCustomModelCards, setModelProviderConfig } = get(); + const enabledModels = modelProviderSelectors.getEnableModelsById(provider)(get()); + + // if there is a new model, add it to `customModelCards` + const pools = options.map(async (option: { label?: string; value?: string }, index: number) => { + // if is a known model, it should have value + // if is an unknown model, the option will be {} + if (option.value) return; + + const modelId = value[index]; + + // if is in enabledModels, it means it's a removed model + if (enabledModels?.some((m) => modelId === m)) return; + + await dispatchCustomModelCards(provider, { + modelCard: { id: modelId }, + type: 'add', + }); + }); + + // TODO: 当前的这个 pool 方法并不是最好的实现,因为它会触发 setModelProviderConfig 的多次更新。 + // 理论上应该合并这些变更,然后最后只做一次触发 + // 因此后续的做法应该是将 dispatchCustomModelCards 改造为同步方法,并在最后做一次异步更新 + // 对应需要改造 'should add new custom model to customModelCards' 这一个单测 + await Promise.all(pools); + + await setModelProviderConfig(provider, { enabledModels: value.filter(Boolean) }); + }, + useFetchProviderModelList: (provider, enabledAutoFetch) => useSWR( [provider, enabledAutoFetch], diff --git a/src/store/user/slices/modelList/initialState.ts b/src/store/user/slices/modelList/initialState.ts new file mode 100644 index 0000000000000..cbef9dfba45f8 --- /dev/null +++ b/src/store/user/slices/modelList/initialState.ts @@ -0,0 +1,15 @@ +import { DEFAULT_MODEL_PROVIDER_LIST } from '@/config/modelProviders'; +import { ModelProviderCard } from '@/types/llm'; +import { ServerLanguageModel } from '@/types/serverConfig'; + +export interface ModelListState { + defaultModelProviderList: ModelProviderCard[]; + editingCustomCardModel?: { id: string; provider: string } | undefined; + modelProviderList: ModelProviderCard[]; + serverLanguageModel?: ServerLanguageModel; +} + +export const initialModelListState: ModelListState = { + defaultModelProviderList: DEFAULT_MODEL_PROVIDER_LIST, + modelProviderList: DEFAULT_MODEL_PROVIDER_LIST, +}; diff --git a/src/store/user/slices/settings/reducers/customModelCard.test.ts b/src/store/user/slices/modelList/reducers/customModelCard.test.ts similarity index 100% rename from src/store/user/slices/settings/reducers/customModelCard.test.ts rename to src/store/user/slices/modelList/reducers/customModelCard.test.ts diff --git a/src/store/user/slices/settings/reducers/customModelCard.ts b/src/store/user/slices/modelList/reducers/customModelCard.ts similarity index 100% rename from src/store/user/slices/settings/reducers/customModelCard.ts rename to src/store/user/slices/modelList/reducers/customModelCard.ts diff --git a/src/store/user/slices/modelList/selectors/index.ts b/src/store/user/slices/modelList/selectors/index.ts new file mode 100644 index 0000000000000..ef666f33aa5e5 --- /dev/null +++ b/src/store/user/slices/modelList/selectors/index.ts @@ -0,0 +1,2 @@ +export { modelConfigSelectors } from './modelConfig'; +export { modelProviderSelectors } from './modelProvider'; diff --git a/src/store/user/slices/settings/selectors/modelConfig.test.ts b/src/store/user/slices/modelList/selectors/modelConfig.test.ts similarity index 96% rename from src/store/user/slices/settings/selectors/modelConfig.test.ts rename to src/store/user/slices/modelList/selectors/modelConfig.test.ts index 11d8543edf1de..06c7ce35e773c 100644 --- a/src/store/user/slices/settings/selectors/modelConfig.test.ts +++ b/src/store/user/slices/modelList/selectors/modelConfig.test.ts @@ -3,7 +3,8 @@ import { describe, expect, it } from 'vitest'; import { UserStore } from '@/store/user'; import { merge } from '@/utils/merge'; -import { UserSettingsState, initialSettingsState } from '../initialState'; +import { UserState } from '../../../initialState'; +import { UserSettingsState, initialSettingsState } from '../../settings/initialState'; import { modelConfigSelectors } from './modelConfig'; describe('modelConfigSelectors', () => { @@ -121,7 +122,7 @@ describe('modelConfigSelectors', () => { id: 'custom-model-2', provider: 'perplexity', }, - } as UserSettingsState) as unknown as UserStore; + } as UserState) as unknown as UserStore; const currentEditingModelCard = modelConfigSelectors.currentEditingCustomModelCard(s); diff --git a/src/store/user/slices/settings/selectors/modelConfig.ts b/src/store/user/slices/modelList/selectors/modelConfig.ts similarity index 95% rename from src/store/user/slices/settings/selectors/modelConfig.ts rename to src/store/user/slices/modelList/selectors/modelConfig.ts index 4e3b8b113cb31..47ed0720a8f5b 100644 --- a/src/store/user/slices/settings/selectors/modelConfig.ts +++ b/src/store/user/slices/modelList/selectors/modelConfig.ts @@ -1,7 +1,7 @@ import { GlobalLLMProviderKey } from '@/types/settings'; import { UserStore } from '../../../store'; -import { currentLLMSettings, getProviderConfigById } from './settings'; +import { currentLLMSettings, getProviderConfigById } from '../../settings/selectors/settings'; const isProviderEnabled = (provider: GlobalLLMProviderKey) => (s: UserStore) => getProviderConfigById(provider)(s)?.enabled || false; diff --git a/src/store/user/slices/settings/selectors/modelProvider.test.ts b/src/store/user/slices/modelList/selectors/modelProvider.test.ts similarity index 90% rename from src/store/user/slices/settings/selectors/modelProvider.test.ts rename to src/store/user/slices/modelList/selectors/modelProvider.test.ts index 87489eb5bc6f5..7ec7024670bec 100644 --- a/src/store/user/slices/settings/selectors/modelProvider.test.ts +++ b/src/store/user/slices/modelList/selectors/modelProvider.test.ts @@ -2,21 +2,21 @@ import { describe, expect, it } from 'vitest'; import { merge } from '@/utils/merge'; +import { UserState, initialState } from '../../../initialState'; import { UserStore, useUserStore } from '../../../store'; -import { UserSettingsState, initialSettingsState } from '../initialState'; import { getDefaultModeProviderById, modelProviderSelectors } from './modelProvider'; describe('modelProviderSelectors', () => { describe('getDefaultModeProviderById', () => { it('should return the correct ModelProviderCard when provider ID matches', () => { - const s = merge(initialSettingsState, {}) as unknown as UserStore; + const s = merge(initialState, {}) as unknown as UserStore; const result = getDefaultModeProviderById('openai')(s); expect(result).not.toBeUndefined(); }); it('should return undefined when provider ID does not exist', () => { - const s = merge(initialSettingsState, {}) as unknown as UserStore; + const s = merge(initialState, {}) as unknown as UserStore; const result = getDefaultModeProviderById('nonExistingProvider')(s); expect(result).toBeUndefined(); }); @@ -24,7 +24,7 @@ describe('modelProviderSelectors', () => { describe('getModelCardsById', () => { it('should return model cards including custom model cards', () => { - const s = merge(initialSettingsState, { + const s = merge(initialState, { settings: { languageModel: { perplexity: { @@ -32,7 +32,7 @@ describe('modelProviderSelectors', () => { }, }, }, - } as UserSettingsState) as unknown as UserStore; + } as UserState) as unknown as UserStore; const modelCards = modelProviderSelectors.getModelCardsById('perplexity')(s); @@ -46,14 +46,14 @@ describe('modelProviderSelectors', () => { describe('defaultEnabledProviderModels', () => { it('should return enabled models for a given provider', () => { - const s = merge(initialSettingsState, {}) as unknown as UserStore; + const s = merge(initialState, {}) as unknown as UserStore; const result = modelProviderSelectors.getDefaultEnabledModelsById('openai')(s); expect(result).toEqual(['gpt-3.5-turbo', 'gpt-4-turbo', 'gpt-4o']); }); it('should return undefined for a non-existing provider', () => { - const s = merge(initialSettingsState, {}) as unknown as UserStore; + const s = merge(initialState, {}) as unknown as UserStore; const result = modelProviderSelectors.getDefaultEnabledModelsById('nonExistingProvider')(s); expect(result).toBeUndefined(); diff --git a/src/store/user/slices/settings/selectors/modelProvider.ts b/src/store/user/slices/modelList/selectors/modelProvider.ts similarity index 95% rename from src/store/user/slices/settings/selectors/modelProvider.ts rename to src/store/user/slices/modelList/selectors/modelProvider.ts index 7f47f77ddf972..d75da0cf28806 100644 --- a/src/store/user/slices/settings/selectors/modelProvider.ts +++ b/src/store/user/slices/modelList/selectors/modelProvider.ts @@ -6,7 +6,7 @@ import { ServerModelProviderConfig } from '@/types/serverConfig'; import { GlobalLLMProviderKey } from '@/types/settings'; import { UserStore } from '../../../store'; -import { currentSettings, getProviderConfigById } from './settings'; +import { currentSettings, getProviderConfigById } from '../../settings/selectors/settings'; /** * get the server side model cards @@ -14,9 +14,7 @@ import { currentSettings, getProviderConfigById } from './settings'; const serverProviderModelCards = (provider: GlobalLLMProviderKey) => (s: UserStore): ChatModelCard[] | undefined => { - const config = s.serverConfig.languageModel?.[provider] as - | ServerModelProviderConfig - | undefined; + const config = s.serverLanguageModel?.[provider] as ServerModelProviderConfig | undefined; if (!config) return; diff --git a/src/store/user/slices/preference/action.test.ts b/src/store/user/slices/preference/action.test.ts index b2b7434f2566a..ec94f7133dc49 100644 --- a/src/store/user/slices/preference/action.test.ts +++ b/src/store/user/slices/preference/action.test.ts @@ -40,56 +40,4 @@ describe('createPreferenceSlice', () => { expect(result.current.preference.hideSyncAlert).toEqual(true); }); }); - - describe('useInitPreference', () => { - it('should return false when userId is empty', async () => { - const { result } = renderHook(() => useUserStore()); - - vi.spyOn(userService, 'getPreference').mockResolvedValueOnce({} as any); - - const { result: prefernce } = renderHook(() => result.current.useInitPreference(), { - wrapper: withSWR, - }); - - await waitFor(() => { - expect(prefernce.current.data).toEqual({}); - expect(result.current.isPreferenceInit).toBeTruthy(); - }); - }); - it('should return default preference when local storage is empty', async () => { - const { result } = renderHook(() => useUserStore()); - - vi.spyOn(userService, 'getPreference').mockResolvedValueOnce({} as any); - - renderHook(() => result.current.useInitPreference(), { - wrapper: withSWR, - }); - - await waitFor(() => { - expect(result.current.preference).toEqual(DEFAULT_PREFERENCE); - expect(result.current.isPreferenceInit).toBeTruthy(); - }); - }); - - it('should return saved preference when local storage has data', async () => { - const { result } = renderHook(() => useUserStore()); - const savedPreference: UserPreference = { - ...DEFAULT_PREFERENCE, - hideSyncAlert: true, - guide: { topic: false, moveSettingsToAvatar: true }, - }; - - vi.spyOn(userService, 'getPreference').mockResolvedValueOnce(savedPreference); - - const { result: prefernce } = renderHook(() => result.current.useInitPreference(), { - wrapper: withSWR, - }); - - await waitFor(() => { - expect(prefernce.current.data).toEqual(savedPreference); - expect(result.current.isPreferenceInit).toBeTruthy(); - expect(result.current.preference).toEqual(savedPreference); - }); - }); - }); }); diff --git a/src/store/user/slices/preference/action.ts b/src/store/user/slices/preference/action.ts index 010e3cb141cbf..5ed6460d6b79f 100644 --- a/src/store/user/slices/preference/action.ts +++ b/src/store/user/slices/preference/action.ts @@ -1,8 +1,5 @@ -import { SWRResponse } from 'swr'; import type { StateCreator } from 'zustand/vanilla'; -import { DEFAULT_PREFERENCE } from '@/const/user'; -import { useClientDataSWR } from '@/libs/swr'; import { userService } from '@/services/user'; import type { UserStore } from '@/store/user'; import { UserGuide, UserPreference } from '@/types/user'; @@ -14,7 +11,6 @@ const n = setNamespace('preference'); export interface PreferenceAction { updateGuideState: (guide: Partial) => Promise; updatePreference: (preference: Partial, action?: any) => Promise; - useInitPreference: () => SWRResponse; } export const createPreferenceSlice: StateCreator< @@ -28,6 +24,7 @@ export const createPreferenceSlice: StateCreator< const nextGuide = merge(get().preference.guide, guide); await updatePreference({ guide: nextGuide }); }, + updatePreference: async (preference, action) => { const nextPreference = merge(get().preference, preference); @@ -35,17 +32,4 @@ export const createPreferenceSlice: StateCreator< await userService.updatePreference(nextPreference); }, - - useInitPreference: () => - useClientDataSWR('initUserPreference', userService.getPreference, { - onSuccess: (preference) => { - const isEmpty = Object.keys(preference).length === 0; - - set( - { isPreferenceInit: true, preference: isEmpty ? DEFAULT_PREFERENCE : preference }, - false, - n('initPreference'), - ); - }, - }), }); diff --git a/src/store/user/slices/preference/initialState.ts b/src/store/user/slices/preference/initialState.ts index 6b40176fe5acb..5df9afe5e6bbc 100644 --- a/src/store/user/slices/preference/initialState.ts +++ b/src/store/user/slices/preference/initialState.ts @@ -1,18 +1,13 @@ import { DEFAULT_PREFERENCE } from '@/const/user'; import { UserPreference } from '@/types/user'; -import { AsyncLocalStorage } from '@/utils/localStorage'; export interface UserPreferenceState { - isPreferenceInit: boolean; /** * the user preference, which only store in local storage */ preference: UserPreference; - preferenceStorage: AsyncLocalStorage; } export const initialPreferenceState: UserPreferenceState = { - isPreferenceInit: false, preference: DEFAULT_PREFERENCE, - preferenceStorage: new AsyncLocalStorage('LOBE_PREFERENCE'), }; diff --git a/src/store/user/slices/preference/selectors.test.ts b/src/store/user/slices/preference/selectors.test.ts index a28403ef57e53..9881a88755fc4 100644 --- a/src/store/user/slices/preference/selectors.test.ts +++ b/src/store/user/slices/preference/selectors.test.ts @@ -72,10 +72,10 @@ describe('preferenceSelectors', () => { describe('isPreferenceInit', () => { it('should return the value of isPreferenceInit state', () => { - store.isPreferenceInit = true; + store.isUserStateInit = true; expect(preferenceSelectors.isPreferenceInit(store)).toBe(true); - store.isPreferenceInit = false; + store.isUserStateInit = false; expect(preferenceSelectors.isPreferenceInit(store)).toBe(false); }); }); diff --git a/src/store/user/slices/preference/selectors.ts b/src/store/user/slices/preference/selectors.ts index 53c3eef35e9d5..ef7060b92ab11 100644 --- a/src/store/user/slices/preference/selectors.ts +++ b/src/store/user/slices/preference/selectors.ts @@ -8,7 +8,7 @@ const hideSyncAlert = (s: UserStore) => s.preference.hideSyncAlert; const hideSettingsMoveGuide = (s: UserStore) => s.preference.guide?.moveSettingsToAvatar; -const isPreferenceInit = (s: UserStore) => s.isPreferenceInit; +const isPreferenceInit = (s: UserStore) => s.isUserStateInit; export const preferenceSelectors = { hideSettingsMoveGuide, diff --git a/src/store/user/slices/settings/actions/general.test.ts b/src/store/user/slices/settings/action.test.ts similarity index 100% rename from src/store/user/slices/settings/actions/general.test.ts rename to src/store/user/slices/settings/action.test.ts diff --git a/src/store/user/slices/settings/actions/general.ts b/src/store/user/slices/settings/action.ts similarity index 92% rename from src/store/user/slices/settings/actions/general.ts rename to src/store/user/slices/settings/action.ts index e855c71e5e5db..8e9085f01c1b2 100644 --- a/src/store/user/slices/settings/actions/general.ts +++ b/src/store/user/slices/settings/action.ts @@ -12,7 +12,7 @@ import { switchLang } from '@/utils/client/switchLang'; import { difference } from '@/utils/difference'; import { merge } from '@/utils/merge'; -export interface GeneralSettingsAction { +export interface UserSettingsAction { importAppSettings: (settings: GlobalSettings) => Promise; resetSettings: () => Promise; setSettings: (settings: DeepPartial) => Promise; @@ -22,11 +22,11 @@ export interface GeneralSettingsAction { updateDefaultAgent: (agent: DeepPartial) => Promise; } -export const generalSettingsSlice: StateCreator< +export const createSettingsSlice: StateCreator< UserStore, [['zustand/devtools', never]], [], - GeneralSettingsAction + UserSettingsAction > = (set, get) => ({ importAppSettings: async (importAppSettings) => { const { setSettings } = get(); @@ -37,7 +37,7 @@ export const generalSettingsSlice: StateCreator< }, resetSettings: async () => { await userService.resetUserSettings(); - await get().refreshUserConfig(); + await get().refreshUserState(); }, setSettings: async (settings) => { const { settings: prevSetting, defaultSettings } = get(); @@ -49,7 +49,7 @@ export const generalSettingsSlice: StateCreator< const diffs = difference(nextSettings, defaultSettings); await userService.updateUserSettings(diffs); - await get().refreshUserConfig(); + await get().refreshUserState(); }, setTranslationSystemAgent: async (provider, model) => { await get().setSettings({ diff --git a/src/store/user/slices/settings/actions/index.ts b/src/store/user/slices/settings/actions/index.ts deleted file mode 100644 index 2200e0ee08a3c..0000000000000 --- a/src/store/user/slices/settings/actions/index.ts +++ /dev/null @@ -1,18 +0,0 @@ -import type { StateCreator } from 'zustand/vanilla'; - -import type { UserStore } from '@/store/user'; - -import { GeneralSettingsAction, generalSettingsSlice } from './general'; -import { LLMSettingsAction, llmSettingsSlice } from './llm'; - -export interface SettingsAction extends LLMSettingsAction, GeneralSettingsAction {} - -export const createSettingsSlice: StateCreator< - UserStore, - [['zustand/devtools', never]], - [], - SettingsAction -> = (...params) => ({ - ...llmSettingsSlice(...params), - ...generalSettingsSlice(...params), -}); diff --git a/src/store/user/slices/settings/actions/llm.test.ts b/src/store/user/slices/settings/actions/llm.test.ts deleted file mode 100644 index 8e0af72a2d371..0000000000000 --- a/src/store/user/slices/settings/actions/llm.test.ts +++ /dev/null @@ -1,136 +0,0 @@ -import { act, renderHook } from '@testing-library/react'; -import { describe, expect, it, vi } from 'vitest'; - -import { userService } from '@/services/user'; -import { useUserStore } from '@/store/user'; -import { GeneralModelProviderConfig } from '@/types/settings'; - -import { CustomModelCardDispatch } from '../reducers/customModelCard'; -import { modelProviderSelectors, settingsSelectors } from '../selectors'; - -// Mock userService -vi.mock('@/services/user', () => ({ - userService: { - updateUserSettings: vi.fn(), - resetUserSettings: vi.fn(), - }, -})); - -describe('LLMSettingsSliceAction', () => { - describe('setModelProviderConfig', () => { - it('should set OpenAI configuration', async () => { - const { result } = renderHook(() => useUserStore()); - const openAIConfig: Partial = { apiKey: 'test-key' }; - - // Perform the action - await act(async () => { - await result.current.setModelProviderConfig('openai', openAIConfig); - }); - - // Assert that updateUserSettings was called with the correct OpenAI configuration - expect(userService.updateUserSettings).toHaveBeenCalledWith({ - languageModel: { - openai: openAIConfig, - }, - }); - }); - }); - - describe('dispatchCustomModelCards', () => { - it('should return early when prevState does not exist', async () => { - const { result } = renderHook(() => useUserStore()); - const provider = 'openai'; - const payload: CustomModelCardDispatch = { type: 'add', modelCard: { id: 'test-id' } }; - - // Mock the selector to return undefined - vi.spyOn(settingsSelectors, 'providerConfig').mockReturnValue(() => undefined); - vi.spyOn(result.current, 'setModelProviderConfig'); - - await act(async () => { - await result.current.dispatchCustomModelCards(provider, payload); - }); - - // Assert that setModelProviderConfig was not called - expect(result.current.setModelProviderConfig).not.toHaveBeenCalled(); - }); - }); - - describe('refreshDefaultModelProviderList', () => { - it('default', async () => { - const { result } = renderHook(() => useUserStore()); - - act(() => { - useUserStore.setState({ - serverConfig: { - languageModel: { - azure: { serverModelCards: [{ id: 'abc', deploymentName: 'abc' }] }, - }, - telemetry: {}, - }, - }); - }); - - act(() => { - result.current.refreshDefaultModelProviderList(); - }); - - // Assert that setModelProviderConfig was not called - const azure = result.current.defaultModelProviderList.find((m) => m.id === 'azure'); - expect(azure?.chatModels).toEqual([{ id: 'abc', deploymentName: 'abc' }]); - }); - }); - - describe('refreshModelProviderList', () => { - it('visible', async () => { - const { result } = renderHook(() => useUserStore()); - act(() => { - useUserStore.setState({ - settings: { - languageModel: { - ollama: { enabledModels: ['llava'] }, - }, - }, - }); - }); - - act(() => { - result.current.refreshModelProviderList(); - }); - - const ollamaList = result.current.modelProviderList.find((r) => r.id === 'ollama'); - // Assert that setModelProviderConfig was not called - expect(ollamaList?.chatModels.find((c) => c.id === 'llava')).toEqual({ - displayName: 'LLaVA 7B', - enabled: true, - id: 'llava', - tokens: 4096, - vision: true, - }); - }); - - it('modelProviderListForModelSelect should return only enabled providers', () => { - const { result } = renderHook(() => useUserStore()); - - act(() => { - useUserStore.setState({ - settings: { - languageModel: { - perplexity: { enabled: true }, - azure: { enabled: false }, - }, - }, - }); - }); - - act(() => { - result.current.refreshModelProviderList(); - }); - - const enabledProviders = modelProviderSelectors.modelProviderListForModelSelect( - result.current, - ); - expect(enabledProviders).toHaveLength(3); - expect(enabledProviders.at(-1)!.id).toBe('perplexity'); - }); - }); -}); diff --git a/src/store/user/slices/settings/initialState.ts b/src/store/user/slices/settings/initialState.ts index 7e47d1735e52f..85e41c3a3b7e8 100644 --- a/src/store/user/slices/settings/initialState.ts +++ b/src/store/user/slices/settings/initialState.ts @@ -1,26 +1,14 @@ import { DeepPartial } from 'utility-types'; -import { DEFAULT_MODEL_PROVIDER_LIST } from '@/config/modelProviders'; import { DEFAULT_SETTINGS } from '@/const/settings'; -import { ModelProviderCard } from '@/types/llm'; -import { GlobalServerConfig } from '@/types/serverConfig'; import { GlobalSettings } from '@/types/settings'; export interface UserSettingsState { - defaultModelProviderList: ModelProviderCard[]; defaultSettings: GlobalSettings; - editingCustomCardModel?: { id: string; provider: string } | undefined; - modelProviderList: ModelProviderCard[]; - serverConfig: GlobalServerConfig; settings: DeepPartial; } export const initialSettingsState: UserSettingsState = { - defaultModelProviderList: DEFAULT_MODEL_PROVIDER_LIST, defaultSettings: DEFAULT_SETTINGS, - modelProviderList: DEFAULT_MODEL_PROVIDER_LIST, - serverConfig: { - telemetry: {}, - }, settings: {}, }; diff --git a/src/store/user/slices/settings/selectors/__snapshots__/selectors.test.ts.snap b/src/store/user/slices/settings/selectors/__snapshots__/settings.test.ts.snap similarity index 100% rename from src/store/user/slices/settings/selectors/__snapshots__/selectors.test.ts.snap rename to src/store/user/slices/settings/selectors/__snapshots__/settings.test.ts.snap diff --git a/src/store/user/slices/settings/selectors/index.ts b/src/store/user/slices/settings/selectors/index.ts index 9e625b82c4d10..a436d51938f73 100644 --- a/src/store/user/slices/settings/selectors/index.ts +++ b/src/store/user/slices/settings/selectors/index.ts @@ -1,5 +1,2 @@ -export { modelConfigSelectors } from './modelConfig'; -export { modelProviderSelectors } from './modelProvider'; export { settingsSelectors } from './settings'; -export { syncSettingsSelectors } from './sync'; export { systemAgentSelectors } from './systemAgent'; diff --git a/src/store/user/slices/settings/selectors/selectors.test.ts b/src/store/user/slices/settings/selectors/settings.test.ts similarity index 100% rename from src/store/user/slices/settings/selectors/selectors.test.ts rename to src/store/user/slices/settings/selectors/settings.test.ts diff --git a/src/store/user/slices/sync/action.test.ts b/src/store/user/slices/sync/action.test.ts index 72a1aa972889d..fd44759a270b9 100644 --- a/src/store/user/slices/sync/action.test.ts +++ b/src/store/user/slices/sync/action.test.ts @@ -107,9 +107,17 @@ describe('createSyncSlice', () => { describe('useEnabledSync', () => { it('should return false when userId is empty', async () => { - const { result } = renderHook(() => useUserStore().useEnabledSync(true, undefined, vi.fn()), { - wrapper: withSWR, - }); + const { result } = renderHook( + () => + useUserStore().useEnabledSync(true, { + userEnableSync: true, + userId: undefined, + onEvent: vi.fn(), + }), + { + wrapper: withSWR, + }, + ); await waitFor(() => expect(result.current.data).toBe(false)); }); @@ -118,7 +126,13 @@ describe('createSyncSlice', () => { const disableSyncSpy = vi.spyOn(syncService, 'disableSync').mockResolvedValueOnce(false); const { result } = renderHook( - () => useUserStore().useEnabledSync(false, 'user-id', vi.fn()), + () => + useUserStore().useEnabledSync(true, { + userEnableSync: false, + userId: 'user-id', + onEvent: vi.fn(), + }), + { wrapper: withSWR }, ); @@ -137,7 +151,7 @@ describe('createSyncSlice', () => { result.current.triggerEnableSync = triggerEnableSyncSpy; const { result: swrResult } = renderHook( - () => result.current.useEnabledSync(true, userId, onEvent), + () => result.current.useEnabledSync(true, { userEnableSync: true, userId, onEvent }), { wrapper: withSWR, }, diff --git a/src/store/user/slices/sync/action.ts b/src/store/user/slices/sync/action.ts index caf2a4d95a784..4020589f2e3e6 100644 --- a/src/store/user/slices/sync/action.ts +++ b/src/store/user/slices/sync/action.ts @@ -8,7 +8,7 @@ import { browserInfo } from '@/utils/platform'; import { setNamespace } from '@/utils/storeDebug'; import { userProfileSelectors } from '../auth/selectors'; -import { syncSettingsSelectors } from '../settings/selectors'; +import { syncSettingsSelectors } from './selectors'; const n = setNamespace('sync'); @@ -19,9 +19,12 @@ export interface SyncAction { refreshConnection: (onEvent: OnSyncEvent) => Promise; triggerEnableSync: (userId: string, onEvent: OnSyncEvent) => Promise; useEnabledSync: ( - userEnableSync: boolean, - userId: string | undefined, - onEvent: OnSyncEvent, + systemEnable: boolean | undefined, + params: { + onEvent: OnSyncEvent; + userEnableSync: boolean; + userId: string | undefined; + }, ) => SWRResponse; } @@ -72,9 +75,9 @@ export const createSyncSlice: StateCreator< }); }, - useEnabledSync: (userEnableSync, userId, onEvent) => + useEnabledSync: (systemEnable, { userEnableSync, userId, onEvent }) => useSWR( - ['enableSync', userEnableSync, userId], + systemEnable ? ['enableSync', userEnableSync, userId] : null, async () => { // if user don't enable sync or no userId ,don't start sync if (!userId) return false; diff --git a/src/store/user/slices/settings/selectors/sync.ts b/src/store/user/slices/sync/selectors.ts similarity index 78% rename from src/store/user/slices/settings/selectors/sync.ts rename to src/store/user/slices/sync/selectors.ts index b9dd988431f3d..4fa3db52e42fd 100644 --- a/src/store/user/slices/settings/selectors/sync.ts +++ b/src/store/user/slices/sync/selectors.ts @@ -1,5 +1,5 @@ -import { UserStore } from '../../../store'; -import { currentSettings } from './settings'; +import { UserStore } from '../../store'; +import { currentSettings } from '../settings/selectors/settings'; const webrtcConfig = (s: UserStore) => currentSettings(s).sync.webrtc; const webrtcChannelName = (s: UserStore) => webrtcConfig(s).channelName; diff --git a/src/store/user/store.ts b/src/store/user/store.ts index f1a1e5f895b56..d6e7e0fb4136f 100644 --- a/src/store/user/store.ts +++ b/src/store/user/store.ts @@ -8,16 +8,18 @@ import { isDev } from '@/utils/env'; import { type UserState, initialState } from './initialState'; import { type UserAuthAction, createAuthSlice } from './slices/auth/action'; import { type CommonAction, createCommonSlice } from './slices/common/action'; +import { type ModelListAction, createModelListSlice } from './slices/modelList/action'; import { type PreferenceAction, createPreferenceSlice } from './slices/preference/action'; -import { type SettingsAction, createSettingsSlice } from './slices/settings/actions'; +import { type UserSettingsAction, createSettingsSlice } from './slices/settings/action'; import { type SyncAction, createSyncSlice } from './slices/sync/action'; // =============== 聚合 createStoreFn ============ // export type UserStore = SyncAction & UserState & - SettingsAction & + UserSettingsAction & PreferenceAction & + ModelListAction & UserAuthAction & CommonAction; @@ -28,6 +30,7 @@ const createStore: StateCreator = (... ...createPreferenceSlice(...parameters), ...createAuthSlice(...parameters), ...createCommonSlice(...parameters), + ...createModelListSlice(...parameters), }); // =============== 实装 useStore ============ // diff --git a/src/types/serverConfig.ts b/src/types/serverConfig.ts index 871733ae9fa8b..742f657f8ac72 100644 --- a/src/types/serverConfig.ts +++ b/src/types/serverConfig.ts @@ -13,12 +13,14 @@ export interface ServerModelProviderConfig { serverModelCards?: ChatModelCard[]; } +export type ServerLanguageModel = Partial>; + export interface GlobalServerConfig { defaultAgent?: DeepPartial; enableUploadFileToServer?: boolean; enabledAccessCode?: boolean; enabledOAuthSSO?: boolean; - languageModel?: Partial>; + languageModel?: ServerLanguageModel; telemetry: { langfuse?: boolean; }; diff --git a/src/types/user/index.ts b/src/types/user/index.ts index 2fa4b9e800632..795ed903023a4 100644 --- a/src/types/user/index.ts +++ b/src/types/user/index.ts @@ -1,3 +1,7 @@ +import { DeepPartial } from 'utility-types'; + +import { GlobalSettings } from '@/types/settings'; + export interface LobeUser { avatar?: string; email?: string | null; @@ -27,3 +31,12 @@ export interface UserPreference { */ useCmdEnterToSend?: boolean; } + +export interface UserInitializationState { + avatar?: string; + canEnableTrace?: boolean; + isOnboard?: boolean; + preference: UserPreference; + settings: DeepPartial; + userId: string; +} diff --git a/src/utils/parseModels.test.ts b/src/utils/parseModels.test.ts index d04d035281046..20c48baa3d9e6 100644 --- a/src/utils/parseModels.test.ts +++ b/src/utils/parseModels.test.ts @@ -1,6 +1,9 @@ import { describe, expect, it } from 'vitest'; -import { parseModelString } from './parseModels'; +import { LOBE_DEFAULT_MODEL_LIST, OpenAIProviderCard } from '@/config/modelProviders'; +import { ChatModelCard } from '@/types/llm'; + +import { parseModelString, transformToChatModelCards } from './parseModels'; describe('parseModelString', () => { it('custom deletion, addition, and renaming of models', () => { @@ -67,6 +70,29 @@ describe('parseModelString', () => { ]); }); + it('should have file with builtin models like gpt-4-0125-preview', () => { + const result = parseModelString( + '-all,+gpt-4-0125-preview=ChatGPT-4<128000:fc:file>,+gpt-4-turbo-2024-04-09=ChatGPT-4 Vision<128000:fc:vision:file>', + ); + expect(result.add).toEqual([ + { + displayName: 'ChatGPT-4', + files: true, + functionCall: true, + id: 'gpt-4-0125-preview', + tokens: 128000, + }, + { + displayName: 'ChatGPT-4 Vision', + files: true, + functionCall: true, + id: 'gpt-4-turbo-2024-04-09', + tokens: 128000, + vision: true, + }, + ]); + }); + it('should handle empty extension capability value', () => { const result = parseModelString('model1<1024:>'); expect(result.add[0]).toEqual({ id: 'model1', tokens: 1024 }); @@ -168,3 +194,97 @@ describe('parseModelString', () => { }); }); }); + +describe('transformToChatModelCards', () => { + const defaultChatModels: ChatModelCard[] = [ + { id: 'model1', displayName: 'Model 1', enabled: true }, + { id: 'model2', displayName: 'Model 2', enabled: false }, + ]; + + it('should return undefined when modelString is empty', () => { + const result = transformToChatModelCards({ + modelString: '', + defaultChatModels, + }); + expect(result).toBeUndefined(); + }); + + it('should remove all models when removeAll is true', () => { + const result = transformToChatModelCards({ + modelString: '-all', + defaultChatModels, + }); + expect(result).toEqual([]); + }); + + it('should remove specified models', () => { + const result = transformToChatModelCards({ + modelString: '-model1', + defaultChatModels, + }); + expect(result).toEqual([{ id: 'model2', displayName: 'Model 2', enabled: false }]); + }); + + it('should add a new known model', () => { + const knownModel = LOBE_DEFAULT_MODEL_LIST[0]; + const result = transformToChatModelCards({ + modelString: `${knownModel.id}`, + defaultChatModels, + }); + expect(result).toContainEqual({ + ...knownModel, + displayName: knownModel.displayName || knownModel.id, + enabled: true, + }); + }); + + it('should update an existing known model', () => { + const knownModel = LOBE_DEFAULT_MODEL_LIST[0]; + const result = transformToChatModelCards({ + modelString: `+${knownModel.id}=Updated Model`, + defaultChatModels: [knownModel], + }); + expect(result![0]).toEqual({ ...knownModel, displayName: 'Updated Model', enabled: true }); + }); + + it('should add a new custom model', () => { + const result = transformToChatModelCards({ + modelString: '+custom_model=Custom Model', + defaultChatModels, + }); + expect(result).toContainEqual({ + id: 'custom_model', + displayName: 'Custom Model', + enabled: true, + }); + }); + + it('should have file with builtin models like gpt-4-0125-preview', () => { + const result = transformToChatModelCards({ + modelString: + '-all,+gpt-4-0125-preview=ChatGPT-4<128000:fc:file>,+gpt-4-turbo-2024-04-09=ChatGPT-4 Vision<128000:fc:vision:file>', + defaultChatModels: OpenAIProviderCard.chatModels, + }); + + expect(result).toEqual([ + { + displayName: 'ChatGPT-4', + files: true, + functionCall: true, + enabled: true, + id: 'gpt-4-0125-preview', + tokens: 128000, + }, + { + description: 'GPT-4 Turbo 视觉版 (240409)', + displayName: 'ChatGPT-4 Vision', + files: true, + functionCall: true, + enabled: true, + id: 'gpt-4-turbo-2024-04-09', + tokens: 128000, + vision: true, + }, + ]); + }); +}); diff --git a/src/utils/parseModels.ts b/src/utils/parseModels.ts index 0ea17cfc2b4dc..63507afe1b1d1 100644 --- a/src/utils/parseModels.ts +++ b/src/utils/parseModels.ts @@ -114,17 +114,22 @@ export const transformToChatModelCards = ({ // if the model is known, update it based on the known model if (knownModel) { - const modelInList = draft.find((model) => model.id === toAddModel.id); + const index = draft.findIndex((model) => model.id === toAddModel.id); + const modelInList = draft[index]; // if the model is already in chatModels, update it if (modelInList) { - // if (modelInList.hidden) delete modelInList.hidden; - modelInList.enabled = true; - if (toAddModel.displayName) modelInList.displayName = toAddModel.displayName; + draft[index] = { + ...modelInList, + ...toAddModel, + displayName: toAddModel.displayName || modelInList.displayName || modelInList.id, + enabled: true, + }; } else { // if the model is not in chatModels, add it draft.push({ ...knownModel, + ...toAddModel, displayName: toAddModel.displayName || knownModel.displayName || knownModel.id, enabled: true, });