diff --git a/api/db/services/llm_service.py b/api/db/services/llm_service.py index 89e5593b36..74515f9951 100644 --- a/api/db/services/llm_service.py +++ b/api/db/services/llm_service.py @@ -195,7 +195,7 @@ def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese"): self.llm_name = llm_name self.mdl = TenantLLMService.model_instance( tenant_id, llm_type, llm_name, lang=lang) - assert self.mdl, "Can't find mole for {}/{}/{}".format( + assert self.mdl, "Can't find model for {}/{}/{}".format( tenant_id, llm_type, llm_name) self.max_length = 8192 for lm in LLMService.query(llm_name=llm_name): diff --git a/rag/llm/__init__.py b/rag/llm/__init__.py index 441a2a553b..c7a820d2fd 100644 --- a/rag/llm/__init__.py +++ b/rag/llm/__init__.py @@ -47,10 +47,9 @@ "Replicate": ReplicateEmbed, "BaiduYiyan": BaiduYiyanEmbed, "Voyage AI": VoyageEmbed, - "HuggingFace":HuggingFaceEmbed, + "HuggingFace": HuggingFaceEmbed, } - CvModel = { "OpenAI": GptV4, "Azure-OpenAI": AzureGptV4, @@ -64,14 +63,13 @@ "LocalAI": LocalAICV, "NVIDIA": NvidiaCV, "LM-Studio": LmStudioCV, - "StepFun":StepFunCV, + "StepFun": StepFunCV, "OpenAI-API-Compatible": OpenAI_APICV, "TogetherAI": TogetherAICV, "01.AI": YiCV, "Tencent Hunyuan": HunyuanCV } - ChatModel = { "OpenAI": GptTurbo, "Azure-OpenAI": AzureChat, @@ -99,7 +97,7 @@ "LeptonAI": LeptonAIChat, "TogetherAI": TogetherAIChat, "PerfXCloud": PerfXCloudChat, - "Upstage":UpstageChat, + "Upstage": UpstageChat, "novita.ai": NovitaAIChat, "SILICONFLOW": SILICONFLOWChat, "01.AI": YiChat, @@ -111,7 +109,6 @@ "Google Cloud": GoogleChat, } - RerankModel = { "BAAI": DefaultRerank, "Jina": JinaRerank, @@ -127,11 +124,9 @@ "Voyage AI": VoyageRerank } - Seq2txtModel = { "OpenAI": GPTSeq2txt, "Tongyi-Qianwen": QWenSeq2txt, - "Ollama": OllamaSeq2txt, "Azure-OpenAI": AzureSeq2txt, "Xinference": XinferenceSeq2txt, "Tencent Cloud": TencentCloudSeq2txt @@ -140,6 +135,7 @@ TTSModel = { "Fish Audio": FishAudioTTS, "Tongyi-Qianwen": QwenTTS, - "OpenAI":OpenAITTS, - "XunFei Spark":SparkTTS -} \ No newline at end of file + "OpenAI": OpenAITTS, + "XunFei Spark": SparkTTS, + "Xinference": XinferenceTTS, +} diff --git a/rag/llm/sequence2txt_model.py b/rag/llm/sequence2txt_model.py index a3f7f5af11..e2f76e16d1 100644 --- a/rag/llm/sequence2txt_model.py +++ b/rag/llm/sequence2txt_model.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import requests from openai.lib.azure import AzureOpenAI from zhipuai import ZhipuAI import io @@ -25,6 +26,7 @@ import base64 import re + class Base(ABC): def __init__(self, key, model_name): pass @@ -36,8 +38,8 @@ def transcription(self, audio, **kwargs): response_format="text" ) return transcription.text.strip(), num_tokens_from_string(transcription.text.strip()) - - def audio2base64(self,audio): + + def audio2base64(self, audio): if isinstance(audio, bytes): return base64.b64encode(audio).decode("utf-8") if isinstance(audio, io.BytesIO): @@ -77,13 +79,6 @@ def transcription(self, audio, format): return "**ERROR**: " + result.message, 0 -class OllamaSeq2txt(Base): - def __init__(self, key, model_name, lang="Chinese", **kwargs): - self.client = Client(host=kwargs["base_url"]) - self.model_name = model_name - self.lang = lang - - class AzureSeq2txt(Base): def __init__(self, key, model_name, lang="Chinese", **kwargs): self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01") @@ -92,16 +87,53 @@ def __init__(self, key, model_name, lang="Chinese", **kwargs): class XinferenceSeq2txt(Base): - def __init__(self, key, model_name="", base_url=""): - if base_url.split("/")[-1] != "v1": - base_url = os.path.join(base_url, "v1") - self.client = OpenAI(api_key="xxx", base_url=base_url) + def __init__(self,key,model_name="whisper-small",**kwargs): + self.base_url = kwargs.get('base_url', None) self.model_name = model_name + def transcription(self, audio, language="zh", prompt=None, response_format="json", temperature=0.7): + if isinstance(audio, str): + audio_file = open(audio, 'rb') + audio_data = audio_file.read() + audio_file_name = audio.split("/")[-1] + else: + audio_data = audio + audio_file_name = "audio.wav" + + payload = { + "model": self.model_name, + "language": language, + "prompt": prompt, + "response_format": response_format, + "temperature": temperature + } + + files = { + "file": (audio_file_name, audio_data, 'audio/wav') + } + + try: + response = requests.post( + f"{self.base_url}/v1/audio/transcriptions", + files=files, + data=payload + ) + response.raise_for_status() + result = response.json() + + if 'text' in result: + transcription_text = result['text'].strip() + return transcription_text, num_tokens_from_string(transcription_text) + else: + return "**ERROR**: Failed to retrieve transcription.", 0 + + except requests.exceptions.RequestException as e: + return f"**ERROR**: {str(e)}", 0 + class TencentCloudSeq2txt(Base): def __init__( - self, key, model_name="16k_zh", base_url="https://asr.tencentcloudapi.com" + self, key, model_name="16k_zh", base_url="https://asr.tencentcloudapi.com" ): from tencentcloud.common import credential from tencentcloud.asr.v20190614 import asr_client diff --git a/rag/llm/tts_model.py b/rag/llm/tts_model.py index bfdb8762c8..814a5dfc37 100644 --- a/rag/llm/tts_model.py +++ b/rag/llm/tts_model.py @@ -297,3 +297,36 @@ def run(*args): break status_code = 1 yield audio_chunk + + + + +class XinferenceTTS: + def __init__(self, key, model_name, **kwargs): + self.base_url = kwargs.get("base_url", None) + self.model_name = model_name + self.headers = { + "accept": "application/json", + "Content-Type": "application/json" + } + + def tts(self, text, voice="中文女", stream=True): + payload = { + "model": self.model_name, + "input": text, + "voice": voice + } + + response = requests.post( + f"{self.base_url}/v1/audio/speech", + headers=self.headers, + json=payload, + stream=stream + ) + + if response.status_code != 200: + raise Exception(f"**Error**: {response.status_code}, {response.text}") + + for chunk in response.iter_content(chunk_size=1024): + if chunk: + yield chunk diff --git a/web/src/pages/user-setting/setting-model/ollama-modal/index.tsx b/web/src/pages/user-setting/setting-model/ollama-modal/index.tsx index cb9bd2546e..c880ec254b 100644 --- a/web/src/pages/user-setting/setting-model/ollama-modal/index.tsx +++ b/web/src/pages/user-setting/setting-model/ollama-modal/index.tsx @@ -53,6 +53,26 @@ const OllamaModal = ({ const url = llmFactoryToUrlMap[llmFactory as LlmFactory] || 'https://github.com/infiniflow/ragflow/blob/main/docs/guides/deploy_local_llm.mdx'; + const optionsMap = { + HuggingFace: [{ value: 'embedding', label: 'embedding' }], + Xinference: [ + { value: 'chat', label: 'chat' }, + { value: 'embedding', label: 'embedding' }, + { value: 'rerank', label: 'rerank' }, + { value: 'image2text', label: 'image2text' }, + { value: 'speech2text', label: 'sequence2text' }, + { value: 'tts', label: 'tts' }, + ], + Default: [ + { value: 'chat', label: 'chat' }, + { value: 'embedding', label: 'embedding' }, + { value: 'rerank', label: 'rerank' }, + { value: 'image2text', label: 'image2text' }, + ], + }; + const getOptions = (factory: string) => { + return optionsMap[factory as keyof typeof optionsMap] || optionsMap.Default; + }; return (