From 60dcf19037bc818bc1287b6ece96b0a7ebda3d6f Mon Sep 17 00:00:00 2001 From: Arvin Xu Date: Fri, 20 Sep 2024 00:33:08 +0800 Subject: [PATCH] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20refactor:=20refactor=20the?= =?UTF-8?q?=20tts=20route=20url=20(#4030)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ♻️ refactor: refactor the tts to new route * ♻️ refactor: refactor the tts to new route --- src/app/api/openai/createBizOpenAI/index.ts | 1 + .../openai/stt => webapi/stt/openai}/route.ts | 0 .../edge-speech => webapi/tts/edge}/route.ts | 0 .../tts/microsoft}/route.ts | 0 .../openai/tts => webapi/tts/openai}/route.ts | 1 + src/const/fetch.ts | 4 +++- src/libs/agent-runtime/AgentRuntime.ts | 4 ++++ src/libs/agent-runtime/BaseAI.ts | 9 ++++++++- src/libs/agent-runtime/types/index.ts | 1 + src/libs/agent-runtime/types/tts.ts | 14 ++++++++++++++ .../utils/openaiCompatibleFactory/index.ts | 17 ++++++++++++++++- src/services/_header.ts | 13 ++++++++++--- src/services/_url.ts | 14 ++++++++------ src/store/file/slices/tts/action.ts | 2 +- src/store/file/slices/upload/action.ts | 16 +++++++++++----- 15 files changed, 78 insertions(+), 18 deletions(-) rename src/app/{api/openai/stt => webapi/stt/openai}/route.ts (100%) rename src/app/{api/tts/edge-speech => webapi/tts/edge}/route.ts (100%) rename src/app/{api/tts/microsoft-speech => webapi/tts/microsoft}/route.ts (100%) rename src/app/{api/openai/tts => webapi/tts/openai}/route.ts (94%) create mode 100644 src/libs/agent-runtime/types/tts.ts diff --git a/src/app/api/openai/createBizOpenAI/index.ts b/src/app/api/openai/createBizOpenAI/index.ts index 0742ca512d14..ce95a858d39c 100644 --- a/src/app/api/openai/createBizOpenAI/index.ts +++ b/src/app/api/openai/createBizOpenAI/index.ts @@ -8,6 +8,7 @@ import { checkAuth } from './auth'; import { createOpenai } from './createOpenai'; /** + * @deprecated * createOpenAI Instance with Auth and azure openai support * if auth not pass ,just return error response */ diff --git a/src/app/api/openai/stt/route.ts b/src/app/webapi/stt/openai/route.ts similarity index 100% rename from src/app/api/openai/stt/route.ts rename to src/app/webapi/stt/openai/route.ts diff --git a/src/app/api/tts/edge-speech/route.ts b/src/app/webapi/tts/edge/route.ts similarity index 100% rename from src/app/api/tts/edge-speech/route.ts rename to src/app/webapi/tts/edge/route.ts diff --git a/src/app/api/tts/microsoft-speech/route.ts b/src/app/webapi/tts/microsoft/route.ts similarity index 100% rename from src/app/api/tts/microsoft-speech/route.ts rename to src/app/webapi/tts/microsoft/route.ts diff --git a/src/app/api/openai/tts/route.ts b/src/app/webapi/tts/openai/route.ts similarity index 94% rename from src/app/api/openai/tts/route.ts rename to src/app/webapi/tts/openai/route.ts index f263c8c65d2d..4b1ac6ada76a 100644 --- a/src/app/api/openai/tts/route.ts +++ b/src/app/webapi/tts/openai/route.ts @@ -28,6 +28,7 @@ export const preferredRegion = [ export const POST = async (req: Request) => { const payload = (await req.json()) as OpenAITTSPayload; + // need to be refactored with jwt auth mode const openaiOrErrResponse = createBizOpenAI(req); // if resOrOpenAI is a Response, it means there is an error,just return it diff --git a/src/const/fetch.ts b/src/const/fetch.ts index 5d6a5361dc49..483a283b86b0 100644 --- a/src/const/fetch.ts +++ b/src/const/fetch.ts @@ -1,5 +1,6 @@ export const OPENAI_END_POINT = 'X-openai-end-point'; export const OPENAI_API_KEY_HEADER_KEY = 'X-openai-api-key'; +export const LOBE_USER_ID = 'X-lobe-user-id'; export const USE_AZURE_OPENAI = 'X-use-azure-openai'; @@ -19,9 +20,10 @@ export const getOpenAIAuthFromRequest = (req: Request) => { const useAzureStr = req.headers.get(USE_AZURE_OPENAI); const apiVersion = req.headers.get(AZURE_OPENAI_API_VERSION); const oauthAuthorizedStr = req.headers.get(OAUTH_AUTHORIZED); + const userId = req.headers.get(LOBE_USER_ID); const oauthAuthorized = !!oauthAuthorizedStr; const useAzure = !!useAzureStr; - return { accessCode, apiKey, apiVersion, endpoint, oauthAuthorized, useAzure }; + return { accessCode, apiKey, apiVersion, endpoint, oauthAuthorized, useAzure, userId }; }; diff --git a/src/libs/agent-runtime/AgentRuntime.ts b/src/libs/agent-runtime/AgentRuntime.ts index 05ce1a7ef525..6a7fedba0329 100644 --- a/src/libs/agent-runtime/AgentRuntime.ts +++ b/src/libs/agent-runtime/AgentRuntime.ts @@ -35,6 +35,7 @@ import { EmbeddingsPayload, ModelProvider, TextToImagePayload, + TextToSpeechPayload, } from './types'; import { LobeUpstageAI } from './upstage'; import { LobeZeroOneAI } from './zeroone'; @@ -97,6 +98,9 @@ class AgentRuntime { async embeddings(payload: EmbeddingsPayload, options?: EmbeddingsOptions) { return this._runtime.embeddings?.(payload, options); } + async textToSpeech(payload: TextToSpeechPayload, options?: EmbeddingsOptions) { + return this._runtime.textToSpeech?.(payload, options); + } /** * @description Initialize the runtime with the provider and the options diff --git a/src/libs/agent-runtime/BaseAI.ts b/src/libs/agent-runtime/BaseAI.ts index c529491dfcc2..3783ea56f8f7 100644 --- a/src/libs/agent-runtime/BaseAI.ts +++ b/src/libs/agent-runtime/BaseAI.ts @@ -1,6 +1,5 @@ import OpenAI from 'openai'; -import { TextToImagePayload } from '@/libs/agent-runtime/types/textToImage'; import { ChatModelCard } from '@/types/llm'; import { @@ -9,6 +8,9 @@ import { EmbeddingItem, EmbeddingsOptions, EmbeddingsPayload, + TextToImagePayload, + TextToSpeechOptions, + TextToSpeechPayload, } from './types'; export interface LobeRuntimeAI { @@ -20,6 +22,11 @@ export interface LobeRuntimeAI { models?(): Promise; textToImage?: (payload: TextToImagePayload) => Promise; + + textToSpeech?: ( + payload: TextToSpeechPayload, + options?: TextToSpeechOptions, + ) => Promise; } export abstract class LobeOpenAICompatibleRuntime { diff --git a/src/libs/agent-runtime/types/index.ts b/src/libs/agent-runtime/types/index.ts index e6ea1ca3a7af..e17d3a324076 100644 --- a/src/libs/agent-runtime/types/index.ts +++ b/src/libs/agent-runtime/types/index.ts @@ -1,4 +1,5 @@ export * from './chat'; export * from './embeddings'; export * from './textToImage'; +export * from './tts'; export * from './type'; diff --git a/src/libs/agent-runtime/types/tts.ts b/src/libs/agent-runtime/types/tts.ts new file mode 100644 index 000000000000..a72ee96c1971 --- /dev/null +++ b/src/libs/agent-runtime/types/tts.ts @@ -0,0 +1,14 @@ +export interface TextToSpeechPayload { + input: string; + model: string; + voice: string; +} + +export interface TextToSpeechOptions { + headers?: Record; + signal?: AbortSignal; + /** + * userId for the embeddings + */ + user?: string; +} diff --git a/src/libs/agent-runtime/utils/openaiCompatibleFactory/index.ts b/src/libs/agent-runtime/utils/openaiCompatibleFactory/index.ts index f324b9c68f38..90331b73c63f 100644 --- a/src/libs/agent-runtime/utils/openaiCompatibleFactory/index.ts +++ b/src/libs/agent-runtime/utils/openaiCompatibleFactory/index.ts @@ -1,7 +1,6 @@ import OpenAI, { ClientOptions } from 'openai'; import { LOBE_DEFAULT_MODEL_LIST } from '@/config/modelProviders'; -import { TextToImagePayload } from '@/libs/agent-runtime/types/textToImage'; import { ChatModelCard } from '@/types/llm'; import { LobeRuntimeAI } from '../../BaseAI'; @@ -13,6 +12,9 @@ import { EmbeddingItem, EmbeddingsOptions, EmbeddingsPayload, + TextToImagePayload, + TextToSpeechOptions, + TextToSpeechPayload, } from '../../types'; import { AgentRuntimeError } from '../createError'; import { debugResponse, debugStream } from '../debugStream'; @@ -253,6 +255,19 @@ export const LobeOpenAICompatibleFactory = = any> } } + async textToSpeech(payload: TextToSpeechPayload, options?: TextToSpeechOptions) { + try { + const mp3 = await this.client.audio.speech.create(payload as any, { + headers: options?.headers, + signal: options?.signal, + }); + + return mp3.arrayBuffer(); + } catch (error) { + throw this.handleError(error); + } + } + private handleError(error: any): ChatCompletionErrorPayload { let desensitizedEndpoint = this.baseURL; diff --git a/src/services/_header.ts b/src/services/_header.ts index 0795176510bb..e4c2e6d63740 100644 --- a/src/services/_header.ts +++ b/src/services/_header.ts @@ -1,4 +1,9 @@ -import { LOBE_CHAT_ACCESS_CODE, OPENAI_API_KEY_HEADER_KEY, OPENAI_END_POINT } from '@/const/fetch'; +import { + LOBE_CHAT_ACCESS_CODE, + LOBE_USER_ID, + OPENAI_API_KEY_HEADER_KEY, + OPENAI_END_POINT, +} from '@/const/fetch'; import { useUserStore } from '@/store/user'; import { keyVaultsConfigSelectors } from '@/store/user/selectors'; @@ -8,12 +13,14 @@ import { keyVaultsConfigSelectors } from '@/store/user/selectors'; */ // eslint-disable-next-line no-undef export const createHeaderWithOpenAI = (header?: HeadersInit): HeadersInit => { - const openAIConfig = keyVaultsConfigSelectors.openAIConfig(useUserStore.getState()); + const state = useUserStore.getState(); + const openAIConfig = keyVaultsConfigSelectors.openAIConfig(state); // eslint-disable-next-line no-undef return { ...header, - [LOBE_CHAT_ACCESS_CODE]: keyVaultsConfigSelectors.password(useUserStore.getState()), + [LOBE_CHAT_ACCESS_CODE]: keyVaultsConfigSelectors.password(state), + [LOBE_USER_ID]: state.user?.id || '', [OPENAI_API_KEY_HEADER_KEY]: openAIConfig.apiKey || '', [OPENAI_END_POINT]: openAIConfig.baseURL || '', }; diff --git a/src/services/_url.ts b/src/services/_url.ts index 8f03525472e6..1bbcee7672a7 100644 --- a/src/services/_url.ts +++ b/src/services/_url.ts @@ -1,4 +1,4 @@ -// TODO: 未来所有路由需要全部迁移到 trpc +// TODO: 未来路由需要迁移到 trpc or /webapi /* eslint-disable sort-keys-fix/sort-keys-fix */ import { transform } from 'lodash-es'; @@ -38,9 +38,11 @@ export const API_ENDPOINTS = mapWithBasePath({ // image images: '/api/text-to-image/openai', - // TTS & STT - stt: '/api/openai/stt', - tts: '/api/openai/tts', - edge: '/api/tts/edge-speech', - microsoft: '/api/tts/microsoft-speech', + // STT + stt: '/webapi/stt/openai', + + // TTS + tts: '/webapi/tts/openai', + edge: '/webapi/tts/edge', + microsoft: '/webapi/tts/microsoft', }); diff --git a/src/store/file/slices/tts/action.ts b/src/store/file/slices/tts/action.ts index 90b5c7c39d0b..940d01f97900 100644 --- a/src/store/file/slices/tts/action.ts +++ b/src/store/file/slices/tts/action.ts @@ -39,7 +39,7 @@ export const createTTSFileSlice: StateCreator< }; const file = new File([blob], fileName, fileOptions); - const res = await get().uploadWithProgress({ file }); + const res = await get().uploadWithProgress({ file, skipCheckFileType: true }); return res?.id; }, diff --git a/src/store/file/slices/upload/action.ts b/src/store/file/slices/upload/action.ts index 0353bee8f5ce..1134d7907f43 100644 --- a/src/store/file/slices/upload/action.ts +++ b/src/store/file/slices/upload/action.ts @@ -29,6 +29,12 @@ interface UploadWithProgressParams { type: 'removeFile'; }, ) => void; + /** + * Optional flag to indicate whether to skip the file type check. + * When set to `true`, any file type checks will be bypassed. + * Default is `false`, which means file type checks will be performed. + */ + skipCheckFileType?: boolean; } interface UploadWithProgressResult { @@ -52,8 +58,8 @@ export const createFileUploadSlice: StateCreator< [], FileUploadAction > = (set, get) => ({ - internal_uploadToClientDB: async ({ file, onStatusUpdate }) => { - if (!file.type.startsWith('image')) { + internal_uploadToClientDB: async ({ file, onStatusUpdate, skipCheckFileType }) => { + if (!skipCheckFileType && !file.type.startsWith('image')) { onStatusUpdate?.({ id: file.name, type: 'removeFile' }); message.info({ content: t('upload.fileOnlySupportInServerMode', { @@ -158,11 +164,11 @@ export const createFileUploadSlice: StateCreator< return data; }, - uploadWithProgress: async ({ file, onStatusUpdate, knowledgeBaseId }) => { + uploadWithProgress: async (payload) => { const { internal_uploadToServer, internal_uploadToClientDB } = get(); - if (isServerMode) return internal_uploadToServer({ file, knowledgeBaseId, onStatusUpdate }); + if (isServerMode) return internal_uploadToServer(payload); - return internal_uploadToClientDB({ file, onStatusUpdate }); + return internal_uploadToClientDB(payload); }, });