Skip to content

Commit 15ab598

Browse files
committed
feat: Xinference Image Model
1 parent 767c284 commit 15ab598

File tree

5 files changed

+400
-0
lines changed

5 files changed

+400
-0
lines changed
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# coding=utf-8
2+
import base64
3+
import os
4+
from typing import Dict
5+
6+
from langchain_core.messages import HumanMessage
7+
8+
from common import forms
9+
from common.exception.app_exception import AppApiException
10+
from common.forms import BaseForm, TooltipLabel
11+
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
12+
13+
class XinferenceImageModelParams(BaseForm):
14+
temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
15+
required=True, default_value=0.7,
16+
_min=0.1,
17+
_max=1.0,
18+
_step=0.01,
19+
precision=2)
20+
21+
max_tokens = forms.SliderField(
22+
TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'),
23+
required=True, default_value=800,
24+
_min=1,
25+
_max=100000,
26+
_step=1,
27+
precision=0)
28+
29+
30+
31+
class XinferenceImageModelCredential(BaseForm, BaseModelCredential):
32+
api_base = forms.TextInputField('API 域名', required=True)
33+
api_key = forms.PasswordInputField('API Key', required=True)
34+
35+
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
36+
raise_exception=False):
37+
model_type_list = provider.get_model_type_list()
38+
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
39+
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
40+
41+
for key in ['api_base', 'api_key']:
42+
if key not in model_credential:
43+
if raise_exception:
44+
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
45+
else:
46+
return False
47+
try:
48+
model = provider.get_model(model_type, model_name, model_credential)
49+
res = model.stream([HumanMessage(content=[{"type": "text", "text": "你好"}])])
50+
for chunk in res:
51+
print(chunk)
52+
except Exception as e:
53+
if isinstance(e, AppApiException):
54+
raise e
55+
if raise_exception:
56+
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
57+
else:
58+
return False
59+
return True
60+
61+
def encryption_dict(self, model: Dict[str, object]):
62+
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
63+
64+
def get_model_params_setting_form(self, model_name):
65+
return XinferenceImageModelParams()
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# coding=utf-8
2+
import base64
3+
import os
4+
from typing import Dict
5+
6+
from langchain_core.messages import HumanMessage
7+
8+
from common import forms
9+
from common.exception.app_exception import AppApiException
10+
from common.forms import BaseForm, TooltipLabel
11+
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
12+
13+
14+
class XinferenceTTIModelParams(BaseForm):
15+
size = forms.SingleSelect(
16+
TooltipLabel('图片尺寸', '指定生成图片的尺寸, 如: 1024x1024'),
17+
required=True,
18+
default_value='1024x1024',
19+
option_list=[
20+
{'value': '1024x1024', 'label': '1024x1024'},
21+
{'value': '1024x1792', 'label': '1024x1792'},
22+
{'value': '1792x1024', 'label': '1792x1024'},
23+
],
24+
text_field='label',
25+
value_field='value'
26+
)
27+
28+
quality = forms.SingleSelect(
29+
TooltipLabel('图片质量', ''),
30+
required=True,
31+
default_value='standard',
32+
option_list=[
33+
{'value': 'standard', 'label': 'standard'},
34+
{'value': 'hd', 'label': 'hd'},
35+
],
36+
text_field='label',
37+
value_field='value'
38+
)
39+
40+
n = forms.SliderField(
41+
TooltipLabel('图片数量', '指定生成图片的数量'),
42+
required=True, default_value=1,
43+
_min=1,
44+
_max=10,
45+
_step=1,
46+
precision=0)
47+
48+
49+
class XinferenceTextToImageModelCredential(BaseForm, BaseModelCredential):
50+
api_base = forms.TextInputField('API 域名', required=True)
51+
api_key = forms.PasswordInputField('API Key', required=True)
52+
53+
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
54+
raise_exception=False):
55+
model_type_list = provider.get_model_type_list()
56+
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
57+
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
58+
59+
for key in ['api_base', 'api_key']:
60+
if key not in model_credential:
61+
if raise_exception:
62+
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
63+
else:
64+
return False
65+
try:
66+
model = provider.get_model(model_type, model_name, model_credential)
67+
res = model.check_auth()
68+
print(res)
69+
except Exception as e:
70+
if isinstance(e, AppApiException):
71+
raise e
72+
if raise_exception:
73+
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
74+
else:
75+
return False
76+
return True
77+
78+
def encryption_dict(self, model: Dict[str, object]):
79+
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
80+
81+
def get_model_params_setting_form(self, model_name):
82+
return XinferenceTTIModelParams()
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from typing import Dict
2+
3+
from langchain_openai.chat_models import ChatOpenAI
4+
5+
from common.config.tokenizer_manage_config import TokenizerManage
6+
from setting.models_provider.base_model_provider import MaxKBBaseModel
7+
8+
9+
def custom_get_token_ids(text: str):
10+
tokenizer = TokenizerManage.get_tokenizer()
11+
return tokenizer.encode(text)
12+
13+
14+
class XinferenceImage(MaxKBBaseModel, ChatOpenAI):
15+
16+
@staticmethod
17+
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
18+
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
19+
return XinferenceImage(
20+
model_name=model_name,
21+
openai_api_base=model_credential.get('api_base'),
22+
openai_api_key=model_credential.get('api_key'),
23+
# stream_options={"include_usage": True},
24+
streaming=True,
25+
**optional_params,
26+
)
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import base64
2+
from typing import Dict
3+
4+
from openai import OpenAI
5+
6+
from common.config.tokenizer_manage_config import TokenizerManage
7+
from common.util.common import bytes_to_uploaded_file
8+
from dataset.serializers.file_serializers import FileSerializer
9+
from setting.models_provider.base_model_provider import MaxKBBaseModel
10+
from setting.models_provider.impl.base_tti import BaseTextToImage
11+
12+
13+
def custom_get_token_ids(text: str):
14+
tokenizer = TokenizerManage.get_tokenizer()
15+
return tokenizer.encode(text)
16+
17+
18+
class XinferenceTextToImage(MaxKBBaseModel, BaseTextToImage):
19+
api_base: str
20+
api_key: str
21+
model: str
22+
params: dict
23+
24+
def __init__(self, **kwargs):
25+
super().__init__(**kwargs)
26+
self.api_key = kwargs.get('api_key')
27+
self.api_base = kwargs.get('api_base')
28+
self.model = kwargs.get('model')
29+
self.params = kwargs.get('params')
30+
31+
@staticmethod
32+
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
33+
optional_params = {'params': {'size': '1024x1024', 'quality': 'standard', 'n': 1}}
34+
for key, value in model_kwargs.items():
35+
if key not in ['model_id', 'use_local', 'streaming']:
36+
optional_params['params'][key] = value
37+
return XinferenceTextToImage(
38+
model=model_name,
39+
api_base=model_credential.get('api_base'),
40+
api_key=model_credential.get('api_key'),
41+
**optional_params,
42+
)
43+
44+
def is_cache_model(self):
45+
return False
46+
47+
def check_auth(self):
48+
chat = OpenAI(api_key=self.api_key, base_url=self.api_base)
49+
response_list = chat.models.with_raw_response.list()
50+
51+
# self.generate_image('生成一个小猫图片')
52+
53+
def generate_image(self, prompt: str, negative_prompt: str = None):
54+
chat = OpenAI(api_key=self.api_key, base_url=self.api_base)
55+
res = chat.images.generate(model=self.model, prompt=prompt, response_format='b64_json', **self.params)
56+
file_urls = []
57+
# 临时文件
58+
for img in res.data:
59+
file = bytes_to_uploaded_file(base64.b64decode(img.b64_json), 'file_name.jpg')
60+
meta = {
61+
'debug': True,
62+
}
63+
file_url = FileSerializer(data={'file': file, 'meta': meta}).upload()
64+
file_urls.append(f'http://localhost:8080{file_url}')
65+
66+
return file_urls

0 commit comments

Comments
 (0)