-
Notifications
You must be signed in to change notification settings - Fork 2.3k
feat: Support iFLYTEK large model for Chinese-English speech recognition #3952
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
# coding=utf-8 | ||
import traceback | ||
from typing import Dict | ||
|
||
from django.utils.translation import gettext as _ | ||
|
||
from common import forms | ||
from common.exception.app_exception import AppApiException | ||
from common.forms import BaseForm | ||
from models_provider.base_model_provider import BaseModelCredential, ValidCode | ||
|
||
|
||
|
||
class ZhEnXunFeiSTTModelCredential(BaseForm, BaseModelCredential): | ||
spark_api_url = forms.TextInputField('API URL', required=True, default_value='wss://iat.xf-yun.com/v1') | ||
spark_app_id = forms.TextInputField('APP ID', required=True) | ||
spark_api_key = forms.PasswordInputField("API Key", required=True) | ||
spark_api_secret = forms.PasswordInputField('API Secret', required=True) | ||
|
||
def is_valid(self, | ||
model_type: str, | ||
model_name, | ||
model_credential: Dict[str, object], | ||
model_params, provider, | ||
raise_exception=False): | ||
model_type_list = provider.get_model_type_list() | ||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): | ||
raise AppApiException(ValidCode.valid_error.value, | ||
_('{model_type} Model type is not supported').format(model_type=model_type)) | ||
|
||
for key in ['spark_api_url', 'spark_app_id', 'spark_api_key', 'spark_api_secret']: | ||
if key not in model_credential: | ||
if raise_exception: | ||
raise AppApiException(ValidCode.valid_error.value, _('{key} is required').format(key=key)) | ||
else: | ||
return False | ||
try: | ||
model = provider.get_model(model_type, model_name, model_credential) | ||
model.check_auth() | ||
except Exception as e: | ||
traceback.print_exc() | ||
if isinstance(e, AppApiException): | ||
raise e | ||
if raise_exception: | ||
raise AppApiException(ValidCode.valid_error.value, | ||
_('Verification failed, please check whether the parameters are correct: {error}').format( | ||
error=str(e))) | ||
else: | ||
return False | ||
return True | ||
|
||
def encryption_dict(self, model: Dict[str, object]): | ||
return {**model, 'spark_api_secret': super().encryption(model.get('spark_api_secret', ''))} | ||
|
||
def get_model_params_setting_form(self, model_name): | ||
pass | ||
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,192 @@ | ||
import asyncio | ||
import json | ||
import base64 | ||
import hmac | ||
import hashlib | ||
import ssl | ||
import traceback | ||
from typing import Dict | ||
from urllib.parse import urlencode | ||
from datetime import datetime, timezone, UTC | ||
import websockets | ||
import os | ||
|
||
from future.backports.urllib.parse import urlparse | ||
|
||
from common.utils.logger import maxkb_logger | ||
from models_provider.base_model_provider import MaxKBBaseModel | ||
from models_provider.impl.base_stt import BaseSpeechToText | ||
|
||
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) | ||
ssl_context.check_hostname = False | ||
ssl_context.verify_mode = ssl.CERT_NONE | ||
|
||
|
||
class XFZhEnSparkSpeechToText(MaxKBBaseModel, BaseSpeechToText): | ||
spark_app_id: str | ||
spark_api_key: str | ||
spark_api_secret: str | ||
spark_api_url: str | ||
|
||
def __init__(self, **kwargs): | ||
super().__init__(**kwargs) | ||
self.spark_api_url = kwargs.get('spark_api_url') | ||
self.spark_app_id = kwargs.get('spark_app_id') | ||
self.spark_api_key = kwargs.get('spark_api_key') | ||
self.spark_api_secret = kwargs.get('spark_api_secret') | ||
|
||
@staticmethod | ||
def is_cache_model(): | ||
return False | ||
|
||
@staticmethod | ||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): | ||
optional_params = {} | ||
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: | ||
optional_params['max_tokens'] = model_kwargs['max_tokens'] | ||
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: | ||
optional_params['temperature'] = model_kwargs['temperature'] | ||
return XFZhEnSparkSpeechToText( | ||
spark_app_id=model_credential.get('spark_app_id'), | ||
spark_api_key=model_credential.get('spark_api_key'), | ||
spark_api_secret=model_credential.get('spark_api_secret'), | ||
spark_api_url=model_credential.get('spark_api_url'), | ||
**optional_params | ||
) | ||
|
||
# 生成url | ||
def create_url(self): | ||
url = self.spark_api_url | ||
host = urlparse(url).hostname | ||
|
||
gmt_format = '%a, %d %b %Y %H:%M:%S GMT' | ||
date = datetime.now(UTC).strftime(gmt_format) | ||
# 拼接字符串 | ||
signature_origin = "host: " + host + "\n" | ||
signature_origin += "date: " + date + "\n" | ||
signature_origin += "GET " + "/v1 HTTP/1.1" | ||
# 进行hmac-sha256进行加密 | ||
signature_sha = hmac.new( | ||
self.spark_api_secret.encode('utf-8'), | ||
signature_origin.encode('utf-8'), | ||
hashlib.sha256 | ||
).digest() | ||
signature = base64.b64encode(signature_sha).decode(encoding='utf-8') | ||
|
||
authorization_origin = ( | ||
f'api_key="{self.spark_api_key}", algorithm="hmac-sha256", ' | ||
f'headers="host date request-line", signature="{signature}"' | ||
) | ||
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8') | ||
|
||
params = { | ||
'authorization': authorization, | ||
'date': date, | ||
'host': host | ||
} | ||
auth_url = url + '?' + urlencode(params) | ||
return auth_url | ||
|
||
def check_auth(self): | ||
cwd = os.path.dirname(os.path.abspath(__file__)) | ||
with open(f'{cwd}/iat_mp3_16k.mp3', 'rb') as f: | ||
self.speech_to_text(f) | ||
|
||
def speech_to_text(self, audio_file_path): | ||
async def handle(): | ||
async with websockets.connect(self.create_url(), max_size=1000000000, ssl=ssl_context) as ws: | ||
# print("连接成功") | ||
# 发送音频数据 | ||
await self.send_audio(ws, audio_file_path) | ||
# 接收识别结果 | ||
return await self.handle_message(ws) | ||
try: | ||
return asyncio.run(handle()) | ||
except Exception as err: | ||
maxkb_logger.error(f"语音识别错误: {str(err)}: {traceback.format_exc()}") | ||
return "" | ||
|
||
async def send_audio(self, ws, audio_file): | ||
"""发送音频数据""" | ||
chunk_size = 4000 | ||
seq = 1 | ||
max_chunks = 10000 | ||
while True: | ||
chunk = audio_file.read(chunk_size) | ||
if not chunk or seq > max_chunks: | ||
break | ||
|
||
chunk_base64 = base64.b64encode(chunk).decode('utf-8') | ||
# 第一帧 | ||
if seq == 1: | ||
frame = { | ||
"header": {"app_id": self.spark_app_id, "status": 0}, | ||
"parameter": { | ||
"iat": { | ||
"domain": "slm", "language": "zh_cn", "accent": "mandarin", | ||
"eos": 10000, "vinfo": 1, | ||
"result": {"encoding": "utf8", "compress": "raw", "format": "json"} | ||
} | ||
}, | ||
"payload": { | ||
"audio": { | ||
"encoding": "lame", "sample_rate": 16000, "channels": 1, | ||
"bit_depth": 16, "seq": seq, "status": 0, "audio": chunk_base64 | ||
} | ||
} | ||
} | ||
# 中间帧 | ||
else: | ||
frame = { | ||
"header": {"app_id": self.spark_app_id, "status": 1}, | ||
"payload": { | ||
"audio": { | ||
"encoding": "lame", "sample_rate": 16000, "channels": 1, | ||
"bit_depth": 16, "seq": seq, "status": 1, "audio": chunk_base64 | ||
} | ||
} | ||
} | ||
|
||
await ws.send(json.dumps(frame)) | ||
seq += 1 | ||
|
||
# 发送结束帧 | ||
end_frame = { | ||
"header": {"app_id": self.spark_app_id, "status": 2}, | ||
"payload": { | ||
"audio": { | ||
"encoding": "lame", "sample_rate": 16000, "channels": 1, | ||
"bit_depth": 16, "seq": seq, "status": 2, "audio": "" | ||
} | ||
} | ||
} | ||
await ws.send(json.dumps(end_frame)) | ||
|
||
|
||
# 接受信息处理器 | ||
async def handle_message(self, ws): | ||
result_text = "" | ||
while True: | ||
try: | ||
message = await asyncio.wait_for(ws.recv(), timeout=30.0) | ||
data = json.loads(message) | ||
|
||
if data['header']['code'] != 0: | ||
raise Exception("") | ||
|
||
if 'payload' in data and 'result' in data['payload']: | ||
result = data['payload']['result'] | ||
text = result.get('text', '') | ||
if text: | ||
text_data = json.loads(base64.b64decode(text).decode('utf-8')) | ||
for ws_item in text_data.get('ws', []): | ||
for cw in ws_item.get('cw', []): | ||
for sw in cw.get('sw', []): | ||
result_text += sw['w'] | ||
|
||
if data['header'].get('status') == 2: | ||
break | ||
except asyncio.TimeoutError: | ||
break | ||
|
||
return result_text | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The provided Python code seems to be an implementation of a web service client for interfacing with a Spark-based speech-to-text API using the XFZhEnSparkSpeechToText class. Here are some general comments on the code:
Overall, the code provides a solid foundation for interacting with the Spark API through websockets asynchronously. Continuous testing and improvement are recommended to address performance bottlenecks and robustness issues. |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,7 @@ | |
from models_provider.impl.xf_model_provider.credential.llm import XunFeiLLMModelCredential | ||
from models_provider.impl.xf_model_provider.credential.stt import XunFeiSTTModelCredential | ||
from models_provider.impl.xf_model_provider.credential.tts import XunFeiTTSModelCredential | ||
from models_provider.impl.xf_model_provider.credential.zh_en_stt import ZhEnXunFeiSTTModelCredential | ||
from models_provider.impl.xf_model_provider.model.embedding import XFEmbedding | ||
from models_provider.impl.xf_model_provider.model.image import XFSparkImage | ||
from models_provider.impl.xf_model_provider.model.llm import XFChatSparkLLM | ||
|
@@ -25,18 +26,24 @@ | |
from maxkb.conf import PROJECT_DIR | ||
from django.utils.translation import gettext as _ | ||
|
||
from models_provider.impl.xf_model_provider.model.zh_en_stt import XFZhEnSparkSpeechToText | ||
|
||
ssl._create_default_https_context = ssl.create_default_context() | ||
|
||
xunfei_model_credential = XunFeiLLMModelCredential() | ||
stt_model_credential = XunFeiSTTModelCredential() | ||
zh_en_stt_credential = ZhEnXunFeiSTTModelCredential() | ||
image_model_credential = XunFeiImageModelCredential() | ||
tts_model_credential = XunFeiTTSModelCredential() | ||
embedding_model_credential = XFEmbeddingCredential() | ||
model_info_list = [ | ||
ModelInfo('generalv3.5', '', ModelTypeConst.LLM, xunfei_model_credential, XFChatSparkLLM), | ||
ModelInfo('generalv3', '', ModelTypeConst.LLM, xunfei_model_credential, XFChatSparkLLM), | ||
ModelInfo('generalv2', '', ModelTypeConst.LLM, xunfei_model_credential, XFChatSparkLLM), | ||
ModelInfo('iat', _('Chinese and English recognition'), ModelTypeConst.STT, stt_model_credential, XFSparkSpeechToText), | ||
ModelInfo('iat', _('Chinese and English recognition'), ModelTypeConst.STT, stt_model_credential, | ||
XFSparkSpeechToText), | ||
ModelInfo('slm', _('Chinese and English recognition'), ModelTypeConst.STT, zh_en_stt_credential, | ||
XFZhEnSparkSpeechToText), | ||
ModelInfo('tts', '', ModelTypeConst.TTS, tts_model_credential, XFSparkTextToSpeech), | ||
ModelInfo('embedding', '', ModelTypeConst.EMBEDDING, embedding_model_credential, XFEmbedding) | ||
] | ||
|
@@ -47,7 +54,8 @@ | |
.append_default_model_info( | ||
ModelInfo('generalv3.5', '', ModelTypeConst.LLM, xunfei_model_credential, XFChatSparkLLM)) | ||
.append_default_model_info( | ||
ModelInfo('iat', _('Chinese and English recognition'), ModelTypeConst.STT, stt_model_credential, XFSparkSpeechToText), | ||
ModelInfo('iat', _('Chinese and English recognition'), ModelTypeConst.STT, stt_model_credential, | ||
XFSparkSpeechToText), | ||
) | ||
.append_default_model_info( | ||
ModelInfo('tts', '', ModelTypeConst.TTS, tts_model_credential, XFSparkTextToSpeech)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There are no significant issues with the provided Python code snippet. However, I have a few optimizations and improvements you might consider:
Here’s an updated version of the code with these considerations: from models_provider.impl.xf_model_provider.credential.llm import XunFeiLLMModelCredential
from models_provider.impl.xf_model_provider.credential.stt import XunFeiSTTModelCredential
from models_provider.impl.xf_model_provider.credential.tts import XunFeiTTSModelCredential
from models_provider.impl.xf_model_provider.credential.zh_en_stt import ZhEnXunFeiSTTModelCredential
from models_provider.impl.xf_model_provider.model.embedding import XFEmbedding
from models_provider.impl.xf_model_provider.model.image import XFSparkImage
from models_provider.impl.xf_model_provider.model.llm import XFChatSparkLLM
from maxkb.conf import PROJECT_DIR
from django.utils.translation import gettext as _
import ssl
ssl._create_default_https_context = ssl.create_default_context()
xunfei_model_credential = XunFeiLLMModelCredential()
stt_model_credential = XunFeiSTTModelCredential()
zh_en_stt_credential = ZhEnXunFeiSTTModelCredential()
image_model_credential = XunFeiImageModelCredential()
tts_model_credential = XunFeiTTSModelCredential()
embedding_model_credential = XFEmbeddingCredential()
model_info_list = [
ModelInfo('generalv3.5', "", ModelTypeConst.LLM, xunfei_model_credential, XFChatSparkLLM),
ModelInfo('', 'General Version v3.0', ModelTypeConst.LLM, xunfei_model_credential, XFChatSparkLLM),
ModelInfo('', 'General Version v2.0', ModelTypeConst.LLM, xunfei_model_credential, XFChatSparkLLM),
# Simplify Chinese and English Recognition model info
ModelInfo('iat', _('Chinese and English Recognition'), ModelTypeConst.STT,
stt_model_credential, XFSparkSpeechToText),
# New STT model info for Chinese and English
ModelInfo('slm', _('Chinese and English Recognition'), ModelTypeConst.STT,
zh_en_stt_credential, XFZhEnSparkSpeechToText),
ModelInfo("", "Text-to-Speech", ModelTypeConst.TTS, tts_model_credential, XFSparkTextToSpeech),
ModelInfo("embedding", "Sentence Embeddings", ModelTypeConst.EMBEDDING,
embedding_model_credential, XFEmbedding)
] By applying these changes, the code becomes more readable and consistent in terms of string handling and overall structure. |
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This Python code snippet appears to be part of a Django application that validates and manages model credentials for speech-to-text processing using Xiaofengyun's API (ZhEnXunFei). Here are some potential issues and optimizations:
Potential Issues:
Type Annotations: The use of
Dict[str, object]
formodel_credential
can lead to runtime type errors since objects could contain unexpected types. Using specific annotation types likedict[str, Any]
would improve clarity.String Formatting with
_
: While using Unicode literals (u'...'
) is no longer necessary in modern versions of Python, it might cause syntax warnings or performance issues in certain environments. Ensure all string formats are correctly specified without explicit encoding hints.Traceback Logging: Although logging the traceback can be helpful during development, it should not be done directly in production unless you have a need to debug exceptions.
Empty List Check: The line
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
model_type_list
to a list before filtering? It seems unnecessary here.model_type_list
is empty beforehand.Exception Handling in
get_model_params_setting_form
: This method does nothing useful; its implementation can be removed or modified based on actual requirements.Unnecessary Empty File at End: There is an empty line at the end of the file that doesn't serve a purpose.
Optimizations:
Use Type Annotations Accurately:
Remove Unnecessary Conversion:
Simplify Exception Handling:
Instead of raising exceptions within exception handling blocks, catch them separately and handle each case appropriately. For example:
Consider Adding More Specific Exceptions:
Depending on the complexity, you might want to introduce specific exception classes or wrap existing ones to better describe the context of each failure.
Apply these improvements where appropriate to ensure the code maintains correctness, readability, and maintainability while also considering future scalability needs.