Skip to content

Commit f4c98ea

Browse files
committed
feat: Volcanic Engine Image Model
1 parent 98fae97 commit f4c98ea

File tree

7 files changed

+388
-8
lines changed

7 files changed

+388
-8
lines changed

apps/application/flow/step_node/image_understand_step_node/impl/base_image_understand_node.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from application.flow.step_node.image_understand_step_node.i_image_understand_node import IImageUnderstandNode
1313
from dataset.models import File
1414
from setting.models_provider.tools import get_model_instance_by_model_user_id
15+
from imghdr import what
1516

1617

1718
def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str):
@@ -59,8 +60,9 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor
5960

6061
def file_id_to_base64(file_id: str):
6162
file = QuerySet(File).filter(id=file_id).first()
62-
base64_image = base64.b64encode(file.get_byte()).decode("utf-8")
63-
return base64_image
63+
file_bytes = file.get_byte()
64+
base64_image = base64.b64encode(file_bytes).decode("utf-8")
65+
return [base64_image, what(None, file_bytes.tobytes())]
6466

6567

6668
class BaseImageUnderstandNode(IImageUnderstandNode):
@@ -77,7 +79,7 @@ def execute(self, model_id, system, prompt, dialogue_number, dialogue_type, hist
7779
# 处理不正确的参数
7880
if image is None or not isinstance(image, list):
7981
image = []
80-
82+
print(model_params_setting)
8183
image_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'), **model_params_setting)
8284
# 执行详情中的历史消息不需要图片内容
8385
history_message = self.get_history_message_for_details(history_chat_record, dialogue_number)
@@ -152,7 +154,7 @@ def generate_history_human_message(self, chat_record):
152154
return HumanMessage(
153155
content=[
154156
{'type': 'text', 'text': data['question']},
155-
*[{'type': 'image_url', 'image_url': {'url': f'data:image/jpeg;base64,{base64_image}'}} for
157+
*[{'type': 'image_url', 'image_url': {'url': f'data:image/{base64_image[1]};base64,{base64_image[0]}'}} for
156158
base64_image in image_base64_list]
157159
])
158160
return HumanMessage(content=chat_record.problem_text)
@@ -167,8 +169,10 @@ def generate_message_list(self, image_model, system: str, prompt: str, history_m
167169
for img in image:
168170
file_id = img['file_id']
169171
file = QuerySet(File).filter(id=file_id).first()
170-
base64_image = base64.b64encode(file.get_byte()).decode("utf-8")
171-
images.append({'type': 'image_url', 'image_url': {'url': f'data:image/jpeg;base64,{base64_image}'}})
172+
image_bytes = file.get_byte()
173+
base64_image = base64.b64encode(image_bytes).decode("utf-8")
174+
image_format = what(None, image_bytes.tobytes())
175+
images.append({'type': 'image_url', 'image_url': {'url': f'data:image/{image_format};base64,{base64_image}'}})
172176
messages = [HumanMessage(
173177
content=[
174178
{'type': 'text', 'text': self.workflow_manage.generate_prompt(prompt)},
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# coding=utf-8
2+
import base64
3+
import os
4+
from typing import Dict
5+
6+
from langchain_core.messages import HumanMessage
7+
8+
from common import forms
9+
from common.exception.app_exception import AppApiException
10+
from common.forms import BaseForm, TooltipLabel
11+
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
12+
13+
class VolcanicEngineImageModelParams(BaseForm):
14+
temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
15+
required=True, default_value=0.95,
16+
_min=0.1,
17+
_max=1.0,
18+
_step=0.01,
19+
precision=2)
20+
21+
max_tokens = forms.SliderField(
22+
TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'),
23+
required=True, default_value=1024,
24+
_min=1,
25+
_max=100000,
26+
_step=1,
27+
precision=0)
28+
29+
class VolcanicEngineImageModelCredential(BaseForm, BaseModelCredential):
30+
api_key = forms.PasswordInputField('API Key', required=True)
31+
api_base = forms.TextInputField('API 域名', required=True)
32+
33+
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
34+
raise_exception=False):
35+
model_type_list = provider.get_model_type_list()
36+
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
37+
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
38+
39+
for key in ['api_key', 'api_base']:
40+
if key not in model_credential:
41+
if raise_exception:
42+
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
43+
else:
44+
return False
45+
try:
46+
model = provider.get_model(model_type, model_name, model_credential)
47+
res = model.stream([HumanMessage(content=[{"type": "text", "text": "你好"}])])
48+
for chunk in res:
49+
print(chunk)
50+
except Exception as e:
51+
if isinstance(e, AppApiException):
52+
raise e
53+
if raise_exception:
54+
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
55+
else:
56+
return False
57+
return True
58+
59+
def encryption_dict(self, model: Dict[str, object]):
60+
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
61+
62+
def get_model_params_setting_form(self, model_name):
63+
return VolcanicEngineImageModelParams()
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# coding=utf-8
2+
3+
from typing import Dict
4+
5+
from common import forms
6+
from common.exception.app_exception import AppApiException
7+
from common.forms import BaseForm, TooltipLabel
8+
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
9+
10+
11+
class VolcanicEngineTTIModelGeneralParams(BaseForm):
12+
size = forms.SingleSelect(
13+
TooltipLabel('图片尺寸',
14+
'宽、高与512差距过大,则出图效果不佳、延迟过长概率显著增加。超分前建议比例及对应宽高:width*height'),
15+
required=True,
16+
default_value='512*512',
17+
option_list=[
18+
{'value': '512*512', 'label': '512*512'},
19+
{'value': '512*384', 'label': '512*384'},
20+
{'value': '384*512', 'label': '384*512'},
21+
{'value': '512*341', 'label': '512*341'},
22+
{'value': '341*512', 'label': '341*512'},
23+
{'value': '512*288', 'label': '512*288'},
24+
{'value': '288*512', 'label': '288*512'},
25+
],
26+
text_field='label',
27+
value_field='value')
28+
29+
30+
class VolcanicEngineTTIModelCredential(BaseForm, BaseModelCredential):
31+
access_key = forms.PasswordInputField('Access Key', required=True)
32+
secret_key = forms.PasswordInputField('Secret Key', required=True)
33+
34+
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
35+
raise_exception=False):
36+
model_type_list = provider.get_model_type_list()
37+
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
38+
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
39+
40+
for key in ['access_key', 'secret_key']:
41+
if key not in model_credential:
42+
if raise_exception:
43+
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
44+
else:
45+
return False
46+
try:
47+
model = provider.get_model(model_type, model_name, model_credential)
48+
model.check_auth()
49+
except Exception as e:
50+
if isinstance(e, AppApiException):
51+
raise e
52+
if raise_exception:
53+
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
54+
else:
55+
return False
56+
return True
57+
58+
def encryption_dict(self, model: Dict[str, object]):
59+
return {**model, 'secret_key': super().encryption(model.get('secret_key', ''))}
60+
61+
def get_model_params_setting_form(self, model_name):
62+
return VolcanicEngineTTIModelGeneralParams()
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from typing import Dict
2+
3+
from langchain_openai.chat_models import ChatOpenAI
4+
5+
from common.config.tokenizer_manage_config import TokenizerManage
6+
from setting.models_provider.base_model_provider import MaxKBBaseModel
7+
8+
9+
def custom_get_token_ids(text: str):
10+
tokenizer = TokenizerManage.get_tokenizer()
11+
return tokenizer.encode(text)
12+
13+
14+
class VolcanicEngineImage(MaxKBBaseModel, ChatOpenAI):
15+
16+
@staticmethod
17+
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
18+
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
19+
return VolcanicEngineImage(
20+
model_name=model_name,
21+
openai_api_key=model_credential.get('api_key'),
22+
openai_api_base=model_credential.get('api_base'),
23+
# stream_options={"include_usage": True},
24+
streaming=True,
25+
**optional_params,
26+
)
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
# coding=utf-8
2+
3+
'''
4+
requires Python 3.6 or later
5+
6+
pip install asyncio
7+
pip install websockets
8+
9+
'''
10+
11+
import datetime
12+
import hashlib
13+
import hmac
14+
import json
15+
import sys
16+
from typing import Dict
17+
18+
import requests
19+
from langchain_openai import ChatOpenAI
20+
21+
from setting.models_provider.base_model_provider import MaxKBBaseModel
22+
from setting.models_provider.impl.base_tti import BaseTextToImage
23+
24+
method = 'POST'
25+
host = 'visual.volcengineapi.com'
26+
region = 'cn-north-1'
27+
endpoint = 'https://visual.volcengineapi.com'
28+
service = 'cv'
29+
30+
req_key_dict = {
31+
'general_v1.4': 'high_aes_general_v14',
32+
'general_v2.0': 'high_aes_general_v20',
33+
'general_v2.0_L': 'high_aes_general_v20_L',
34+
}
35+
36+
37+
def sign(key, msg):
38+
return hmac.new(key, msg.encode('utf-8'), hashlib.sha256).digest()
39+
40+
41+
def getSignatureKey(key, dateStamp, regionName, serviceName):
42+
kDate = sign(key.encode('utf-8'), dateStamp)
43+
kRegion = sign(kDate, regionName)
44+
kService = sign(kRegion, serviceName)
45+
kSigning = sign(kService, 'request')
46+
return kSigning
47+
48+
49+
def formatQuery(parameters):
50+
request_parameters_init = ''
51+
for key in sorted(parameters):
52+
request_parameters_init += key + '=' + parameters[key] + '&'
53+
request_parameters = request_parameters_init[:-1]
54+
return request_parameters
55+
56+
57+
def signV4Request(access_key, secret_key, service, req_query, req_body):
58+
if access_key is None or secret_key is None:
59+
print('No access key is available.')
60+
sys.exit()
61+
62+
t = datetime.datetime.utcnow()
63+
current_date = t.strftime('%Y%m%dT%H%M%SZ')
64+
# current_date = '20210818T095729Z'
65+
datestamp = t.strftime('%Y%m%d') # Date w/o time, used in credential scope
66+
canonical_uri = '/'
67+
canonical_querystring = req_query
68+
signed_headers = 'content-type;host;x-content-sha256;x-date'
69+
payload_hash = hashlib.sha256(req_body.encode('utf-8')).hexdigest()
70+
content_type = 'application/json'
71+
canonical_headers = 'content-type:' + content_type + '\n' + 'host:' + host + \
72+
'\n' + 'x-content-sha256:' + payload_hash + \
73+
'\n' + 'x-date:' + current_date + '\n'
74+
canonical_request = method + '\n' + canonical_uri + '\n' + canonical_querystring + \
75+
'\n' + canonical_headers + '\n' + signed_headers + '\n' + payload_hash
76+
# print(canonical_request)
77+
algorithm = 'HMAC-SHA256'
78+
credential_scope = datestamp + '/' + region + '/' + service + '/' + 'request'
79+
string_to_sign = algorithm + '\n' + current_date + '\n' + credential_scope + '\n' + hashlib.sha256(
80+
canonical_request.encode('utf-8')).hexdigest()
81+
# print(string_to_sign)
82+
signing_key = getSignatureKey(secret_key, datestamp, region, service)
83+
# print(signing_key)
84+
signature = hmac.new(signing_key, (string_to_sign).encode(
85+
'utf-8'), hashlib.sha256).hexdigest()
86+
# print(signature)
87+
88+
authorization_header = algorithm + ' ' + 'Credential=' + access_key + '/' + \
89+
credential_scope + ', ' + 'SignedHeaders=' + \
90+
signed_headers + ', ' + 'Signature=' + signature
91+
# print(authorization_header)
92+
headers = {'X-Date': current_date,
93+
'Authorization': authorization_header,
94+
'X-Content-Sha256': payload_hash,
95+
'Content-Type': content_type
96+
}
97+
# print(headers)
98+
99+
# ************* SEND THE REQUEST *************
100+
request_url = endpoint + '?' + canonical_querystring
101+
102+
print('\nBEGIN REQUEST++++++++++++++++++++++++++++++++++++')
103+
print('Request URL = ' + request_url)
104+
try:
105+
r = requests.post(request_url, headers=headers, data=req_body)
106+
except Exception as err:
107+
print(f'error occurred: {err}')
108+
raise
109+
else:
110+
print('\nRESPONSE++++++++++++++++++++++++++++++++++++')
111+
print(f'Response code: {r.status_code}\n')
112+
# 使用 replace 方法将 \u0026 替换为 &
113+
resp_str = r.text.replace("\\u0026", "&")
114+
if r.status_code != 200:
115+
raise Exception(f'Error: {resp_str}')
116+
print(f'Response body: {resp_str}\n')
117+
return json.loads(resp_str)['data']['image_urls']
118+
119+
120+
class VolcanicEngineTextToImage(MaxKBBaseModel, BaseTextToImage):
121+
access_key: str
122+
secret_key: str
123+
model_version: str
124+
params: dict
125+
126+
def __init__(self, **kwargs):
127+
super().__init__(**kwargs)
128+
self.access_key = kwargs.get('access_key')
129+
self.secret_key = kwargs.get('secret_key')
130+
self.model_version = kwargs.get('model_version')
131+
self.params = kwargs.get('params')
132+
133+
@staticmethod
134+
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
135+
optional_params = {'params': {}}
136+
for key, value in model_kwargs.items():
137+
if key not in ['model_id', 'use_local', 'streaming']:
138+
optional_params['params'][key] = value
139+
return VolcanicEngineTextToImage(
140+
model_version=model_name,
141+
access_key=model_credential.get('access_key'),
142+
secret_key=model_credential.get('secret_key'),
143+
**optional_params
144+
)
145+
146+
def check_auth(self):
147+
res = self.generate_image('生成一张小猫图片')
148+
print(res)
149+
150+
def generate_image(self, prompt: str, negative_prompt: str = None):
151+
# 请求Query,按照接口文档中填入即可
152+
query_params = {
153+
'Action': 'CVProcess',
154+
'Version': '2022-08-31',
155+
}
156+
formatted_query = formatQuery(query_params)
157+
size = self.params.pop('size', '512*512').split('*')
158+
body_params = {
159+
"req_key": req_key_dict[self.model_version],
160+
"prompt": prompt,
161+
"model_version": self.model_version,
162+
"return_url": True,
163+
"width": int(size[0]),
164+
"height": int(size[1]),
165+
**self.params
166+
}
167+
formatted_body = json.dumps(body_params)
168+
return signV4Request(self.access_key, self.secret_key, service, formatted_query, formatted_body)
169+
170+
def is_cache_model(self):
171+
return False

0 commit comments

Comments
 (0)