Skip to content

feat: Support Azure embedding model #1868

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,39 @@
from common.util.file_util import get_file_content
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, \
ModelTypeConst, ModelInfoManage
from setting.models_provider.impl.azure_model_provider.credential.embedding import AzureOpenAIEmbeddingCredential
from setting.models_provider.impl.azure_model_provider.credential.llm import AzureLLMModelCredential
from setting.models_provider.impl.azure_model_provider.model.azure_chat_model import AzureChatModel
from setting.models_provider.impl.azure_model_provider.model.embedding import AzureOpenAIEmbeddingModel
from smartdoc.conf import PROJECT_DIR

base_azure_llm_model_credential = AzureLLMModelCredential()
base_azure_embedding_model_credential = AzureOpenAIEmbeddingCredential()

default_model_info = ModelInfo('Azure OpenAI', '具体的基础模型由部署名决定', ModelTypeConst.LLM,
base_azure_llm_model_credential, AzureChatModel, api_version='2024-02-15-preview'
)

model_info_manage = ModelInfoManage.builder().append_default_model_info(default_model_info).append_model_info(
default_model_info).build()
embedding_model_info = [
ModelInfo('text-embedding-3-large', '具体的基础模型由部署名决定', ModelTypeConst.EMBEDDING,
base_azure_embedding_model_credential, AzureOpenAIEmbeddingModel, api_version='2023-05-15'
),
ModelInfo('text-embedding-3-small', '', ModelTypeConst.EMBEDDING,
base_azure_embedding_model_credential, AzureOpenAIEmbeddingModel, api_version='2023-05-15'
),
ModelInfo('text-embedding-ada-002', '', ModelTypeConst.EMBEDDING,
base_azure_embedding_model_credential, AzureOpenAIEmbeddingModel, api_version='2023-05-15'
),
]

model_info_manage = (
ModelInfoManage.builder()
.append_default_model_info(default_model_info)
.append_model_info(default_model_info)
.append_model_info_list(embedding_model_info)
.append_default_model_info(embedding_model_info[0])
.build()
)


class AzureModelProvider(IModelProvider):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# coding=utf-8
"""
@project: MaxKB
@Author:虎
@file: llm.py
@date:2024/7/11 17:08
@desc:
"""
from typing import Dict

from langchain_core.messages import HumanMessage

from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm, TooltipLabel
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode



class AzureOpenAIEmbeddingCredential(BaseForm, BaseModelCredential):

def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')

for key in ['api_base', 'api_key', 'api_version']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model.embed_query('你好')
except Exception as e:
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, '校验失败,请检查参数是否正确')
else:
return False

return True

def encryption_dict(self, model: Dict[str, object]):
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}

api_version = forms.TextInputField("API 版本 (api_version)", required=True)

api_base = forms.TextInputField('API 域名 (azure_endpoint)', required=True)

api_key = forms.PasswordInputField("API Key (api_key)", required=True)

Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# coding=utf-8
"""
@project: MaxKB
@Author:虎
@file: embedding.py
@date:2024/7/12 17:44
@desc:
"""
from typing import Dict

from langchain_openai import AzureOpenAIEmbeddings

from setting.models_provider.base_model_provider import MaxKBBaseModel


class AzureOpenAIEmbeddingModel(MaxKBBaseModel, AzureOpenAIEmbeddings):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
return AzureOpenAIEmbeddings(
model=model_name,
openai_api_key=model_credential.get('api_key'),
azure_endpoint=model_credential.get('api_base'),
openai_api_version=model_credential.get('api_version'),
openai_api_type="azure",
)
Loading