Skip to content

Commit 33da18f

Browse files
committed
feat: 支持添加图片生成模型(WIP)
1 parent daf27a7 commit 33da18f

File tree

28 files changed

+959
-16
lines changed

28 files changed

+959
-16
lines changed

apps/application/flow/step_node/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,15 @@
1818

1919
from .document_extract_node import *
2020
from .image_understand_step_node import *
21+
from .image_generate_step_node import *
2122

2223
from .search_dataset_node import *
2324
from .start_node import *
2425

2526
node_list = [BaseStartStepNode, BaseChatNode, BaseSearchDatasetNode, BaseQuestionNode, BaseConditionNode, BaseReplyNode,
2627
BaseFunctionNodeNode, BaseFunctionLibNodeNode, BaseRerankerNode, BaseApplicationNode,
2728
BaseDocumentExtractNode,
28-
BaseImageUnderstandNode, BaseFormNode]
29+
BaseImageUnderstandNode, BaseImageGenerateNode, BaseFormNode]
2930

3031

3132
def get_node(node_type):
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# coding=utf-8
2+
3+
from .impl import *
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# coding=utf-8
2+
3+
from typing import Type
4+
5+
from rest_framework import serializers
6+
7+
from application.flow.i_step_node import INode, NodeResult
8+
from common.util.field_message import ErrMessage
9+
10+
11+
class ImageGenerateNodeSerializer(serializers.Serializer):
12+
model_id = serializers.CharField(required=True, error_messages=ErrMessage.char("模型id"))
13+
14+
prompt = serializers.CharField(required=True, error_messages=ErrMessage.char("提示词(正向)"))
15+
16+
negative_prompt = serializers.CharField(required=False, default='', error_messages=ErrMessage.char("提示词(负向)"))
17+
# 多轮对话数量
18+
dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("多轮对话数量"))
19+
20+
dialogue_type = serializers.CharField(required=True, error_messages=ErrMessage.char("对话存储类型"))
21+
22+
is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容'))
23+
24+
25+
class IImageGenerateNode(INode):
26+
type = 'image-generate-node'
27+
28+
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
29+
return ImageGenerateNodeSerializer
30+
31+
def _run(self):
32+
return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data)
33+
34+
def execute(self, model_id, prompt, negative_prompt, dialogue_number, dialogue_type, history_chat_record, chat_id,
35+
chat_record_id,
36+
**kwargs) -> NodeResult:
37+
pass
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# coding=utf-8
2+
3+
from .base_image_generate_node import BaseImageGenerateNode
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# coding=utf-8
2+
from functools import reduce
3+
from typing import List
4+
5+
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
6+
7+
from application.flow.i_step_node import NodeResult
8+
from application.flow.step_node.image_generate_step_node.i_image_generate_node import IImageGenerateNode
9+
from setting.models_provider.tools import get_model_instance_by_model_user_id
10+
11+
12+
class BaseImageGenerateNode(IImageGenerateNode):
13+
def save_context(self, details, workflow_manage):
14+
self.context['answer'] = details.get('answer')
15+
self.context['question'] = details.get('question')
16+
self.answer_text = details.get('answer')
17+
18+
def execute(self, model_id, prompt, negative_prompt, dialogue_number, dialogue_type, history_chat_record, chat_id,
19+
chat_record_id,
20+
**kwargs) -> NodeResult:
21+
22+
tti_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'))
23+
history_message = self.get_history_message(history_chat_record, dialogue_number)
24+
self.context['history_message'] = history_message
25+
question = self.generate_prompt_question(prompt)
26+
self.context['question'] = question
27+
message_list = self.generate_message_list(question, history_message)
28+
self.context['message_list'] = message_list
29+
self.context['dialogue_type'] = dialogue_type
30+
print(message_list)
31+
print(negative_prompt)
32+
image_urls = tti_model.generate_image(question, negative_prompt)
33+
self.context['image_list'] = image_urls
34+
answer = '\n'.join([f"![Image]({path})" for path in image_urls])
35+
return NodeResult({'answer': answer, 'chat_model': tti_model, 'message_list': message_list,
36+
'history_message': history_message, 'question': question}, {})
37+
38+
def generate_history_ai_message(self, chat_record):
39+
for val in chat_record.details.values():
40+
if self.node.id == val['node_id'] and 'image_list' in val:
41+
if val['dialogue_type'] == 'WORKFLOW':
42+
return chat_record.get_ai_message()
43+
return AIMessage(content=val['answer'])
44+
return chat_record.get_ai_message()
45+
46+
def get_history_message(self, history_chat_record, dialogue_number):
47+
start_index = len(history_chat_record) - dialogue_number
48+
history_message = reduce(lambda x, y: [*x, *y], [
49+
[self.generate_history_human_message(history_chat_record[index]),
50+
self.generate_history_ai_message(history_chat_record[index])]
51+
for index in
52+
range(start_index if start_index > 0 else 0, len(history_chat_record))], [])
53+
return history_message
54+
55+
def generate_history_human_message(self, chat_record):
56+
57+
for data in chat_record.details.values():
58+
if self.node.id == data['node_id'] and 'image_list' in data:
59+
image_list = data['image_list']
60+
if len(image_list) == 0 or data['dialogue_type'] == 'WORKFLOW':
61+
return HumanMessage(content=chat_record.problem_text)
62+
return HumanMessage(content=data['question'])
63+
return HumanMessage(content=chat_record.problem_text)
64+
65+
def generate_prompt_question(self, prompt):
66+
return self.workflow_manage.generate_prompt(prompt)
67+
68+
def generate_message_list(self, question: str, history_message):
69+
return [
70+
*history_message,
71+
question
72+
]
73+
74+
@staticmethod
75+
def reset_message_list(message_list: List[BaseMessage], answer_text):
76+
result = [{'role': 'user' if isinstance(message, HumanMessage) else 'ai', 'content': message.content} for
77+
message
78+
in
79+
message_list]
80+
result.append({'role': 'ai', 'content': answer_text})
81+
return result
82+
83+
def get_details(self, index: int, **kwargs):
84+
return {
85+
'name': self.node.properties.get('stepName'),
86+
"index": index,
87+
'run_time': self.context.get('run_time'),
88+
'history_message': [{'content': message.content, 'role': message.type} for message in
89+
(self.context.get('history_message') if self.context.get(
90+
'history_message') is not None else [])],
91+
'question': self.context.get('question'),
92+
'answer': self.context.get('answer'),
93+
'type': self.node.type,
94+
'message_tokens': self.context.get('message_tokens'),
95+
'answer_tokens': self.context.get('answer_tokens'),
96+
'status': self.status,
97+
'err_message': self.err_message,
98+
'image_list': self.context.get('image_list'),
99+
'dialogue_type': self.context.get('dialogue_type')
100+
}

apps/application/flow/workflow_manage.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def __init__(self, _id: str, _type: str, x: int, y: int, properties: dict, **kwa
5353

5454

5555
end_nodes = ['ai-chat-node', 'reply-node', 'function-node', 'function-lib-node', 'application-node',
56-
'image-understand-node']
56+
'image-understand-node', 'image-generate-node']
5757

5858

5959
class Flow:

apps/common/util/common.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,12 @@
88
"""
99
import hashlib
1010
import importlib
11+
import mimetypes
12+
import io
1113
from functools import reduce
1214
from typing import Dict, List
1315

16+
from django.core.files.uploadedfile import InMemoryUploadedFile
1417
from django.db.models import QuerySet
1518

1619
from ..exception.app_exception import AppApiException
@@ -111,3 +114,25 @@ def bulk_create_in_batches(model, data, batch_size=1000):
111114
batch = data[i:i + batch_size]
112115
model.objects.bulk_create(batch)
113116

117+
118+
def bytes_to_uploaded_file(file_bytes, file_name="file.txt"):
119+
content_type, _ = mimetypes.guess_type(file_name)
120+
if content_type is None:
121+
# 如果未能识别,设置为默认的二进制文件类型
122+
content_type = "application/octet-stream"
123+
# 创建一个内存中的字节流对象
124+
file_stream = io.BytesIO(file_bytes)
125+
126+
# 获取文件大小
127+
file_size = len(file_bytes)
128+
129+
# 创建 InMemoryUploadedFile 对象
130+
uploaded_file = InMemoryUploadedFile(
131+
file=file_stream,
132+
field_name=None,
133+
name=file_name,
134+
content_type=content_type,
135+
size=file_size,
136+
charset=None,
137+
)
138+
return uploaded_file

apps/setting/models_provider/base_model_provider.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ class ModelTypeConst(Enum):
150150
STT = {'code': 'STT', 'message': '语音识别'}
151151
TTS = {'code': 'TTS', 'message': '语音合成'}
152152
IMAGE = {'code': 'IMAGE', 'message': '图片理解'}
153+
TTI = {'code': 'TTI', 'message': '图片生成'}
153154
RERANKER = {'code': 'RERANKER', 'message': '重排模型'}
154155

155156

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# coding=utf-8
2+
from abc import abstractmethod
3+
4+
from pydantic import BaseModel
5+
6+
7+
class BaseTextToImage(BaseModel):
8+
@abstractmethod
9+
def check_auth(self):
10+
pass
11+
12+
@abstractmethod
13+
def generate_image(self, prompt: str, negative_prompt: str = None):
14+
pass
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# coding=utf-8
2+
import base64
3+
import os
4+
from typing import Dict
5+
6+
from langchain_core.messages import HumanMessage
7+
8+
from common import forms
9+
from common.exception.app_exception import AppApiException
10+
from common.forms import BaseForm
11+
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
12+
13+
14+
class OpenAITextToImageModelCredential(BaseForm, BaseModelCredential):
15+
api_base = forms.TextInputField('API 域名', required=True)
16+
api_key = forms.PasswordInputField('API Key', required=True)
17+
18+
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
19+
raise_exception=False):
20+
model_type_list = provider.get_model_type_list()
21+
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
22+
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
23+
24+
for key in ['api_base', 'api_key']:
25+
if key not in model_credential:
26+
if raise_exception:
27+
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
28+
else:
29+
return False
30+
try:
31+
model = provider.get_model(model_type, model_name, model_credential)
32+
res = model.check_auth()
33+
print(res)
34+
except Exception as e:
35+
if isinstance(e, AppApiException):
36+
raise e
37+
if raise_exception:
38+
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
39+
else:
40+
return False
41+
return True
42+
43+
def encryption_dict(self, model: Dict[str, object]):
44+
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
45+
46+
def get_model_params_setting_form(self, model_name):
47+
pass
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
from typing import Dict
2+
3+
import requests
4+
from langchain_core.messages import HumanMessage
5+
from langchain_openai import ChatOpenAI
6+
from openai import OpenAI
7+
8+
from common.config.tokenizer_manage_config import TokenizerManage
9+
from common.util.common import bytes_to_uploaded_file
10+
from dataset.serializers.file_serializers import FileSerializer
11+
from setting.models_provider.base_model_provider import MaxKBBaseModel
12+
from setting.models_provider.impl.base_tti import BaseTextToImage
13+
14+
15+
def custom_get_token_ids(text: str):
16+
tokenizer = TokenizerManage.get_tokenizer()
17+
return tokenizer.encode(text)
18+
19+
20+
class OpenAITextToImage(MaxKBBaseModel, BaseTextToImage):
21+
api_base: str
22+
api_key: str
23+
model: str
24+
params: dict
25+
26+
def __init__(self, **kwargs):
27+
super().__init__(**kwargs)
28+
self.api_key = kwargs.get('api_key')
29+
self.api_base = kwargs.get('api_base')
30+
self.model = kwargs.get('model')
31+
self.params = kwargs.get('params')
32+
33+
@staticmethod
34+
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
35+
optional_params = {'params': {}}
36+
for key, value in model_kwargs.items():
37+
if key not in ['model_id', 'use_local', 'streaming']:
38+
optional_params['params'][key] = value
39+
return OpenAITextToImage(
40+
model=model_name,
41+
api_base=model_credential.get('api_base'),
42+
api_key=model_credential.get('api_key'),
43+
**optional_params,
44+
)
45+
46+
def check_auth(self):
47+
chat = OpenAI(api_key=self.api_key, base_url=self.api_base)
48+
response_list = chat.models.with_raw_response.list()
49+
50+
# self.generate_image('生成一个小猫图片')
51+
52+
def generate_image(self, prompt: str, negative_prompt: str = None):
53+
54+
chat = OpenAI(api_key=self.api_key, base_url=self.api_base)
55+
res = chat.images.generate(model=self.model, prompt=prompt, **self.params)
56+
57+
file_urls = []
58+
for content in res.data:
59+
url = content.url
60+
print(url)
61+
file_name = 'generated_image.png'
62+
file = bytes_to_uploaded_file(requests.get(url).content, file_name)
63+
meta = {'debug': True}
64+
file_url = FileSerializer(data={'file': file, 'meta': meta}).upload()
65+
file_urls.append(file_url)
66+
67+
return file_urls

0 commit comments

Comments
 (0)