Skip to content

Commit 98fae97

Browse files
committed
feat: Create model and configure advanced parameters
1 parent 60dea47 commit 98fae97

File tree

15 files changed

+382
-191
lines changed

15 files changed

+382
-191
lines changed

apps/application/flow/step_node/image_generate_step_node/i_image_generate_node.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ class ImageGenerateNodeSerializer(serializers.Serializer):
2121

2222
is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容'))
2323

24+
model_params_setting = serializers.JSONField(required=False, default=dict, error_messages=ErrMessage.json("模型参数设置"))
25+
2426

2527
class IImageGenerateNode(INode):
2628
type = 'image-generate-node'
@@ -32,6 +34,7 @@ def _run(self):
3234
return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data)
3335

3436
def execute(self, model_id, prompt, negative_prompt, dialogue_number, dialogue_type, history_chat_record, chat_id,
37+
model_params_setting,
3538
chat_record_id,
3639
**kwargs) -> NodeResult:
3740
pass

apps/application/flow/step_node/image_generate_step_node/impl/base_image_generate_node.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,13 @@
22
from functools import reduce
33
from typing import List
44

5+
import requests
56
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
67

78
from application.flow.i_step_node import NodeResult
89
from application.flow.step_node.image_generate_step_node.i_image_generate_node import IImageGenerateNode
10+
from common.util.common import bytes_to_uploaded_file
11+
from dataset.serializers.file_serializers import FileSerializer
912
from setting.models_provider.tools import get_model_instance_by_model_user_id
1013

1114

@@ -16,10 +19,12 @@ def save_context(self, details, workflow_manage):
1619
self.answer_text = details.get('answer')
1720

1821
def execute(self, model_id, prompt, negative_prompt, dialogue_number, dialogue_type, history_chat_record, chat_id,
22+
model_params_setting,
1923
chat_record_id,
2024
**kwargs) -> NodeResult:
21-
22-
tti_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'))
25+
print(model_params_setting)
26+
application = self.workflow_manage.work_flow_post_handler.chat_info.application
27+
tti_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'), **model_params_setting)
2328
history_message = self.get_history_message(history_chat_record, dialogue_number)
2429
self.context['history_message'] = history_message
2530
question = self.generate_prompt_question(prompt)
@@ -28,10 +33,21 @@ def execute(self, model_id, prompt, negative_prompt, dialogue_number, dialogue_t
2833
self.context['message_list'] = message_list
2934
self.context['dialogue_type'] = dialogue_type
3035
print(message_list)
31-
print(negative_prompt)
3236
image_urls = tti_model.generate_image(question, negative_prompt)
33-
self.context['image_list'] = image_urls
34-
answer = '\n'.join([f"![Image]({path})" for path in image_urls])
37+
# 保存图片
38+
file_urls = []
39+
for image_url in image_urls:
40+
file_name = 'generated_image.png'
41+
file = bytes_to_uploaded_file(requests.get(image_url).content, file_name)
42+
meta = {
43+
'debug': False if application.id else True,
44+
'chat_id': chat_id,
45+
'application_id': str(application.id) if application.id else None,
46+
}
47+
file_url = FileSerializer(data={'file': file, 'meta': meta}).upload()
48+
file_urls.append(file_url)
49+
self.context['image_list'] = file_urls
50+
answer = '\n'.join([f"![Image]({path})" for path in file_urls])
3551
return NodeResult({'answer': answer, 'chat_model': tti_model, 'message_list': message_list,
3652
'image': [{'file_id': path.split('/')[-1], 'file_url': path} for path in file_urls],
3753
'history_message': history_message, 'question': question}, {})

apps/setting/models_provider/impl/openai_model_provider/credential/tti.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,32 @@
1010
from common.forms import BaseForm, TooltipLabel
1111
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
1212

13+
1314
class OpenAITTIModelParams(BaseForm):
14-
size = forms.TextInputField(
15+
size = forms.SingleSelect(
1516
TooltipLabel('图片尺寸', '指定生成图片的尺寸, 如: 1024x1024'),
16-
required=True, default_value='1024x1024')
17+
required=True,
18+
default_value='1024x1024',
19+
option_list=[
20+
{'value': '1024x1024', 'label': '1024x1024'},
21+
{'value': '1024x1792', 'label': '1024x1792'},
22+
{'value': '1792x1024', 'label': '1792x1024'},
23+
],
24+
text_field='label',
25+
value_field='value'
26+
)
1727

18-
quality = forms.TextInputField(
28+
quality = forms.SingleSelect(
1929
TooltipLabel('图片质量', ''),
20-
required=True, default_value='standard')
30+
required=True,
31+
default_value='standard',
32+
option_list=[
33+
{'value': 'standard', 'label': 'standard'},
34+
{'value': 'hd', 'label': 'hd'},
35+
],
36+
text_field='label',
37+
value_field='value'
38+
)
2139

2240
n = forms.SliderField(
2341
TooltipLabel('图片数量', '指定生成图片的数量'),

apps/setting/models_provider/impl/openai_model_provider/model/tti.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,8 @@
11
from typing import Dict
22

3-
import requests
4-
from langchain_core.messages import HumanMessage
5-
from langchain_openai import ChatOpenAI
63
from openai import OpenAI
74

85
from common.config.tokenizer_manage_config import TokenizerManage
9-
from common.util.common import bytes_to_uploaded_file
10-
from dataset.serializers.file_serializers import FileSerializer
116
from setting.models_provider.base_model_provider import MaxKBBaseModel
127
from setting.models_provider.impl.base_tti import BaseTextToImage
138

@@ -32,7 +27,7 @@ def __init__(self, **kwargs):
3227

3328
@staticmethod
3429
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
35-
optional_params = {'params': {}}
30+
optional_params = {'params': {'size': '1024x1024', 'quality': 'standard', 'n': 1}}
3631
for key, value in model_kwargs.items():
3732
if key not in ['model_id', 'use_local', 'streaming']:
3833
optional_params['params'][key] = value
@@ -43,25 +38,21 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
4338
**optional_params,
4439
)
4540

41+
def is_cache_model(self):
42+
return False
43+
4644
def check_auth(self):
4745
chat = OpenAI(api_key=self.api_key, base_url=self.api_base)
4846
response_list = chat.models.with_raw_response.list()
4947

5048
# self.generate_image('生成一个小猫图片')
5149

5250
def generate_image(self, prompt: str, negative_prompt: str = None):
53-
5451
chat = OpenAI(api_key=self.api_key, base_url=self.api_base)
5552
res = chat.images.generate(model=self.model, prompt=prompt, **self.params)
56-
5753
file_urls = []
5854
for content in res.data:
5955
url = content.url
60-
print(url)
61-
file_name = 'generated_image.png'
62-
file = bytes_to_uploaded_file(requests.get(url).content, file_name)
63-
meta = {'debug': True}
64-
file_url = FileSerializer(data={'file': file, 'meta': meta}).upload()
65-
file_urls.append(file_url)
56+
file_urls.append(url)
6657

6758
return file_urls

apps/setting/models_provider/impl/qwen_model_provider/credential/tti.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,19 +19,44 @@
1919

2020

2121
class QwenModelParams(BaseForm):
22-
size = forms.TextInputField(
22+
size = forms.SingleSelect(
2323
TooltipLabel('图片尺寸', '指定生成图片的尺寸, 如: 1024x1024'),
24-
required=True, default_value='1024x1024')
24+
required=True,
25+
default_value='1024*1024',
26+
option_list=[
27+
{'value': '1024*1024', 'label': '1024*1024'},
28+
{'value': '720*1280', 'label': '720*1280'},
29+
{'value': '768*1152', 'label': '768*1152'},
30+
{'value': '1280*720', 'label': '1280*720'},
31+
],
32+
text_field='label',
33+
value_field='value')
2534
n = forms.SliderField(
2635
TooltipLabel('图片数量', '指定生成图片的数量'),
2736
required=True, default_value=1,
2837
_min=1,
2938
_max=4,
3039
_step=1,
3140
precision=0)
32-
style = forms.TextInputField(
41+
style = forms.SingleSelect(
3342
TooltipLabel('风格', '指定生成图片的风格'),
34-
required=True, default_value='<auto>')
43+
required=True,
44+
default_value='<auto>',
45+
option_list=[
46+
{'value': '<auto>', 'label': '默认值,由模型随机输出图像风格'},
47+
{'value': '<photography>', 'label': '摄影'},
48+
{'value': '<portrait>', 'label': '人像写真'},
49+
{'value': '<3d cartoon>', 'label': '3D卡通'},
50+
{'value': '<anime>', 'label': '动画'},
51+
{'value': '<oil painting>', 'label': '油画'},
52+
{'value': '<watercolor>', 'label': '水彩'},
53+
{'value': '<sketch>', 'label': '素描'},
54+
{'value': '<chinese painting>', 'label': '中国画'},
55+
{'value': '<flat illustration>', 'label': '扁平插画'},
56+
],
57+
text_field='label',
58+
value_field='value'
59+
)
3560

3661

3762
class QwenTextToImageModelCredential(BaseForm, BaseModelCredential):

apps/setting/models_provider/impl/qwen_model_provider/model/tti.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,11 @@
11
# coding=utf-8
22
from http import HTTPStatus
3-
from pathlib import PurePosixPath
43
from typing import Dict
5-
from urllib.parse import unquote, urlparse
64

7-
import requests
85
from dashscope import ImageSynthesis
96
from langchain_community.chat_models import ChatTongyi
107
from langchain_core.messages import HumanMessage
118

12-
from common.util.common import bytes_to_uploaded_file
13-
from dataset.serializers.file_serializers import FileSerializer
149
from setting.models_provider.base_model_provider import MaxKBBaseModel
1510
from setting.models_provider.impl.base_tti import BaseTextToImage
1611

@@ -28,7 +23,7 @@ def __init__(self, **kwargs):
2823

2924
@staticmethod
3025
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
31-
optional_params = {'params': {}}
26+
optional_params = {'params': {'size': '1024*1024', 'style': '<auto>', 'n': 1}}
3227
for key, value in model_kwargs.items():
3328
if key not in ['model_id', 'use_local', 'streaming']:
3429
optional_params['params'][key] = value
@@ -39,6 +34,9 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
3934
)
4035
return chat_tong_yi
4136

37+
def is_cache_model(self):
38+
return False
39+
4240
def check_auth(self):
4341
chat = ChatTongyi(api_key=self.api_key, model_name='qwen-max')
4442
chat.invoke([HumanMessage([{"type": "text", "text": "你好"}])])
@@ -53,11 +51,7 @@ def generate_image(self, prompt: str, negative_prompt: str = None):
5351
file_urls = []
5452
if rsp.status_code == HTTPStatus.OK:
5553
for result in rsp.output.results:
56-
file_name = PurePosixPath(unquote(urlparse(result.url).path)).parts[-1]
57-
file = bytes_to_uploaded_file(requests.get(result.url).content, file_name)
58-
meta = {'debug': True}
59-
file_url = FileSerializer(data={'file': file, 'meta': meta}).upload()
60-
file_urls.append(file_url)
54+
file_urls.append(result.url)
6155
else:
6256
print('sync_call Failed, status_code: %s, code: %s, message: %s' %
6357
(rsp.status_code, rsp.code, rsp.message))

apps/setting/models_provider/impl/tencent_model_provider/model/tti.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,12 @@
33
import json
44
from typing import Dict
55

6-
import requests
76
from tencentcloud.common import credential
87
from tencentcloud.common.exception.tencent_cloud_sdk_exception import TencentCloudSDKException
98
from tencentcloud.common.profile.client_profile import ClientProfile
109
from tencentcloud.common.profile.http_profile import HttpProfile
1110
from tencentcloud.hunyuan.v20230901 import hunyuan_client, models
1211

13-
from common.util.common import bytes_to_uploaded_file
14-
from dataset.serializers.file_serializers import FileSerializer
1512
from setting.models_provider.base_model_provider import MaxKBBaseModel
1613
from setting.models_provider.impl.base_tti import BaseTextToImage
1714
from setting.models_provider.impl.tencent_model_provider.model.hunyuan import ChatHunyuan
@@ -87,12 +84,8 @@ def generate_image(self, prompt: str, negative_prompt: str = None):
8784
# 输出json格式的字符串回包
8885
print(resp.to_json_string())
8986
file_urls = []
90-
file_name = 'generated_image.png'
91-
file = bytes_to_uploaded_file(requests.get(resp.ResultImage).content, file_name)
92-
meta = {'debug': True}
93-
file_url = FileSerializer(data={'file': file, 'meta': meta}).upload()
94-
file_urls.append(file_url)
87+
88+
file_urls.append(resp.ResultImage)
9589
return file_urls
9690
except TencentCloudSDKException as err:
9791
print(err)
98-

apps/setting/models_provider/impl/zhipu_model_provider/credential/tti.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,22 @@
88

99

1010
class ZhiPuTTIModelParams(BaseForm):
11-
size = forms.TextInputField(
11+
size = forms.SingleSelect(
1212
TooltipLabel('图片尺寸',
1313
'图片尺寸,仅 cogview-3-plus 支持该参数。可选范围:[1024x1024,768x1344,864x1152,1344x768,1152x864,1440x720,720x1440],默认是1024x1024。'),
14-
required=True, default_value='1024x1024')
14+
required=True,
15+
default_value='1024x1024',
16+
option_list=[
17+
{'value': '1024x1024', 'label': '1024x1024'},
18+
{'value': '768x1344', 'label': '768x1344'},
19+
{'value': '864x1152', 'label': '864x1152'},
20+
{'value': '1344x768', 'label': '1344x768'},
21+
{'value': '1152x864', 'label': '1152x864'},
22+
{'value': '1440x720', 'label': '1440x720'},
23+
{'value': '720x1440', 'label': '720x1440'},
24+
],
25+
text_field='label',
26+
value_field='value')
1527

1628

1729
class ZhiPuTextToImageModelCredential(BaseForm, BaseModelCredential):

apps/setting/models_provider/impl/zhipu_model_provider/model/tti.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,10 @@
11
from typing import Dict
22

3-
import requests
43
from langchain_community.chat_models import ChatZhipuAI
54
from langchain_core.messages import HumanMessage
65
from zhipuai import ZhipuAI
76

87
from common.config.tokenizer_manage_config import TokenizerManage
9-
from common.util.common import bytes_to_uploaded_file
10-
from dataset.serializers.file_serializers import FileSerializer
118
from setting.models_provider.base_model_provider import MaxKBBaseModel
129
from setting.models_provider.impl.base_tti import BaseTextToImage
1310

@@ -30,7 +27,7 @@ def __init__(self, **kwargs):
3027

3128
@staticmethod
3229
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
33-
optional_params = {'params': {}}
30+
optional_params = {'params': {'size': '1024x1024'}}
3431
for key, value in model_kwargs.items():
3532
if key not in ['model_id', 'use_local', 'streaming']:
3633
optional_params['params'][key] = value
@@ -40,6 +37,9 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
4037
**optional_params,
4138
)
4239

40+
def is_cache_model(self):
41+
return False
42+
4343
def check_auth(self):
4444
chat = ChatZhipuAI(
4545
zhipuai_api_key=self.api_key,
@@ -58,16 +58,11 @@ def generate_image(self, prompt: str, negative_prompt: str = None):
5858
response = chat.images.generations(
5959
model=self.model, # 填写需要调用的模型编码
6060
prompt=prompt, # 填写需要生成图片的文本
61-
**self.params # 填写额外参数
61+
**self.params # 填写额外参数
6262
)
6363
file_urls = []
6464
for content in response.data:
65-
url = content['url']
66-
print(url)
67-
file_name = url.split('/')[-1]
68-
file = bytes_to_uploaded_file(requests.get(url).content, file_name)
69-
meta = {'debug': True}
70-
file_url = FileSerializer(data={'file': file, 'meta': meta}).upload()
71-
file_urls.append(file_url)
65+
url = content.url
66+
file_urls.append(url)
7267

7368
return file_urls

0 commit comments

Comments
 (0)