Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# coding=utf-8
import traceback
from typing import Dict

from django.utils.translation import gettext as _

from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from models_provider.base_model_provider import BaseModelCredential, ValidCode



class ZhEnXunFeiSTTModelCredential(BaseForm, BaseModelCredential):
spark_api_url = forms.TextInputField('API URL', required=True, default_value='wss://iat.xf-yun.com/v1')
spark_app_id = forms.TextInputField('APP ID', required=True)
spark_api_key = forms.PasswordInputField("API Key", required=True)
spark_api_secret = forms.PasswordInputField('API Secret', required=True)

def is_valid(self,
model_type: str,
model_name,
model_credential: Dict[str, object],
model_params, provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value,
_('{model_type} Model type is not supported').format(model_type=model_type))

for key in ['spark_api_url', 'spark_app_id', 'spark_api_key', 'spark_api_secret']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, _('{key} is required').format(key=key))
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model.check_auth()
except Exception as e:
traceback.print_exc()
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value,
_('Verification failed, please check whether the parameters are correct: {error}').format(
error=str(e)))
else:
return False
return True

def encryption_dict(self, model: Dict[str, object]):
return {**model, 'spark_api_secret': super().encryption(model.get('spark_api_secret', ''))}

def get_model_params_setting_form(self, model_name):
pass
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This Python code snippet appears to be part of a Django application that validates and manages model credentials for speech-to-text processing using Xiaofengyun's API (ZhEnXunFei). Here are some potential issues and optimizations:

Potential Issues:

  1. Type Annotations: The use of Dict[str, object] for model_credential can lead to runtime type errors since objects could contain unexpected types. Using specific annotation types like dict[str, Any] would improve clarity.

  2. String Formatting with _: While using Unicode literals (u'...') is no longer necessary in modern versions of Python, it might cause syntax warnings or performance issues in certain environments. Ensure all string formats are correctly specified without explicit encoding hints.

  3. Traceback Logging: Although logging the traceback can be helpful during development, it should not be done directly in production unless you have a need to debug exceptions.

  4. Empty List Check: The line if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):

    • Is there a reason why you're converting model_type_list to a list before filtering? It seems unnecessary here.
    • A more idiomatic way would be to skip this step if model_type_list is empty beforehand.
  5. Exception Handling in get_model_params_setting_form: This method does nothing useful; its implementation can be removed or modified based on actual requirements.

  6. Unnecessary Empty File at End: There is an empty line at the end of the file that doesn't serve a purpose.

Optimizations:

  1. Use Type Annotations Accurately:

    from typing import Dict, Optional, Any
    
    class ZhEnXunFeiSTTModelCredential(BaseForm, BaseModelCredential):
        # ...
  2. Remove Unnecessary Conversion:

    model_type_list = provider.get_model_type_list()
    if not model_type_list:
        raise AppApiException(ValidCode.valid_error.value, _("No valid model types found"))
    # ... rest of the validation logic remains unchanged ...
  3. Simplify Exception Handling:
    Instead of raising exceptions within exception handling blocks, catch them separately and handle each case appropriately. For example:

    try:
        model = provider.get_model(model_type, model_name, credential)
        model.check_auth()
    except AppApiException as e:
        if not raise_exception:
            return False
        raise
    except Exception as e:
        traceback.print_exc()
        if raise_exception:
            raise AppApiException(ValidCode.valid_error.value, _('Verification failed'))
        return False
  4. Consider Adding More Specific Exceptions:
    Depending on the complexity, you might want to introduce specific exception classes or wrap existing ones to better describe the context of each failure.

Apply these improvements where appropriate to ensure the code maintains correctness, readability, and maintainability while also considering future scalability needs.

192 changes: 192 additions & 0 deletions apps/models_provider/impl/xf_model_provider/model/zh_en_stt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
import asyncio
import json
import base64
import hmac
import hashlib
import ssl
import traceback
from typing import Dict
from urllib.parse import urlencode
from datetime import datetime, timezone, UTC
import websockets
import os

from future.backports.urllib.parse import urlparse

from common.utils.logger import maxkb_logger
from models_provider.base_model_provider import MaxKBBaseModel
from models_provider.impl.base_stt import BaseSpeechToText

ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE


class XFZhEnSparkSpeechToText(MaxKBBaseModel, BaseSpeechToText):
spark_app_id: str
spark_api_key: str
spark_api_secret: str
spark_api_url: str

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.spark_api_url = kwargs.get('spark_api_url')
self.spark_app_id = kwargs.get('spark_app_id')
self.spark_api_key = kwargs.get('spark_api_key')
self.spark_api_secret = kwargs.get('spark_api_secret')

@staticmethod
def is_cache_model():
return False

@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {}
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
optional_params['max_tokens'] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
optional_params['temperature'] = model_kwargs['temperature']
return XFZhEnSparkSpeechToText(
spark_app_id=model_credential.get('spark_app_id'),
spark_api_key=model_credential.get('spark_api_key'),
spark_api_secret=model_credential.get('spark_api_secret'),
spark_api_url=model_credential.get('spark_api_url'),
**optional_params
)

# 生成url
def create_url(self):
url = self.spark_api_url
host = urlparse(url).hostname

gmt_format = '%a, %d %b %Y %H:%M:%S GMT'
date = datetime.now(UTC).strftime(gmt_format)
# 拼接字符串
signature_origin = "host: " + host + "\n"
signature_origin += "date: " + date + "\n"
signature_origin += "GET " + "/v1 HTTP/1.1"
# 进行hmac-sha256进行加密
signature_sha = hmac.new(
self.spark_api_secret.encode('utf-8'),
signature_origin.encode('utf-8'),
hashlib.sha256
).digest()
signature = base64.b64encode(signature_sha).decode(encoding='utf-8')

authorization_origin = (
f'api_key="{self.spark_api_key}", algorithm="hmac-sha256", '
f'headers="host date request-line", signature="{signature}"'
)
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')

params = {
'authorization': authorization,
'date': date,
'host': host
}
auth_url = url + '?' + urlencode(params)
return auth_url

def check_auth(self):
cwd = os.path.dirname(os.path.abspath(__file__))
with open(f'{cwd}/iat_mp3_16k.mp3', 'rb') as f:
self.speech_to_text(f)

def speech_to_text(self, audio_file_path):
async def handle():
async with websockets.connect(self.create_url(), max_size=1000000000, ssl=ssl_context) as ws:
# print("连接成功")
# 发送音频数据
await self.send_audio(ws, audio_file_path)
# 接收识别结果
return await self.handle_message(ws)
try:
return asyncio.run(handle())
except Exception as err:
maxkb_logger.error(f"语音识别错误: {str(err)}: {traceback.format_exc()}")
return ""

async def send_audio(self, ws, audio_file):
"""发送音频数据"""
chunk_size = 4000
seq = 1
max_chunks = 10000
while True:
chunk = audio_file.read(chunk_size)
if not chunk or seq > max_chunks:
break

chunk_base64 = base64.b64encode(chunk).decode('utf-8')
# 第一帧
if seq == 1:
frame = {
"header": {"app_id": self.spark_app_id, "status": 0},
"parameter": {
"iat": {
"domain": "slm", "language": "zh_cn", "accent": "mandarin",
"eos": 10000, "vinfo": 1,
"result": {"encoding": "utf8", "compress": "raw", "format": "json"}
}
},
"payload": {
"audio": {
"encoding": "lame", "sample_rate": 16000, "channels": 1,
"bit_depth": 16, "seq": seq, "status": 0, "audio": chunk_base64
}
}
}
# 中间帧
else:
frame = {
"header": {"app_id": self.spark_app_id, "status": 1},
"payload": {
"audio": {
"encoding": "lame", "sample_rate": 16000, "channels": 1,
"bit_depth": 16, "seq": seq, "status": 1, "audio": chunk_base64
}
}
}

await ws.send(json.dumps(frame))
seq += 1

# 发送结束帧
end_frame = {
"header": {"app_id": self.spark_app_id, "status": 2},
"payload": {
"audio": {
"encoding": "lame", "sample_rate": 16000, "channels": 1,
"bit_depth": 16, "seq": seq, "status": 2, "audio": ""
}
}
}
await ws.send(json.dumps(end_frame))


# 接受信息处理器
async def handle_message(self, ws):
result_text = ""
while True:
try:
message = await asyncio.wait_for(ws.recv(), timeout=30.0)
data = json.loads(message)

if data['header']['code'] != 0:
raise Exception("")

if 'payload' in data and 'result' in data['payload']:
result = data['payload']['result']
text = result.get('text', '')
if text:
text_data = json.loads(base64.b64decode(text).decode('utf-8'))
for ws_item in text_data.get('ws', []):
for cw in ws_item.get('cw', []):
for sw in cw.get('sw', []):
result_text += sw['w']

if data['header'].get('status') == 2:
break
except asyncio.TimeoutError:
break

return result_text
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The provided Python code seems to be an implementation of a web service client for interfacing with a Spark-based speech-to-text API using the XFZhEnSparkSpeechToText class. Here are some general comments on the code:

  1. SSL Context: The use of SSLContext without specific settings might pose security risks. You should configure your SSL context more securely based on your production requirements.

  2. URL Creation: The method create_url() generates an authorization header by concatenating headers and signing them with the key. This approach is fine but could benefit from better exception handling when errors occur during cryptographic operations.

  3. Audio Sending Logic: The loop that sends chunks breaks prematurely after reading zero bytes, which isn't ideal if you haven't reached end-of-stream in chunks. Consider implementing proper detection for EOF or raising an error accordingly.

  4. WebSocket Connection Handling: It's good practice to have cleaner separation between sending audio and receiving responses. However, the current design has a single function for both.

  5. Error Handling: Currently, all exceptions within methods like handle_audio(), send_audio(), and speech_to_text() catch them locally, leading to unhandled exceptions being logged at the global level (maxkb_logger.error). Improving this ensures exceptions don't silently fail.

  6. Asynchronous vs Synchronous Calls: Most parts of the code assume asynchronous execution (e.g., asyncio.run(handle())). Ensure it aligns with your application architecture. Some functions may still require synchronization due to their nature.

  7. Logging: Using logging instead of printing error messages directly can make debugging easier and adhere to best practices by reducing clutter in console output.

  8. Configuration Management: If you're storing keys/configurations in files, consider moving sensitive information such as API keys into environment variables or secure vault configurations rather than plain text files.

  9. Documentation: Adding docstrings across various methods would improve readability and maintainability of the codebase.

Overall, the code provides a solid foundation for interacting with the Spark API through websockets asynchronously. Continuous testing and improvement are recommended to address performance bottlenecks and robustness issues.

12 changes: 10 additions & 2 deletions apps/models_provider/impl/xf_model_provider/xf_model_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from models_provider.impl.xf_model_provider.credential.llm import XunFeiLLMModelCredential
from models_provider.impl.xf_model_provider.credential.stt import XunFeiSTTModelCredential
from models_provider.impl.xf_model_provider.credential.tts import XunFeiTTSModelCredential
from models_provider.impl.xf_model_provider.credential.zh_en_stt import ZhEnXunFeiSTTModelCredential
from models_provider.impl.xf_model_provider.model.embedding import XFEmbedding
from models_provider.impl.xf_model_provider.model.image import XFSparkImage
from models_provider.impl.xf_model_provider.model.llm import XFChatSparkLLM
Expand All @@ -25,18 +26,24 @@
from maxkb.conf import PROJECT_DIR
from django.utils.translation import gettext as _

from models_provider.impl.xf_model_provider.model.zh_en_stt import XFZhEnSparkSpeechToText

ssl._create_default_https_context = ssl.create_default_context()

xunfei_model_credential = XunFeiLLMModelCredential()
stt_model_credential = XunFeiSTTModelCredential()
zh_en_stt_credential = ZhEnXunFeiSTTModelCredential()
image_model_credential = XunFeiImageModelCredential()
tts_model_credential = XunFeiTTSModelCredential()
embedding_model_credential = XFEmbeddingCredential()
model_info_list = [
ModelInfo('generalv3.5', '', ModelTypeConst.LLM, xunfei_model_credential, XFChatSparkLLM),
ModelInfo('generalv3', '', ModelTypeConst.LLM, xunfei_model_credential, XFChatSparkLLM),
ModelInfo('generalv2', '', ModelTypeConst.LLM, xunfei_model_credential, XFChatSparkLLM),
ModelInfo('iat', _('Chinese and English recognition'), ModelTypeConst.STT, stt_model_credential, XFSparkSpeechToText),
ModelInfo('iat', _('Chinese and English recognition'), ModelTypeConst.STT, stt_model_credential,
XFSparkSpeechToText),
ModelInfo('slm', _('Chinese and English recognition'), ModelTypeConst.STT, zh_en_stt_credential,
XFZhEnSparkSpeechToText),
ModelInfo('tts', '', ModelTypeConst.TTS, tts_model_credential, XFSparkTextToSpeech),
ModelInfo('embedding', '', ModelTypeConst.EMBEDDING, embedding_model_credential, XFEmbedding)
]
Expand All @@ -47,7 +54,8 @@
.append_default_model_info(
ModelInfo('generalv3.5', '', ModelTypeConst.LLM, xunfei_model_credential, XFChatSparkLLM))
.append_default_model_info(
ModelInfo('iat', _('Chinese and English recognition'), ModelTypeConst.STT, stt_model_credential, XFSparkSpeechToText),
ModelInfo('iat', _('Chinese and English recognition'), ModelTypeConst.STT, stt_model_credential,
XFSparkSpeechToText),
)
.append_default_model_info(
ModelInfo('tts', '', ModelTypeConst.TTS, tts_model_credential, XFSparkTextToSpeech))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are no significant issues with the provided Python code snippet. However, I have a few optimizations and improvements you might consider:

  1. Avoid Using append_default_model_info: If this is intended to add default model information, ensure that the class DefaultModelInfoProvider supports appending without causing side effects.

  2. Consistent Use of Quotation Marks: In some places, double quotes (") and single quotes ('') are used interchangeably. Consistency is preferred, especially if you're using string literals throughout.

  3. Line Length: The line lengths in the file can be improved for better readability and maintainability. Consider breaking down long lines into multiple parts or using triple quotes for larger strings.

Here’s an updated version of the code with these considerations:

from models_provider.impl.xf_model_provider.credential.llm import XunFeiLLMModelCredential
from models_provider.impl.xf_model_provider.credential.stt import XunFeiSTTModelCredential
from models_provider.impl.xf_model_provider.credential.tts import XunFeiTTSModelCredential
from models_provider.impl.xf_model_provider.credential.zh_en_stt import ZhEnXunFeiSTTModelCredential

from models_provider.impl.xf_model_provider.model.embedding import XFEmbedding
from models_provider.impl.xf_model_provider.model.image import XFSparkImage
from models_provider.impl.xf_model_provider.model.llm import XFChatSparkLLM

from maxkb.conf import PROJECT_DIR
from django.utils.translation import gettext as _

import ssl

ssl._create_default_https_context = ssl.create_default_context()

xunfei_model_credential = XunFeiLLMModelCredential()
stt_model_credential = XunFeiSTTModelCredential()
zh_en_stt_credential = ZhEnXunFeiSTTModelCredential()
image_model_credential = XunFeiImageModelCredential()
tts_model_credential = XunFeiTTSModelCredential()
embedding_model_credential = XFEmbeddingCredential()

model_info_list = [
    ModelInfo('generalv3.5', "", ModelTypeConst.LLM, xunfei_model_credential, XFChatSparkLLM),
    ModelInfo('', 'General Version v3.0', ModelTypeConst.LLM, xunfei_model_credential, XFChatSparkLLM),
    ModelInfo('', 'General Version v2.0', ModelTypeConst.LLM, xunfei_model_credential, XFChatSparkLLM),

    # Simplify Chinese and English Recognition model info
    ModelInfo('iat', _('Chinese and English Recognition'), ModelTypeConst.STT,
              stt_model_credential, XFSparkSpeechToText),
    
    # New STT model info for Chinese and English
    ModelInfo('slm', _('Chinese and English Recognition'), ModelTypeConst.STT,
              zh_en_stt_credential, XFZhEnSparkSpeechToText),

    ModelInfo("", "Text-to-Speech", ModelTypeConst.TTS, tts_model_credential, XFSparkTextToSpeech),
    ModelInfo("embedding", "Sentence Embeddings", ModelTypeConst.EMBEDDING,
              embedding_model_credential, XFEmbedding)
]

By applying these changes, the code becomes more readable and consistent in terms of string handling and overall structure.

Expand Down
Loading