Skip to content

Commit 60dea47

Browse files
committed
feat: Image understanding and image generation models support configuring model parameters
1 parent 977e68f commit 60dea47

File tree

10 files changed

+89
-24
lines changed

10 files changed

+89
-24
lines changed

apps/application/flow/step_node/image_understand_step_node/i_image_understand_node.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ class ImageUnderstandNodeSerializer(serializers.Serializer):
2222

2323
image_list = serializers.ListField(required=False, error_messages=ErrMessage.list("图片"))
2424

25+
model_params_setting = serializers.JSONField(required=False, default=dict, error_messages=ErrMessage.json("模型参数设置"))
26+
27+
2528

2629
class IImageUnderstandNode(INode):
2730
type = 'image-understand-node'
@@ -35,6 +38,7 @@ def _run(self):
3538
return self.execute(image=res, **self.node_params_serializer.data, **self.flow_params_serializer.data)
3639

3740
def execute(self, model_id, system, prompt, dialogue_number, dialogue_type, history_chat_record, stream, chat_id,
41+
model_params_setting,
3842
chat_record_id,
3943
image,
4044
**kwargs) -> NodeResult:

apps/application/flow/step_node/image_understand_step_node/impl/base_image_understand_node.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,14 +70,15 @@ def save_context(self, details, workflow_manage):
7070
self.answer_text = details.get('answer')
7171

7272
def execute(self, model_id, system, prompt, dialogue_number, dialogue_type, history_chat_record, stream, chat_id,
73+
model_params_setting,
7374
chat_record_id,
7475
image,
7576
**kwargs) -> NodeResult:
7677
# 处理不正确的参数
7778
if image is None or not isinstance(image, list):
7879
image = []
7980

80-
image_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'))
81+
image_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'), **model_params_setting)
8182
# 执行详情中的历史消息不需要图片内容
8283
history_message = self.get_history_message_for_details(history_chat_record, dialogue_number)
8384
self.context['history_message'] = history_message

apps/common/forms/text_input_field.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
"""
99
from typing import Dict
1010

11+
from common.forms import BaseLabel
1112
from common.forms.base_field import BaseField, TriggerType
1213

1314

@@ -16,7 +17,7 @@ class TextInputField(BaseField):
1617
文本输入框
1718
"""
1819

19-
def __init__(self, label: str,
20+
def __init__(self, label: str or BaseLabel,
2021
required: bool = False,
2122
default_value=None,
2223
relation_show_field_dict: Dict = None,

apps/setting/models_provider/impl/openai_model_provider/credential/image.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,26 @@
77

88
from common import forms
99
from common.exception.app_exception import AppApiException
10-
from common.forms import BaseForm
10+
from common.forms import BaseForm, TooltipLabel
1111
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
1212

13+
class OpenAIImageModelParams(BaseForm):
14+
temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
15+
required=True, default_value=0.7,
16+
_min=0.1,
17+
_max=1.0,
18+
_step=0.01,
19+
precision=2)
20+
21+
max_tokens = forms.SliderField(
22+
TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'),
23+
required=True, default_value=800,
24+
_min=1,
25+
_max=100000,
26+
_step=1,
27+
precision=0)
28+
29+
1330

1431
class OpenAIImageModelCredential(BaseForm, BaseModelCredential):
1532
api_base = forms.TextInputField('API 域名', required=True)
@@ -45,4 +62,4 @@ def encryption_dict(self, model: Dict[str, object]):
4562
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
4663

4764
def get_model_params_setting_form(self, model_name):
48-
pass
65+
return OpenAIImageModelParams()

apps/setting/models_provider/impl/openai_model_provider/credential/tti.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,26 @@
77

88
from common import forms
99
from common.exception.app_exception import AppApiException
10-
from common.forms import BaseForm
10+
from common.forms import BaseForm, TooltipLabel
1111
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
1212

13+
class OpenAITTIModelParams(BaseForm):
14+
size = forms.TextInputField(
15+
TooltipLabel('图片尺寸', '指定生成图片的尺寸, 如: 1024x1024'),
16+
required=True, default_value='1024x1024')
17+
18+
quality = forms.TextInputField(
19+
TooltipLabel('图片质量', ''),
20+
required=True, default_value='standard')
21+
22+
n = forms.SliderField(
23+
TooltipLabel('图片数量', '指定生成图片的数量'),
24+
required=True, default_value=1,
25+
_min=1,
26+
_max=10,
27+
_step=1,
28+
precision=0)
29+
1330

1431
class OpenAITextToImageModelCredential(BaseForm, BaseModelCredential):
1532
api_base = forms.TextInputField('API 域名', required=True)
@@ -44,4 +61,4 @@ def encryption_dict(self, model: Dict[str, object]):
4461
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
4562

4663
def get_model_params_setting_form(self, model_name):
47-
pass
64+
return OpenAITTIModelParams()

apps/setting/models_provider/impl/qwen_model_provider/credential/tti.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,19 @@
1919

2020

2121
class QwenModelParams(BaseForm):
22-
temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
23-
required=True, default_value=1.0,
24-
_min=0.1,
25-
_max=1.9,
26-
_step=0.01,
27-
precision=2)
28-
29-
max_tokens = forms.SliderField(
30-
TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'),
31-
required=True, default_value=800,
22+
size = forms.TextInputField(
23+
TooltipLabel('图片尺寸', '指定生成图片的尺寸, 如: 1024x1024'),
24+
required=True, default_value='1024x1024')
25+
n = forms.SliderField(
26+
TooltipLabel('图片数量', '指定生成图片的数量'),
27+
required=True, default_value=1,
3228
_min=1,
33-
_max=100000,
29+
_max=4,
3430
_step=1,
3531
precision=0)
32+
style = forms.TextInputField(
33+
TooltipLabel('风格', '指定生成图片的风格'),
34+
required=True, default_value='<auto>')
3635

3736

3837
class QwenTextToImageModelCredential(BaseForm, BaseModelCredential):

apps/setting/models_provider/impl/zhipu_model_provider/credential/image.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,24 @@
77

88
from common import forms
99
from common.exception.app_exception import AppApiException
10-
from common.forms import BaseForm
10+
from common.forms import BaseForm, TooltipLabel
1111
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
1212

13+
class ZhiPuImageModelParams(BaseForm):
14+
temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
15+
required=True, default_value=0.95,
16+
_min=0.1,
17+
_max=1.0,
18+
_step=0.01,
19+
precision=2)
20+
21+
max_tokens = forms.SliderField(
22+
TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'),
23+
required=True, default_value=1024,
24+
_min=1,
25+
_max=100000,
26+
_step=1,
27+
precision=0)
1328

1429
class ZhiPuImageModelCredential(BaseForm, BaseModelCredential):
1530
api_key = forms.PasswordInputField('API Key', required=True)
@@ -44,4 +59,4 @@ def encryption_dict(self, model: Dict[str, object]):
4459
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
4560

4661
def get_model_params_setting_form(self, model_name):
47-
pass
62+
return ZhiPuImageModelParams()

apps/setting/models_provider/impl/zhipu_model_provider/credential/tti.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,19 @@
11
# coding=utf-8
22
from typing import Dict
33

4-
from langchain_core.messages import HumanMessage
5-
64
from common import forms
75
from common.exception.app_exception import AppApiException
8-
from common.forms import BaseForm
6+
from common.forms import BaseForm, TooltipLabel
97
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
108

119

10+
class ZhiPuTTIModelParams(BaseForm):
11+
size = forms.TextInputField(
12+
TooltipLabel('图片尺寸',
13+
'图片尺寸,仅 cogview-3-plus 支持该参数。可选范围:[1024x1024,768x1344,864x1152,1344x768,1152x864,1440x720,720x1440],默认是1024x1024。'),
14+
required=True, default_value='1024x1024')
15+
16+
1217
class ZhiPuTextToImageModelCredential(BaseForm, BaseModelCredential):
1318
api_key = forms.PasswordInputField('API Key', required=True)
1419

@@ -41,4 +46,4 @@ def encryption_dict(self, model: Dict[str, object]):
4146
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
4247

4348
def get_model_params_setting_form(self, model_name):
44-
pass
49+
return ZhiPuTTIModelParams()

apps/setting/serializers/provider_serializers.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,11 @@
2828
from setting.models_provider.base_model_provider import ValidCode, DownModelChunkStatus
2929
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
3030

31+
def get_default_model_params_setting(provider, model_type, model_name):
32+
credential = get_model_credential(provider, model_type, model_name)
33+
model_params_setting = credential.get_model_params_setting_form(model_name).to_form_list()
34+
return model_params_setting
35+
3136

3237
class ModelPullManage:
3338

@@ -206,6 +211,7 @@ def insert(self, user_id, with_valid=False):
206211
model = Model(id=uuid.uuid1(), status=status, user_id=user_id, name=name,
207212
credential=rsa_long_encrypt(model_credential_str),
208213
provider=provider, model_type=model_type, model_name=model_name,
214+
model_params_form=get_default_model_params_setting(provider, model_type, model_name),
209215
permission_type=permission_type)
210216
model.save()
211217
if status == Status.DOWNLOAD:

ui/src/views/template/component/ModelCard.vue

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@
9494
<template #dropdown>
9595
<el-dropdown-menu>
9696
<el-dropdown-item
97-
v-if="currentModel.model_type === 'TTS' || currentModel.model_type === 'LLM'"
97+
v-if="currentModel.model_type === 'TTS' || currentModel.model_type === 'LLM' || currentModel.model_type === 'IMAGE' || currentModel.model_type === 'TTI'"
9898
:disabled="!is_permisstion"
9999
icon="Setting" @click.stop="openParamSetting"
100100
>

0 commit comments

Comments
 (0)