|
| 1 | +# coding=utf-8 |
| 2 | +""" |
| 3 | + @project: maxkb |
| 4 | + @Author:虎 |
| 5 | + @file: zhipu_model_provider.py |
| 6 | + @date:2024/04/19 13:5 |
| 7 | + @desc: |
| 8 | +""" |
| 9 | +import os |
| 10 | +from typing import Dict |
| 11 | + |
| 12 | +from langchain.schema import HumanMessage |
| 13 | +from langchain_community.chat_models import ChatZhipuAI |
| 14 | + |
| 15 | +from common import forms |
| 16 | +from common.exception.app_exception import AppApiException |
| 17 | +from common.forms import BaseForm |
| 18 | +from common.util.file_util import get_file_content |
| 19 | +from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, BaseModelCredential, \ |
| 20 | + ModelInfo, IModelProvider, ValidCode |
| 21 | +from smartdoc.conf import PROJECT_DIR |
| 22 | + |
| 23 | + |
| 24 | +class ZhiPuLLMModelCredential(BaseForm, BaseModelCredential): |
| 25 | + |
| 26 | + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False): |
| 27 | + model_type_list = ZhiPuModelProvider().get_model_type_list() |
| 28 | + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): |
| 29 | + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') |
| 30 | + for key in ['api_key']: |
| 31 | + if key not in model_credential: |
| 32 | + if raise_exception: |
| 33 | + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') |
| 34 | + else: |
| 35 | + return False |
| 36 | + try: |
| 37 | + model = ZhiPuModelProvider().get_model(model_type, model_name, model_credential) |
| 38 | + model.invoke([HumanMessage(content='你好')]) |
| 39 | + except Exception as e: |
| 40 | + if isinstance(e, AppApiException): |
| 41 | + raise e |
| 42 | + if raise_exception: |
| 43 | + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') |
| 44 | + else: |
| 45 | + return False |
| 46 | + return True |
| 47 | + |
| 48 | + def encryption_dict(self, model: Dict[str, object]): |
| 49 | + return {**model, 'api_key': super().encryption(model.get('api_key', ''))} |
| 50 | + |
| 51 | + api_key = forms.PasswordInputField('API Key', required=True) |
| 52 | + |
| 53 | + |
| 54 | +qwen_model_credential = ZhiPuLLMModelCredential() |
| 55 | + |
| 56 | +model_dict = { |
| 57 | + 'glm-4': ModelInfo('glm-4', '', ModelTypeConst.LLM, qwen_model_credential), |
| 58 | + 'glm-4v': ModelInfo('glm-4v', '', ModelTypeConst.LLM, qwen_model_credential), |
| 59 | + 'glm-3-turbo': ModelInfo('glm-3-turbo', '', ModelTypeConst.LLM, qwen_model_credential) |
| 60 | +} |
| 61 | + |
| 62 | + |
| 63 | +class ZhiPuModelProvider(IModelProvider): |
| 64 | + |
| 65 | + def get_dialogue_number(self): |
| 66 | + return 3 |
| 67 | + |
| 68 | + def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> ChatZhipuAI: |
| 69 | + zhipuai_chat = ChatZhipuAI( |
| 70 | + temperature=0.5, |
| 71 | + api_key=model_credential.get('api_key'), |
| 72 | + model=model_name |
| 73 | + ) |
| 74 | + return zhipuai_chat |
| 75 | + |
| 76 | + def get_model_credential(self, model_type, model_name): |
| 77 | + if model_name in model_dict: |
| 78 | + return model_dict.get(model_name).model_credential |
| 79 | + return qwen_model_credential |
| 80 | + |
| 81 | + def get_model_provide_info(self): |
| 82 | + return ModelProvideInfo(provider='model_zhipu_provider', name='智谱AI', icon=get_file_content( |
| 83 | + os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'zhipu_model_provider', 'icon', |
| 84 | + 'zhipuai_icon_svg'))) |
| 85 | + |
| 86 | + def get_model_list(self, model_type: str): |
| 87 | + if model_type is None: |
| 88 | + raise AppApiException(500, '模型类型不能为空') |
| 89 | + return [model_dict.get(key).to_dict() for key in |
| 90 | + list(filter(lambda key: model_dict.get(key).model_type == model_type, model_dict.keys()))] |
| 91 | + |
| 92 | + def get_model_type_list(self): |
| 93 | + return [{'key': "大语言模型", 'value': "LLM"}] |
0 commit comments