From 29a7b7a040d87c265d5045d5dc77d35852622e47 Mon Sep 17 00:00:00 2001 From: H <43509927+guoyuhao2330@users.noreply.github.com> Date: Mon, 22 Jul 2024 14:52:08 +0800 Subject: [PATCH] Add sequence2txt model.py (#1633) ### What problem does this PR solve? #1514 ### Type of change - [x] New Feature (non-breaking change which adds functionality) --- rag/llm/__init__.py | 11 ++++- rag/llm/sequence2txt_model.py | 89 +++++++++++++++++++++++++++++++++++ 2 files changed, 99 insertions(+), 1 deletion(-) create mode 100644 rag/llm/sequence2txt_model.py diff --git a/rag/llm/__init__.py b/rag/llm/__init__.py index 539257f48c..50e2938a39 100644 --- a/rag/llm/__init__.py +++ b/rag/llm/__init__.py @@ -17,7 +17,7 @@ from .chat_model import * from .cv_model import * from .rerank_model import * - +from .sequence2txt_model import * EmbeddingModel = { "Ollama": OllamaEmbed, @@ -81,3 +81,12 @@ "Youdao": YoudaoRerank, "Xinference": XInferenceRerank } + + +Seq2txtModel = { + "OpenAI": GPTSeq2txt, + "Tongyi-Qianwen": QWenSeq2txt, + "Ollama": OllamaSeq2txt, + "Azure-OpenAI": AzureSeq2txt, + "Xinference": XinferenceSeq2txt +} diff --git a/rag/llm/sequence2txt_model.py b/rag/llm/sequence2txt_model.py new file mode 100644 index 0000000000..08a2b84f0d --- /dev/null +++ b/rag/llm/sequence2txt_model.py @@ -0,0 +1,89 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from openai.lib.azure import AzureOpenAI +from zhipuai import ZhipuAI +import io +from abc import ABC +from ollama import Client +from openai import OpenAI +import os +import json +from rag.utils import num_tokens_from_string + + +class Base(ABC): + def __init__(self, key, model_name): + pass + + def transcription(self, audio, **kwargs): + transcription = self.client.audio.transcriptions.create( + model=self.model_name, + file=audio, + response_format="text" + ) + return transcription.text.strip(), num_tokens_from_string(transcription.text.strip()) + + +class GPTSeq2txt(Base): + def __init__(self, key, model_name="whisper-1", base_url="https://api.openai.com/v1"): + if not base_url: base_url = "https://api.openai.com/v1" + self.client = OpenAI(api_key=key, base_url=base_url) + self.model_name = model_name + + +class QWenSeq2txt(Base): + def __init__(self, key, model_name="paraformer-realtime-8k-v1", **kwargs): + import dashscope + dashscope.api_key = key + self.model_name = model_name + + def transcription(self, audio, format): + from http import HTTPStatus + from dashscope.audio.asr import Recognition + + recognition = Recognition(model=self.model_name, + format=format, + sample_rate=16000, + callback=None) + result = recognition.call(audio) + + ans = "" + if result.status_code == HTTPStatus.OK: + for sentence in result.get_sentence(): + ans += str(sentence + '\n') + return ans, num_tokens_from_string(ans) + + return "**ERROR**: " + result.message, 0 + + +class OllamaSeq2txt(Base): + def __init__(self, key, model_name, lang="Chinese", **kwargs): + self.client = Client(host=kwargs["base_url"]) + self.model_name = model_name + self.lang = lang + + +class AzureSeq2txt(Base): + def __init__(self, key, model_name, lang="Chinese", **kwargs): + self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01") + self.model_name = model_name + self.lang = lang + + +class XinferenceSeq2txt(Base): + def __init__(self, key, model_name="", base_url=""): + self.client = OpenAI(api_key="xxx", base_url=base_url) + self.model_name = model_name