Skip to content

Commit 4f4c3c0

Browse files
authored
Add builtin tools: image generate, image edit, video generate (#98)
1 parent 0a4b0c9 commit 4f4c3c0

File tree

3 files changed

+330
-0
lines changed

3 files changed

+330
-0
lines changed
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from typing import Dict
15+
from google.adk.tools import ToolContext
16+
from google.genai import types
17+
from volcenginesdkarkruntime import Ark
18+
from veadk.config import getenv
19+
import base64
20+
21+
from veadk.utils.logger import get_logger
22+
23+
logger = get_logger(__name__)
24+
25+
client = Ark(
26+
api_key=getenv("MODEL_IMAGE_API_KEY"),
27+
base_url=getenv("MODEL_IMAGE_API_BASE"),
28+
)
29+
30+
31+
async def image_edit(
32+
origin_image: str,
33+
image_name: str,
34+
image_prompt: str,
35+
response_format: str,
36+
guidance_scale: float,
37+
watermark: bool,
38+
seed: int,
39+
tool_context: ToolContext,
40+
) -> Dict:
41+
"""Edit an image accoding to the prompt.
42+
43+
Args:
44+
origin_image: The url or the base64 string of the edited image.
45+
image_name: The name of the generated image.
46+
image_prompt: The prompt that describes the image.
47+
response_format: str, b64_json or url, default url.
48+
guidance_scale: default 2.5.
49+
watermark: default True.
50+
seed: default -1.
51+
52+
"""
53+
try:
54+
response = client.images.generate(
55+
model=getenv("MODEL_EDIT_NAME"),
56+
image=origin_image,
57+
prompt=image_prompt,
58+
response_format=response_format,
59+
guidance_scale=guidance_scale,
60+
watermark=watermark,
61+
seed=seed,
62+
)
63+
64+
if response.data and len(response.data) > 0:
65+
for item in response.data:
66+
if response_format == "url":
67+
image = item.url
68+
tool_context.state["generated_image_url"] = image
69+
70+
elif response_format == "b64_json":
71+
image = item.b64_json
72+
image_bytes = base64.b64decode(image)
73+
74+
tool_context.state["generated_image_url"] = (
75+
f"data:image/jpeg;base64,{image}"
76+
)
77+
78+
report_artifact = types.Part.from_bytes(
79+
data=image_bytes, mime_type="image/png"
80+
)
81+
await tool_context.save_artifact(image_name, report_artifact)
82+
logger.debug(f"Image saved as ADK artifact: {image_name}")
83+
84+
return {"status": "success", "image_name": image_name, "image": image}
85+
else:
86+
error_details = f"No images returned by Doubao model: {response}"
87+
logger.error(error_details)
88+
return {"status": "error", "message": error_details}
89+
90+
except Exception as e:
91+
return {
92+
"status": "error",
93+
"message": f"Doubao image generation failed: {str(e)}",
94+
}
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from typing import Dict
15+
16+
from google.genai import types
17+
from google.adk.tools import ToolContext
18+
from veadk.config import getenv
19+
import base64
20+
from volcenginesdkarkruntime import Ark
21+
22+
from veadk.utils.logger import get_logger
23+
24+
logger = get_logger(__name__)
25+
26+
client = Ark(
27+
api_key=getenv("MODEL_IMAGE_API_KEY"),
28+
base_url=getenv("MODEL_IMAGE_API_BASE"),
29+
)
30+
31+
32+
async def image_generate(
33+
image_name: str,
34+
image_prompt: str,
35+
response_format: str,
36+
size: str,
37+
guidance_scale: float,
38+
watermark: bool,
39+
seed: int,
40+
tool_context: ToolContext,
41+
) -> Dict:
42+
"""Generate an image accoding to the prompt.
43+
44+
Args:
45+
image_name: The name of the generated image.
46+
image_prompt: The prompt that describes the image.
47+
response_format: str, b64_json or url, default url.
48+
size: default 1024x1024.
49+
guidance_scale: default 2.5.
50+
watermark: default True.
51+
seed: default -1.
52+
53+
"""
54+
try:
55+
response = client.images.generate(
56+
model=getenv("MODEL_IMAGE_NAME"),
57+
prompt=image_prompt,
58+
response_format=response_format,
59+
size=size,
60+
guidance_scale=guidance_scale,
61+
watermark=watermark,
62+
seed=seed,
63+
)
64+
65+
if response.data and len(response.data) > 0:
66+
for item in response.data:
67+
if response_format == "url":
68+
image = item.url
69+
tool_context.state["generated_image_url"] = image
70+
71+
elif response_format == "b64_json":
72+
image = item.b64_json
73+
image_bytes = base64.b64decode(image)
74+
75+
tool_context.state["generated_image_url"] = (
76+
f"data:image/jpeg;base64,{image}"
77+
)
78+
79+
report_artifact = types.Part.from_bytes(
80+
data=image_bytes, mime_type="image/png"
81+
)
82+
await tool_context.save_artifact(image_name, report_artifact)
83+
logger.debug(f"Image saved as ADK artifact: {image_name}")
84+
85+
return {"status": "success", "image_name": image_name, "image": image}
86+
else:
87+
error_details = f"No images returned by Doubao model: {response}"
88+
logger.error(error_details)
89+
return {"status": "error", "message": error_details}
90+
91+
except Exception as e:
92+
return {
93+
"status": "error",
94+
"message": f"Doubao image generation failed: {str(e)}",
95+
}
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from typing import Dict
15+
from google.adk.tools import ToolContext
16+
from volcenginesdkarkruntime import Ark
17+
from veadk.config import getenv
18+
import time
19+
import traceback
20+
21+
from veadk.utils.logger import get_logger
22+
23+
logger = get_logger(__name__)
24+
25+
client = Ark(
26+
api_key=getenv("MODEL_VIDEO_API_KEY"),
27+
base_url=getenv("MODEL_VIDEO_API_BASE"),
28+
)
29+
30+
31+
async def generate(tool_context, prompt, first_frame_image=None, last_frame_image=None):
32+
try:
33+
if first_frame_image is None:
34+
logger.debug("text generation")
35+
response = client.content_generation.tasks.create(
36+
model=getenv("MODEL_VIDEO_NAME"),
37+
content=[
38+
{"type": "text", "text": prompt},
39+
],
40+
)
41+
elif last_frame_image is None:
42+
logger.debug("first frame generation")
43+
response = client.content_generation.tasks.create(
44+
model=getenv("MODEL_VIDEO_NAME"),
45+
content=[
46+
{"type": "text", "text": prompt},
47+
{
48+
"type": "image_url",
49+
"image_url": {"url": first_frame_image},
50+
},
51+
],
52+
)
53+
else:
54+
logger.debug("last frame generation")
55+
response = client.content_generation.tasks.create(
56+
model=getenv("MODEL_VIDEO_NAME"),
57+
content=[
58+
{"type": "text", "text": prompt},
59+
{
60+
"type": "image_url",
61+
"image_url": {"url": first_frame_image},
62+
"role": "first_frame",
63+
},
64+
{
65+
"type": "image_url",
66+
"image_url": {"url": last_frame_image},
67+
"role": "last_frame",
68+
},
69+
],
70+
)
71+
except:
72+
traceback.print_exc()
73+
raise
74+
return response
75+
76+
77+
async def video_generate(params: list, tool_context: ToolContext) -> Dict:
78+
"""Generate video in batch according to the prompt.
79+
80+
Args:
81+
params:
82+
video_name: The name of the generated video.
83+
first_frame: The first frame of the video, url or base64 string, or None.
84+
last_frame:The last frame of the video, url or base64 string, or None.
85+
prompt:The prompt of the video.
86+
"""
87+
batch_size = 10
88+
success_list = []
89+
error_list = []
90+
for start_idx in range(0, len(params), batch_size):
91+
batch = params[start_idx : start_idx + batch_size]
92+
task_dict = {}
93+
for item in batch:
94+
video_name = item["video_name"]
95+
first_frame = item["first_frame"]
96+
last_frame = item["last_frame"]
97+
prompt = item["prompt"]
98+
try:
99+
if not first_frame:
100+
response = await generate(tool_context, prompt)
101+
elif not last_frame:
102+
response = await generate(tool_context, prompt, first_frame)
103+
else:
104+
response = await generate(
105+
tool_context, prompt, first_frame, last_frame
106+
)
107+
task_dict[response.id] = video_name
108+
except Exception:
109+
traceback.print_exc()
110+
while True:
111+
task_list = list(task_dict.keys())
112+
if len(task_list) == 0:
113+
break
114+
for task_id in task_list:
115+
result = client.content_generation.tasks.get(task_id=task_id)
116+
status = result.status
117+
if status == "succeeded":
118+
logger.debug("----- task succeeded -----")
119+
tool_context.state[f"{task_dict[task_id]}_video_url"] = (
120+
result.content.video_url
121+
)
122+
success_list.append({task_dict[task_id]: result.content.video_url})
123+
task_dict.pop(task_id, None)
124+
elif status == "failed":
125+
logger.debug("----- task failed -----")
126+
logger.debug(f"Error: {result.error}")
127+
error_list.append(task_dict[task_id])
128+
task_dict.pop(task_id, None)
129+
else:
130+
logger.debug(
131+
f"Current status: {status}, Retrying after 10 seconds..."
132+
)
133+
time.sleep(10)
134+
135+
if len(success_list) == 0:
136+
return {"status": "error", "message": f"Following videos failed: {error_list}"}
137+
else:
138+
return {
139+
"status": "success",
140+
"message": f"Following videos generated: {success_list}\nFollowing videos failed: {error_list}",
141+
}

0 commit comments

Comments
 (0)