-
Notifications
You must be signed in to change notification settings - Fork 2.2k
feat: Support image generate model #1812
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
977e68f
60dea47
98fae97
767c284
15ab598
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# coding=utf-8 | ||
|
||
from .impl import * |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
# coding=utf-8 | ||
|
||
from typing import Type | ||
|
||
from rest_framework import serializers | ||
|
||
from application.flow.i_step_node import INode, NodeResult | ||
from common.util.field_message import ErrMessage | ||
|
||
|
||
class ImageGenerateNodeSerializer(serializers.Serializer): | ||
model_id = serializers.CharField(required=True, error_messages=ErrMessage.char("模型id")) | ||
|
||
prompt = serializers.CharField(required=True, error_messages=ErrMessage.char("提示词(正向)")) | ||
|
||
negative_prompt = serializers.CharField(required=False, default='', error_messages=ErrMessage.char("提示词(负向)")) | ||
# 多轮对话数量 | ||
dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("多轮对话数量")) | ||
|
||
dialogue_type = serializers.CharField(required=True, error_messages=ErrMessage.char("对话存储类型")) | ||
|
||
is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容')) | ||
|
||
model_params_setting = serializers.JSONField(required=False, default=dict, error_messages=ErrMessage.json("模型参数设置")) | ||
|
||
|
||
class IImageGenerateNode(INode): | ||
type = 'image-generate-node' | ||
|
||
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]: | ||
return ImageGenerateNodeSerializer | ||
|
||
def _run(self): | ||
return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data) | ||
|
||
def execute(self, model_id, prompt, negative_prompt, dialogue_number, dialogue_type, history_chat_record, chat_id, | ||
model_params_setting, | ||
chat_record_id, | ||
**kwargs) -> NodeResult: | ||
pass |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# coding=utf-8 | ||
|
||
from .base_image_generate_node import BaseImageGenerateNode |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
# coding=utf-8 | ||
from functools import reduce | ||
from typing import List | ||
|
||
import requests | ||
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage | ||
|
||
from application.flow.i_step_node import NodeResult | ||
from application.flow.step_node.image_generate_step_node.i_image_generate_node import IImageGenerateNode | ||
from common.util.common import bytes_to_uploaded_file | ||
from dataset.serializers.file_serializers import FileSerializer | ||
from setting.models_provider.tools import get_model_instance_by_model_user_id | ||
|
||
|
||
class BaseImageGenerateNode(IImageGenerateNode): | ||
def save_context(self, details, workflow_manage): | ||
self.context['answer'] = details.get('answer') | ||
self.context['question'] = details.get('question') | ||
self.answer_text = details.get('answer') | ||
|
||
def execute(self, model_id, prompt, negative_prompt, dialogue_number, dialogue_type, history_chat_record, chat_id, | ||
model_params_setting, | ||
chat_record_id, | ||
**kwargs) -> NodeResult: | ||
print(model_params_setting) | ||
application = self.workflow_manage.work_flow_post_handler.chat_info.application | ||
tti_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'), **model_params_setting) | ||
history_message = self.get_history_message(history_chat_record, dialogue_number) | ||
self.context['history_message'] = history_message | ||
question = self.generate_prompt_question(prompt) | ||
self.context['question'] = question | ||
message_list = self.generate_message_list(question, history_message) | ||
self.context['message_list'] = message_list | ||
self.context['dialogue_type'] = dialogue_type | ||
print(message_list) | ||
image_urls = tti_model.generate_image(question, negative_prompt) | ||
# 保存图片 | ||
file_urls = [] | ||
for image_url in image_urls: | ||
file_name = 'generated_image.png' | ||
file = bytes_to_uploaded_file(requests.get(image_url).content, file_name) | ||
meta = { | ||
'debug': False if application.id else True, | ||
'chat_id': chat_id, | ||
'application_id': str(application.id) if application.id else None, | ||
} | ||
file_url = FileSerializer(data={'file': file, 'meta': meta}).upload() | ||
file_urls.append(file_url) | ||
self.context['image_list'] = file_urls | ||
answer = '\n'.join([f"" for path in file_urls]) | ||
return NodeResult({'answer': answer, 'chat_model': tti_model, 'message_list': message_list, | ||
'image': [{'file_id': path.split('/')[-1], 'file_url': path} for path in file_urls], | ||
'history_message': history_message, 'question': question}, {}) | ||
|
||
def generate_history_ai_message(self, chat_record): | ||
for val in chat_record.details.values(): | ||
if self.node.id == val['node_id'] and 'image_list' in val: | ||
if val['dialogue_type'] == 'WORKFLOW': | ||
return chat_record.get_ai_message() | ||
return AIMessage(content=val['answer']) | ||
return chat_record.get_ai_message() | ||
|
||
def get_history_message(self, history_chat_record, dialogue_number): | ||
start_index = len(history_chat_record) - dialogue_number | ||
history_message = reduce(lambda x, y: [*x, *y], [ | ||
[self.generate_history_human_message(history_chat_record[index]), | ||
self.generate_history_ai_message(history_chat_record[index])] | ||
for index in | ||
range(start_index if start_index > 0 else 0, len(history_chat_record))], []) | ||
return history_message | ||
|
||
def generate_history_human_message(self, chat_record): | ||
|
||
for data in chat_record.details.values(): | ||
if self.node.id == data['node_id'] and 'image_list' in data: | ||
image_list = data['image_list'] | ||
if len(image_list) == 0 or data['dialogue_type'] == 'WORKFLOW': | ||
return HumanMessage(content=chat_record.problem_text) | ||
return HumanMessage(content=data['question']) | ||
return HumanMessage(content=chat_record.problem_text) | ||
|
||
def generate_prompt_question(self, prompt): | ||
return self.workflow_manage.generate_prompt(prompt) | ||
|
||
def generate_message_list(self, question: str, history_message): | ||
return [ | ||
*history_message, | ||
question | ||
] | ||
|
||
@staticmethod | ||
def reset_message_list(message_list: List[BaseMessage], answer_text): | ||
result = [{'role': 'user' if isinstance(message, HumanMessage) else 'ai', 'content': message.content} for | ||
message | ||
in | ||
message_list] | ||
result.append({'role': 'ai', 'content': answer_text}) | ||
return result | ||
|
||
def get_details(self, index: int, **kwargs): | ||
return { | ||
'name': self.node.properties.get('stepName'), | ||
"index": index, | ||
'run_time': self.context.get('run_time'), | ||
'history_message': [{'content': message.content, 'role': message.type} for message in | ||
(self.context.get('history_message') if self.context.get( | ||
'history_message') is not None else [])], | ||
'question': self.context.get('question'), | ||
'answer': self.context.get('answer'), | ||
'type': self.node.type, | ||
'message_tokens': self.context.get('message_tokens'), | ||
'answer_tokens': self.context.get('answer_tokens'), | ||
'status': self.status, | ||
'err_message': self.err_message, | ||
'image_list': self.context.get('image_list'), | ||
'dialogue_type': self.context.get('dialogue_type') | ||
} | ||
liuruibin marked this conversation as resolved.
Show resolved
Hide resolved
liuruibin marked this conversation as resolved.
Show resolved
Hide resolved
liuruibin marked this conversation as resolved.
Show resolved
Hide resolved
liuruibin marked this conversation as resolved.
Show resolved
Hide resolved
liuruibin marked this conversation as resolved.
Show resolved
Hide resolved
liuruibin marked this conversation as resolved.
Show resolved
Hide resolved
liuruibin marked this conversation as resolved.
Show resolved
Hide resolved
liuruibin marked this conversation as resolved.
Show resolved
Hide resolved
liuruibin marked this conversation as resolved.
Show resolved
Hide resolved
liuruibin marked this conversation as resolved.
Show resolved
Hide resolved
liuruibin marked this conversation as resolved.
Show resolved
Hide resolved
liuruibin marked this conversation as resolved.
Show resolved
Hide resolved
liuruibin marked this conversation as resolved.
Show resolved
Hide resolved
liuruibin marked this conversation as resolved.
Show resolved
Hide resolved
liuruibin marked this conversation as resolved.
Show resolved
Hide resolved
liuruibin marked this conversation as resolved.
Show resolved
Hide resolved
liuruibin marked this conversation as resolved.
Show resolved
Hide resolved
liuruibin marked this conversation as resolved.
Show resolved
Hide resolved
liuruibin marked this conversation as resolved.
Show resolved
Hide resolved
liuruibin marked this conversation as resolved.
Show resolved
Hide resolved
liuruibin marked this conversation as resolved.
Show resolved
Hide resolved
liuruibin marked this conversation as resolved.
Show resolved
Hide resolved
liuruibin marked this conversation as resolved.
Show resolved
Hide resolved
liuruibin marked this conversation as resolved.
Show resolved
Hide resolved
liuruibin marked this conversation as resolved.
Show resolved
Hide resolved
liuruibin marked this conversation as resolved.
Show resolved
Hide resolved
liuruibin marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 该代码存在以下不足和改进点:
以下是修改后的一些示例: 修改后的版本def generate_prompt(self, query):
"""根据查询生成提示问题"""
return self.workflow_manage.generate_prompt(query)
@staticmethod
def process_generated_images(image_urls, app_id=None):
"""处理生成的图像并上传到服务器"""
result = []
for url in image_urls:
response = requests.get(url)
image_data = response.content
file_name = f"generated_image_{len(result)}.png"
uploaded_file = bytes_to_uploaded_file(image_data, file_name)
metadata = {
'debug': bool(app_id),
'app_id': str(app_id) if app_id else None,
}
upload_response = FileSerializer(data={'file': uploaded_file, 'meta': metadata}).upload()
result.append(upload_response.url)
return result
def save_result(self, task_detail, run_time):
"""保存任务结果"""
return NodeResult({
'answer': task_detail.answer,
'chat_model': task_detail.chat_model,
'message_list': task_detail.message_list,
'image': [{'file_id': path.rsplit('/', 1)[-1], 'file_url': path} for path in task_detail.image_urls],
'history_message': task_detail.history_message,
'question': task_detail.question
}, '')
def handle_workflow_messages(task_detail):
"""处理流程消息"""
for val in task_detail.details.values():
if val['node_id'] == self.node.id and 'image_list' in val:
if val['dialogue_type'] == 'WORKFLOW':
return task_detail.ai_message
elif not (val['dialogue_type'] == 'IMAGE_GENERATION'):
return AIMessage(content=val['answer'].strip())
def build_history_messages(chat_records, count):
"""构建历史消息列表"""
history_messages = []
for record in chat_records[-count:]:
history_messages.extend([
HumanMessage(content=record.problem_text),
self.handle_workflow_messages(record)
])
return history_messages
class ImageGenerationWorkflow(BaseStepNode, iImageGenerateNode):
"""图像生成流程节点"""
def execute(
self,
model_id,
prompt,
negative_prompt='',
dialogue_number=1,
dialect='DEFAULT',
history=None,
dialog_type='CHAT',
flow_params_serializer=None,
**kwargs
) -> TaskDetailResponseModel_v3_0:
try:
task_details = {}
task_results = {}
# 获取模型实例,并获取上下文。
... (原实现保持不变)
generated_images = task_details.get('generated_images')
if generated_images:
image_upload_urls = ImageGenerationWorkflow.process_generated_images(generated_images, app_id)
node_result = self.save_result(
{'id': task_details.get('task_id', ''), 'question': task_details.get('prompt', ''),
'generated_answers': task_details.get('response', ''), 'image_list': image_upload_urls},
datetime.datetime.now()
)
return node_result
except Exception as e:
self.log_error(f"Exception during execution: {e}")
raise 通过上述改进建议,可以使代码更加健壮、易于维护和扩展。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这段代码存在一些问题和可以优化的地方: 1.
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
# coding=utf-8 | ||
from abc import abstractmethod | ||
|
||
from pydantic import BaseModel | ||
|
||
|
||
class BaseTextToImage(BaseModel): | ||
@abstractmethod | ||
def check_auth(self): | ||
pass | ||
|
||
@abstractmethod | ||
def generate_image(self, prompt: str, negative_prompt: str = None): | ||
pass |
Uh oh!
There was an error while loading. Please reload this page.