Skip to content

Commit 7a17bb5

Browse files
committed
feat: 创建模型同时配置高级参数
1 parent bfc9f69 commit 7a17bb5

File tree

12 files changed

+345
-144
lines changed

12 files changed

+345
-144
lines changed

apps/setting/models_provider/impl/openai_model_provider/credential/tti.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,32 @@
1010
from common.forms import BaseForm, TooltipLabel
1111
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
1212

13+
1314
class OpenAITTIModelParams(BaseForm):
14-
size = forms.TextInputField(
15+
size = forms.SingleSelect(
1516
TooltipLabel('图片尺寸', '指定生成图片的尺寸, 如: 1024x1024'),
16-
required=True, default_value='1024x1024')
17+
required=True,
18+
default_value='1024x1024',
19+
option_list=[
20+
{'value': '1024x1024', 'label': '1024x1024'},
21+
{'value': '1024x1792', 'label': '1024x1792'},
22+
{'value': '1792x1024', 'label': '1792x1024'},
23+
],
24+
text_field='label',
25+
value_field='value'
26+
)
1727

18-
quality = forms.TextInputField(
28+
quality = forms.SingleSelect(
1929
TooltipLabel('图片质量', ''),
20-
required=True, default_value='standard')
30+
required=True,
31+
default_value='standard',
32+
option_list=[
33+
{'value': 'standard', 'label': 'standard'},
34+
{'value': 'hd', 'label': 'hd'},
35+
],
36+
text_field='label',
37+
value_field='value'
38+
)
2139

2240
n = forms.SliderField(
2341
TooltipLabel('图片数量', '指定生成图片的数量'),

apps/setting/models_provider/impl/openai_model_provider/model/tti.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def __init__(self, **kwargs):
3232

3333
@staticmethod
3434
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
35-
optional_params = {'params': {}}
35+
optional_params = {'params': {'size': '1024x1024', 'quality': 'standard', 'n': 1}}
3636
for key, value in model_kwargs.items():
3737
if key not in ['model_id', 'use_local', 'streaming']:
3838
optional_params['params'][key] = value

apps/setting/models_provider/impl/qwen_model_provider/credential/tti.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,19 +19,44 @@
1919

2020

2121
class QwenModelParams(BaseForm):
22-
size = forms.TextInputField(
22+
size = forms.SingleSelect(
2323
TooltipLabel('图片尺寸', '指定生成图片的尺寸, 如: 1024x1024'),
24-
required=True, default_value='1024x1024')
24+
required=True,
25+
default_value='1024*1024',
26+
option_list=[
27+
{'value': '1024*1024', 'label': '1024*1024'},
28+
{'value': '720*1280', 'label': '720*1280'},
29+
{'value': '768*1152', 'label': '768*1152'},
30+
{'value': '1280*720', 'label': '1280*720'},
31+
],
32+
text_field='label',
33+
value_field='value')
2534
n = forms.SliderField(
2635
TooltipLabel('图片数量', '指定生成图片的数量'),
2736
required=True, default_value=1,
2837
_min=1,
2938
_max=4,
3039
_step=1,
3140
precision=0)
32-
style = forms.TextInputField(
41+
style = forms.SingleSelect(
3342
TooltipLabel('风格', '指定生成图片的风格'),
34-
required=True, default_value='<auto>')
43+
required=True,
44+
default_value='<auto>',
45+
option_list=[
46+
{'value': '<auto>', 'label': '默认值,由模型随机输出图像风格'},
47+
{'value': '<photography>', 'label': '摄影'},
48+
{'value': '<portrait>', 'label': '人像写真'},
49+
{'value': '<3d cartoon>', 'label': '3D卡通'},
50+
{'value': '<anime>', 'label': '动画'},
51+
{'value': '<oil painting>', 'label': '油画'},
52+
{'value': '<watercolor>', 'label': '水彩'},
53+
{'value': '<sketch>', 'label': '素描'},
54+
{'value': '<chinese painting>', 'label': '中国画'},
55+
{'value': '<flat illustration>', 'label': '扁平插画'},
56+
],
57+
text_field='label',
58+
value_field='value'
59+
)
3560

3661

3762
class QwenTextToImageModelCredential(BaseForm, BaseModelCredential):

apps/setting/models_provider/impl/qwen_model_provider/model/tti.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def __init__(self, **kwargs):
2828

2929
@staticmethod
3030
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
31-
optional_params = {'params': {}}
31+
optional_params = {'params': {'size': '1024*1024', 'style': '<auto>', 'n': 1}}
3232
for key, value in model_kwargs.items():
3333
if key not in ['model_id', 'use_local', 'streaming']:
3434
optional_params['params'][key] = value
@@ -50,6 +50,7 @@ def generate_image(self, prompt: str, negative_prompt: str = None):
5050
prompt=prompt,
5151
negative_prompt=negative_prompt,
5252
**self.params)
53+
print(rsp)
5354
file_urls = []
5455
if rsp.status_code == HTTPStatus.OK:
5556
for result in rsp.output.results:

apps/setting/models_provider/impl/zhipu_model_provider/credential/tti.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,22 @@
88

99

1010
class ZhiPuTTIModelParams(BaseForm):
11-
size = forms.TextInputField(
11+
size = forms.SingleSelect(
1212
TooltipLabel('图片尺寸',
1313
'图片尺寸,仅 cogview-3-plus 支持该参数。可选范围:[1024x1024,768x1344,864x1152,1344x768,1152x864,1440x720,720x1440],默认是1024x1024。'),
14-
required=True, default_value='1024x1024')
14+
required=True,
15+
default_value='1024x1024',
16+
option_list=[
17+
{'value': '1024x1024', 'label': '1024x1024'},
18+
{'value': '768x1344', 'label': '768x1344'},
19+
{'value': '864x1152', 'label': '864x1152'},
20+
{'value': '1344x768', 'label': '1344x768'},
21+
{'value': '1152x864', 'label': '1152x864'},
22+
{'value': '1440x720', 'label': '1440x720'},
23+
{'value': '720x1440', 'label': '720x1440'},
24+
],
25+
text_field='label',
26+
value_field='value')
1527

1628

1729
class ZhiPuTextToImageModelCredential(BaseForm, BaseModelCredential):

apps/setting/models_provider/impl/zhipu_model_provider/model/tti.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def __init__(self, **kwargs):
3030

3131
@staticmethod
3232
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
33-
optional_params = {'params': {}}
33+
optional_params = {'params': {'size': '1024x1024'}}
3434
for key, value in model_kwargs.items():
3535
if key not in ['model_id', 'use_local', 'streaming']:
3636
optional_params['params'][key] = value
@@ -62,7 +62,7 @@ def generate_image(self, prompt: str, negative_prompt: str = None):
6262
)
6363
file_urls = []
6464
for content in response.data:
65-
url = content['url']
65+
url = content.url
6666
print(url)
6767
file_name = url.split('/')[-1]
6868
file = bytes_to_uploaded_file(requests.get(url).content, file_name)

apps/setting/serializers/provider_serializers.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,10 @@
3030

3131
def get_default_model_params_setting(provider, model_type, model_name):
3232
credential = get_model_credential(provider, model_type, model_name)
33-
model_params_setting = credential.get_model_params_setting_form(model_name).to_form_list()
34-
return model_params_setting
33+
setting_form = credential.get_model_params_setting_form(model_name)
34+
if setting_form is not None:
35+
return setting_form.to_form_list()
36+
return []
3537

3638

3739
class ModelPullManage:
@@ -178,6 +180,8 @@ class Create(serializers.Serializer):
178180

179181
model_name = serializers.CharField(required=True, error_messages=ErrMessage.char("基础模型"))
180182

183+
model_params_form = serializers.ListField(required=False, default=list, error_messages=ErrMessage.char("参数配置"))
184+
181185
credential = serializers.DictField(required=True, error_messages=ErrMessage.dict("认证信息"))
182186

183187
def is_valid(self, *, raise_exception=False):
@@ -207,11 +211,12 @@ def insert(self, user_id, with_valid=False):
207211
model_type = self.data.get('model_type')
208212
model_name = self.data.get('model_name')
209213
permission_type = self.data.get('permission_type')
214+
model_params_form = self.data.get('model_params_form')
210215
model_credential_str = json.dumps(credential)
211216
model = Model(id=uuid.uuid1(), status=status, user_id=user_id, name=name,
212217
credential=rsa_long_encrypt(model_credential_str),
213218
provider=provider, model_type=model_type, model_name=model_name,
214-
model_params_form=get_default_model_params_setting(provider, model_type, model_name),
219+
model_params_form=model_params_form,
215220
permission_type=permission_type)
216221
model.save()
217222
if status == Status.DOWNLOAD:

apps/setting/urls.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
path('provider/model_type_list', views.Provide.ModelTypeList.as_view(), name="provider/model_type_list"),
1515
path('provider/model_list', views.Provide.ModelList.as_view(),
1616
name="provider/model_name_list"),
17+
path('provider/model_params_form', views.Provide.ModelParamsForm.as_view(),
18+
name="provider/model_params_form"),
1719
path('provider/model_form', views.Provide.ModelForm.as_view(),
1820
name="provider/model_form"),
1921
path('model', views.Model.as_view(), name='model'),

apps/setting/views/model.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from common.response import result
1717
from common.util.common import query_params_to_single_dict
1818
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
19-
from setting.serializers.provider_serializers import ProviderSerializer, ModelSerializer
19+
from setting.serializers.provider_serializers import ProviderSerializer, ModelSerializer, get_default_model_params_setting
2020
from setting.swagger_api.provide_api import ProvideApi, ModelCreateApi, ModelQueryApi, ModelEditApi
2121

2222

@@ -207,6 +207,24 @@ def get(self, request: Request):
207207
ModelProvideConstants[provider].value.get_model_list(
208208
model_type))
209209

210+
class ModelParamsForm(APIView):
211+
authentication_classes = [TokenAuth]
212+
213+
@action(methods=['GET'], detail=False)
214+
@swagger_auto_schema(operation_summary="获取模型默认参数",
215+
operation_id="获取模型创建表单",
216+
manual_parameters=ProvideApi.ModelList.get_request_params_api(),
217+
responses=result.get_api_array_response(ProvideApi.ModelList.get_response_body_api())
218+
, tags=["模型"]
219+
)
220+
@has_permissions(PermissionConstants.MODEL_READ)
221+
def get(self, request: Request):
222+
provider = request.query_params.get('provider')
223+
model_type = request.query_params.get('model_type')
224+
model_name = request.query_params.get('model_name')
225+
226+
return result.success(get_default_model_params_setting(provider, model_type, model_name))
227+
210228
class ModelForm(APIView):
211229
authentication_classes = [TokenAuth]
212230

ui/src/api/model.ts

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,15 @@ const listBaseModel: (
9898
return get(`${prefix_provider}/model_list`, { provider, model_type }, loading)
9999
}
100100

101+
const listBaseModelParamsForm: (
102+
provider: string,
103+
model_type: string,
104+
model_name: string,
105+
loading?: Ref<boolean>
106+
) => Promise<Result<Array<BaseModel>>> = (provider, model_type, model_name, loading) => {
107+
return get(`${prefix_provider}/model_params_form`, { provider, model_type, model_name}, loading)
108+
}
109+
101110
/**
102111
* 创建模型
103112
* @param request 请求对象
@@ -187,6 +196,7 @@ export default {
187196
getModelCreateForm,
188197
listModelType,
189198
listBaseModel,
199+
listBaseModelParamsForm,
190200
createModel,
191201
updateModel,
192202
deleteModel,

0 commit comments

Comments
 (0)