Skip to content

Commit 9160410

Browse files
committed
feat: 创建模型同时配置高级参数
1 parent 3396018 commit 9160410

File tree

6 files changed

+273
-130
lines changed

6 files changed

+273
-130
lines changed

apps/setting/serializers/provider_serializers.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,10 @@
3030

3131
def get_default_model_params_setting(provider, model_type, model_name):
3232
credential = get_model_credential(provider, model_type, model_name)
33-
model_params_setting = credential.get_model_params_setting_form(model_name).to_form_list()
34-
return model_params_setting
33+
setting_form = credential.get_model_params_setting_form(model_name)
34+
if setting_form is not None:
35+
return setting_form.to_form_list()
36+
return []
3537

3638

3739
class ModelPullManage:
@@ -207,11 +209,12 @@ def insert(self, user_id, with_valid=False):
207209
model_type = self.data.get('model_type')
208210
model_name = self.data.get('model_name')
209211
permission_type = self.data.get('permission_type')
212+
model_params_form = self.data.get('model_params_form')
210213
model_credential_str = json.dumps(credential)
211214
model = Model(id=uuid.uuid1(), status=status, user_id=user_id, name=name,
212215
credential=rsa_long_encrypt(model_credential_str),
213216
provider=provider, model_type=model_type, model_name=model_name,
214-
model_params_form=get_default_model_params_setting(provider, model_type, model_name),
217+
model_params_form=model_params_form,
215218
permission_type=permission_type)
216219
model.save()
217220
if status == Status.DOWNLOAD:

apps/setting/urls.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
path('provider/model_type_list', views.Provide.ModelTypeList.as_view(), name="provider/model_type_list"),
1515
path('provider/model_list', views.Provide.ModelList.as_view(),
1616
name="provider/model_name_list"),
17+
path('provider/model_params_form', views.Provide.ModelParamsForm.as_view(),
18+
name="provider/model_params_form"),
1719
path('provider/model_form', views.Provide.ModelForm.as_view(),
1820
name="provider/model_form"),
1921
path('model', views.Model.as_view(), name='model'),

apps/setting/views/model.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from common.response import result
1717
from common.util.common import query_params_to_single_dict
1818
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
19-
from setting.serializers.provider_serializers import ProviderSerializer, ModelSerializer
19+
from setting.serializers.provider_serializers import ProviderSerializer, ModelSerializer, get_default_model_params_setting
2020
from setting.swagger_api.provide_api import ProvideApi, ModelCreateApi, ModelQueryApi, ModelEditApi
2121

2222

@@ -207,6 +207,24 @@ def get(self, request: Request):
207207
ModelProvideConstants[provider].value.get_model_list(
208208
model_type))
209209

210+
class ModelParamsForm(APIView):
211+
authentication_classes = [TokenAuth]
212+
213+
@action(methods=['GET'], detail=False)
214+
@swagger_auto_schema(operation_summary="获取模型默认参数",
215+
operation_id="获取模型创建表单",
216+
manual_parameters=ProvideApi.ModelList.get_request_params_api(),
217+
responses=result.get_api_array_response(ProvideApi.ModelList.get_response_body_api())
218+
, tags=["模型"]
219+
)
220+
@has_permissions(PermissionConstants.MODEL_READ)
221+
def get(self, request: Request):
222+
provider = request.query_params.get('provider')
223+
model_type = request.query_params.get('model_type')
224+
model_name = request.query_params.get('model_name')
225+
226+
return result.success(get_default_model_params_setting(provider, model_type, model_name))
227+
210228
class ModelForm(APIView):
211229
authentication_classes = [TokenAuth]
212230

ui/src/api/model.ts

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,15 @@ const listBaseModel: (
9898
return get(`${prefix_provider}/model_list`, { provider, model_type }, loading)
9999
}
100100

101+
const listBaseModelParamsForm: (
102+
provider: string,
103+
model_type: string,
104+
model_name: string,
105+
loading?: Ref<boolean>
106+
) => Promise<Result<Array<BaseModel>>> = (provider, model_type, model_name, loading) => {
107+
return get(`${prefix_provider}/model_params_form`, { provider, model_type, model_name}, loading)
108+
}
109+
101110
/**
102111
* 创建模型
103112
* @param request 请求对象
@@ -187,6 +196,7 @@ export default {
187196
getModelCreateForm,
188197
listModelType,
189198
listBaseModel,
199+
listBaseModelParamsForm,
190200
createModel,
191201
updateModel,
192202
deleteModel,

0 commit comments

Comments
 (0)