Skip to content

feat: 增加了DeepSeek大模型支持 && fix: 修复了系统设置->模型设置中的动态表单复用时,错误地显示之前动态表单内容的bug && perf: 修改Azure OpenAI模型表单描述,简化对应代码实现 #431

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from setting.models_provider.impl.kimi_model_provider.kimi_model_provider import KimiModelProvider
from setting.models_provider.impl.xf_model_provider.xf_model_provider import XunFeiModelProvider
from setting.models_provider.impl.zhipu_model_provider.zhipu_model_provider import ZhiPuModelProvider
from setting.models_provider.impl.deepseek_model_provider.deepseek_model_provider import DeepSeekModelProvider


class ModelProvideConstants(Enum):
Expand All @@ -27,3 +28,4 @@ class ModelProvideConstants(Enum):
model_qwen_provider = QwenModelProvider()
model_zhipu_provider = ZhiPuModelProvider()
model_xf_provider = XunFeiModelProvider()
model_deepseek_provider = DeepSeekModelProvider()
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from typing import Dict

from langchain.schema import HumanMessage
from langchain_community.chat_models.azure_openai import AzureChatOpenAI

from common import forms
from common.exception.app_exception import AppApiException
Expand All @@ -22,7 +21,7 @@
from setting.models_provider.impl.azure_model_provider.model.azure_chat_model import AzureChatModel
from smartdoc.conf import PROJECT_DIR


"""
class AzureLLMModelCredential(BaseForm, BaseModelCredential):

def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False):
Expand Down Expand Up @@ -52,11 +51,12 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje
def encryption_dict(self, model: Dict[str, object]):
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}

api_base = forms.TextInputField('API 域名', required=True)
api_base = forms.TextInputField('API 版本 (api_version)', required=True)

api_key = forms.PasswordInputField("API Key", required=True)
api_key = forms.PasswordInputField("API Key(API 密钥)", required=True)

deployment_name = forms.TextInputField("部署名", required=True)
deployment_name = forms.TextInputField("部署名(deployment_name)", required=True)
"""


class DefaultAzureLLMModelCredential(BaseForm, BaseModelCredential):
Expand Down Expand Up @@ -88,28 +88,23 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje
def encryption_dict(self, model: Dict[str, object]):
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}

api_version = forms.TextInputField("api_version", required=True)
api_version = forms.TextInputField("API 版本 (api_version)", required=True)

api_base = forms.TextInputField('API 域名', required=True)
api_base = forms.TextInputField('API 域名 (azure_endpoint)', required=True)

api_key = forms.PasswordInputField("API Key", required=True)
api_key = forms.PasswordInputField("API Key (api_key)", required=True)

deployment_name = forms.TextInputField("部署名", required=True)
deployment_name = forms.TextInputField("部署名 (deployment_name)", required=True)


azure_llm_model_credential = AzureLLMModelCredential()
# azure_llm_model_credential: AzureLLMModelCredential = AzureLLMModelCredential()

base_azure_llm_model_credential = DefaultAzureLLMModelCredential()

model_dict = {
'gpt-3.5-turbo-0613': ModelInfo('gpt-3.5-turbo-0613', '', ModelTypeConst.LLM, azure_llm_model_credential,
api_version='2023-07-01-preview'),
'gpt-3.5-turbo-0301': ModelInfo('gpt-3.5-turbo-0301', '', ModelTypeConst.LLM, azure_llm_model_credential,
api_version='2023-07-01-preview'),
'gpt-3.5-turbo-16k-0613': ModelInfo('gpt-3.5-turbo-16k-0613', '', ModelTypeConst.LLM, azure_llm_model_credential,
api_version='2023-07-01-preview'),
'gpt-4-0613': ModelInfo('gpt-4-0613', '', ModelTypeConst.LLM, azure_llm_model_credential,
api_version='2023-07-01-preview'),
'deployment_name': ModelInfo('Azure OpenAI', '具体的基础模型由部署名决定', ModelTypeConst.LLM,
base_azure_llm_model_credential, api_version='2024-02-15-preview'
)
}


Expand All @@ -118,12 +113,11 @@ class AzureModelProvider(IModelProvider):
def get_dialogue_number(self):
return 3

def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> AzureChatOpenAI:
def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> AzureChatModel:
model_info: ModelInfo = model_dict.get(model_name)
azure_chat_open_ai = AzureChatModel(
azure_endpoint=model_credential.get('api_base'),
openai_api_version=model_info.api_version if model_name in model_dict else model_credential.get(
'api_version'),
openai_api_version=model_credential.get('api_version', '2024-02-15-preview'),
deployment_name=model_credential.get('deployment_name'),
openai_api_key=model_credential.get('api_key'),
openai_api_type="azure"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project :MaxKB
@File :__init__.py.py
@Author :Brian Yang
@Date :5/12/24 7:38 AM
"""
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project :MaxKB
@File :deepseek_model_provider.py
@Author :Brian Yang
@Date :5/12/24 7:40 AM
"""
import os
from typing import Dict

from langchain.schema import HumanMessage

from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from common.util.file_util import get_file_content
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, BaseModelCredential, \
ModelInfo, ModelTypeConst, ValidCode
from setting.models_provider.impl.deepseek_model_provider.model.deepseek_chat_model import DeepSeekChatModel
from smartdoc.conf import PROJECT_DIR


class DeepSeekLLMModelCredential(BaseForm, BaseModelCredential):

def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False):
model_type_list = DeepSeekModelProvider().get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')

for key in ['api_key']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
else:
return False
try:
model = DeepSeekModelProvider().get_model(model_type, model_name, model_credential)
model.invoke([HumanMessage(content='你好')])
except Exception as e:
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
else:
return False
return True

def encryption_dict(self, model: Dict[str, object]):
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}

api_key = forms.PasswordInputField('API Key', required=True)


deepseek_llm_model_credential = DeepSeekLLMModelCredential()

model_dict = {
'deepseek-chat': ModelInfo('deepseek-chat', '擅长通用对话任务,支持 32K 上下文', ModelTypeConst.LLM,
deepseek_llm_model_credential,
),
'deepseek-coder': ModelInfo('deepseek-coder', '擅长处理编程任务,支持 16K 上下文', ModelTypeConst.LLM,
deepseek_llm_model_credential,
),
}


class DeepSeekModelProvider(IModelProvider):

def get_dialogue_number(self):
return 3

def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> DeepSeekChatModel:
deepseek_chat_open_ai = DeepSeekChatModel(
model=model_name,
openai_api_base='https://api.deepseek.com',
openai_api_key=model_credential.get('api_key')
)
return deepseek_chat_open_ai

def get_model_credential(self, model_type, model_name):
if model_name in model_dict:
return model_dict.get(model_name).model_credential
return deepseek_llm_model_credential

def get_model_provide_info(self):
return ModelProvideInfo(provider='model_deepseek_provider', name='DeepSeek', icon=get_file_content(
os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'deepseek_model_provider', 'icon',
'deepseek_icon_svg')))

def get_model_list(self, model_type: str):
if model_type is None:
raise AppApiException(500, '模型类型不能为空')
return [model_dict.get(key).to_dict() for key in
list(filter(lambda key: model_dict.get(key).model_type == model_type, model_dict.keys()))]

def get_model_type_list(self):
return [{'key': "大语言模型", 'value': "LLM"}]
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
<svg width="100%" height="100%" viewBox="0 0 50 50" fill="none" xmlns="http://www.w3.org/2000/svg"
xmlns:xlink="http://www.w3.org/1999/xlink">
<path id="path"
d="M48.8354 10.0479C48.3232 9.79199 48.1025 10.2798 47.8032 10.5278C47.7007 10.6079 47.6143 10.7119 47.5273 10.8076C46.7793 11.624 45.9048 12.1597 44.7622 12.0957C43.0923 12 41.666 12.5356 40.4058 13.8398C40.1377 12.2319 39.2476 11.272 37.8926 10.6558C37.1836 10.3359 36.4668 10.0156 35.9702 9.31982C35.6235 8.82373 35.5293 8.27197 35.356 7.72754C35.2456 7.3999 35.1353 7.06396 34.7651 7.00781C34.3633 6.94385 34.2056 7.2876 34.0479 7.57568C33.418 8.75195 33.1733 10.0479 33.1973 11.3599C33.2524 14.312 34.4736 16.6641 36.8999 18.3359C37.1758 18.5278 37.2466 18.7197 37.1597 19C36.9946 19.5757 36.7974 20.1357 36.624 20.7119C36.5137 21.0801 36.3486 21.1597 35.9624 21C34.6309 20.4321 33.481 19.5918 32.4644 18.5757C30.7393 16.8721 29.1792 14.9917 27.2334 13.52C26.7764 13.1758 26.3193 12.856 25.8467 12.5518C23.8618 10.584 26.1069 8.96777 26.627 8.77588C27.1704 8.57568 26.8159 7.8877 25.0591 7.896C23.3022 7.90381 21.6953 8.50391 19.647 9.30371C19.3477 9.42383 19.0322 9.51172 18.7095 9.58398C16.8501 9.22363 14.9199 9.14355 12.9033 9.37598C9.10596 9.80762 6.07275 11.6396 3.84326 14.7681C1.16455 18.5278 0.53418 22.7998 1.30664 27.2559C2.11768 31.9521 4.46582 35.8398 8.07373 38.8799C11.8159 42.0322 16.1255 43.5762 21.041 43.2803C24.0269 43.104 27.3516 42.6963 31.1016 39.4561C32.0469 39.936 33.0396 40.1279 34.686 40.272C35.9546 40.3921 37.1758 40.208 38.1211 40.0078C39.6021 39.688 39.4995 38.2881 38.9639 38.0322C34.623 35.9678 35.5762 36.8081 34.71 36.1279C36.9155 33.4639 40.2402 30.6958 41.54 21.728C41.6426 21.0161 41.5557 20.5679 41.54 19.9917C41.5322 19.6396 41.6108 19.5039 42.0049 19.4639C43.0923 19.3359 44.1479 19.0317 45.1167 18.4878C47.9292 16.9199 49.064 14.3438 49.3315 11.2559C49.3711 10.7837 49.3237 10.2959 48.8354 10.0479ZM24.3262 37.8398C20.1196 34.4639 18.0791 33.3521 17.2358 33.3999C16.4482 33.4482 16.5898 34.3682 16.7632 34.9678C16.9443 35.5601 17.1812 35.9683 17.5117 36.4878C17.7402 36.832 17.8979 37.3442 17.2832 37.728C15.9282 38.584 13.5728 37.4399 13.4624 37.3838C10.7207 35.7358 8.42822 33.5601 6.81348 30.584C5.25342 27.7197 4.34766 24.6479 4.19775 21.3677C4.1582 20.5757 4.38672 20.2959 5.15869 20.1519C6.17529 19.96 7.22314 19.9199 8.23926 20.0718C12.5327 20.7119 16.1885 22.6719 19.2529 25.7759C21.002 27.5439 22.3252 29.6558 23.6885 31.7202C25.1377 33.9121 26.6978 36 28.6831 37.7119C29.3843 38.312 29.9434 38.7681 30.479 39.104C28.8643 39.2881 26.1699 39.3281 24.3262 37.8398ZM26.3433 24.6001C26.3433 24.248 26.6191 23.9678 26.9658 23.9678C27.0444 23.9678 27.1152 23.9839 27.1782 24.0078C27.2651 24.04 27.3438 24.0879 27.4067 24.1602C27.5171 24.272 27.5801 24.4321 27.5801 24.6001C27.5801 24.9521 27.3042 25.2319 26.9575 25.2319C26.6108 25.2319 26.3433 24.9521 26.3433 24.6001ZM32.6064 27.8799C32.2046 28.0479 31.8027 28.1919 31.4165 28.208C30.8179 28.2397 30.1641 27.9922 29.8096 27.688C29.2583 27.2158 28.8643 26.9521 28.6987 26.1279C28.6279 25.7759 28.6675 25.2319 28.7305 24.9199C28.8721 24.248 28.7144 23.8159 28.2495 23.4238C27.8716 23.104 27.3911 23.0161 26.8633 23.0161C26.666 23.0161 26.4849 22.9277 26.3511 22.856C26.1304 22.7441 25.9492 22.4639 26.1226 22.1201C26.1777 22.0078 26.4458 21.7358 26.5088 21.688C27.2256 21.272 28.0527 21.4077 28.8169 21.7197C29.5259 22.0161 30.0615 22.5601 30.834 23.3281C31.6216 24.2559 31.7632 24.5117 32.2124 25.208C32.5669 25.752 32.8901 26.312 33.1104 26.9521C33.2446 27.3521 33.0713 27.6802 32.6064 27.8799Z"
fill="#4D6BFE" fill-opacity="1.000000" fill-rule="nonzero" />
</svg>
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project :MaxKB
@File :deepseek_chat_model.py
@Author :Brian Yang
@Date :5/12/24 7:44 AM
"""
from typing import List

from langchain_core.messages import BaseMessage, get_buffer_string
from langchain_openai import ChatOpenAI

from common.config.tokenizer_manage_config import TokenizerManage


class DeepSeekChatModel(ChatOpenAI):
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
try:
return super().get_num_tokens_from_messages(messages)
except Exception as e:
tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])

def get_num_tokens(self, text: str) -> int:
try:
return super().get_num_tokens(text)
except Exception as e:
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(text))
15 changes: 14 additions & 1 deletion ui/src/views/template/index.vue
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
ref="createModelRef"
@submit="list_model"
@change="openCreateModel($event)"
:key="dialogState.createModelDialogKey"
></CreateModelDialog>

<SelectProviderDialog
Expand All @@ -81,7 +82,7 @@

<script lang="ts" setup>
import { ElMessage } from 'element-plus'
import { onMounted, ref, computed, watch } from 'vue'
import { onMounted, ref, computed, reactive } from 'vue'
import ModelApi from '@/api/model'
import type { Provider, Model } from '@/api/type/model'
import AppIcon from '@/components/icons/AppIcon.vue'
Expand Down Expand Up @@ -128,6 +129,7 @@ const openCreateModel = (provider?: Provider) => {
createModelRef.value?.open(provider)
} else {
selectProviderRef.value?.open()
refreshCreateModelDialogKey() // 更新key
}
}

Expand All @@ -138,6 +140,16 @@ const list_model = () => {
})
}

// 添加一个响应式的state来存储dialog的key
const dialogState = reactive({
createModelDialogKey: Date.now() // 初始值为当前的时间戳
})

// 更新dialogState.createModelDialogKey的函数
const refreshCreateModelDialogKey = () => {
dialogState.createModelDialogKey = Date.now() // 更新为新的时间戳
}

onMounted(() => {
ModelApi.getProvider(loading).then((ok) => {
active_provider.value = allObj
Expand All @@ -154,6 +166,7 @@ onMounted(() => {
width: var(--setting-left-width);
min-width: var(--setting-left-width);
}

.model-list-height {
height: calc(var(--create-dataset-height) - 70px);
}
Expand Down
Loading