|
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
| 14 | + |
14 | 15 | from typing import Dict
|
15 | 16 | from google.adk.tools import ToolContext
|
16 | 17 | from google.genai import types
|
17 | 18 | from volcenginesdkarkruntime import Ark
|
18 | 19 | from veadk.config import getenv
|
19 | 20 | import base64
|
20 |
| - |
| 21 | +from opentelemetry import trace |
| 22 | +import traceback |
| 23 | +import json |
| 24 | +from veadk.version import VERSION |
| 25 | +from opentelemetry.trace import Span |
21 | 26 | from veadk.utils.logger import get_logger
|
22 | 27 |
|
23 | 28 | logger = get_logger(__name__)
|
24 | 29 |
|
25 | 30 | client = Ark(
|
26 |
| - api_key=getenv("MODEL_IMAGE_API_KEY"), |
27 |
| - base_url=getenv("MODEL_IMAGE_API_BASE"), |
| 31 | + api_key=getenv("MODEL_EDIT_API_KEY"), |
| 32 | + base_url=getenv("MODEL_EDIT_API_BASE"), |
28 | 33 | )
|
29 | 34 |
|
30 | 35 |
|
31 | 36 | async def image_edit(
|
32 |
| - origin_image: str, |
33 |
| - image_name: str, |
34 |
| - image_prompt: str, |
35 |
| - response_format: str, |
36 |
| - guidance_scale: float, |
37 |
| - watermark: bool, |
38 |
| - seed: int, |
| 37 | + params: list, |
39 | 38 | tool_context: ToolContext,
|
40 | 39 | ) -> Dict:
|
41 |
| - """Edit an image accoding to the prompt. |
| 40 | + """ |
| 41 | + Edit images in batch according to prompts and optional settings. |
| 42 | +
|
| 43 | + Each item in `params` describes a single image-edit request. |
42 | 44 |
|
43 | 45 | Args:
|
44 |
| - origin_image: The url or the base64 string of the edited image. |
45 |
| - image_name: The name of the generated image. |
46 |
| - image_prompt: The prompt that describes the image. |
47 |
| - response_format: str, b64_json or url, default url. |
48 |
| - guidance_scale: default 2.5. |
49 |
| - watermark: default True. |
50 |
| - seed: default -1. |
| 46 | + params (list[dict]): |
| 47 | + A list of image editing requests. Each item supports: |
| 48 | +
|
| 49 | + Required: |
| 50 | + - origin_image (str): |
| 51 | + The URL or Base64 string of the original image to edit. |
| 52 | + Example: |
| 53 | + * URL: "https://example.com/image.png" |
| 54 | + * Base64: "data:image/png;base64,<BASE64>" |
| 55 | +
|
| 56 | + - prompt (str): |
| 57 | + The textual description/instruction for editing the image. |
| 58 | + Supports English and Chinese. |
| 59 | +
|
| 60 | + Optional: |
| 61 | + - image_name (str): |
| 62 | + Name/identifier for the generated image. |
51 | 63 |
|
| 64 | + - response_format (str): |
| 65 | + Format of the returned image. |
| 66 | + * "url": JPEG link (default) |
| 67 | + * "b64_json": Base64 string in JSON |
| 68 | +
|
| 69 | + - guidance_scale (float): |
| 70 | + How strongly the prompt affects the result. |
| 71 | + Range: [1.0, 10.0], default 2.5. |
| 72 | +
|
| 73 | + - watermark (bool): |
| 74 | + Whether to add watermark. |
| 75 | + Default: True. |
| 76 | +
|
| 77 | + - seed (int): |
| 78 | + Random seed for reproducibility. |
| 79 | + Range: [-1, 2^31-1], default -1 (random). |
| 80 | +
|
| 81 | + Returns: |
| 82 | + Dict: API response containing generated image metadata. |
| 83 | + Example: |
| 84 | + { |
| 85 | + "status": "success", |
| 86 | + "success_list": [{"image_name": ""}], |
| 87 | + "error_list": [{}] |
| 88 | + } |
| 89 | +
|
| 90 | + Notes: |
| 91 | + - Uses SeedEdit 3.0 model. |
| 92 | + - Provide the same `seed` for consistent outputs across runs. |
| 93 | + - A high `guidance_scale` enforces stricter adherence to text prompt. |
52 | 94 | """
|
53 |
| - try: |
54 |
| - response = client.images.generate( |
55 |
| - model=getenv("MODEL_EDIT_NAME"), |
56 |
| - image=origin_image, |
57 |
| - prompt=image_prompt, |
58 |
| - response_format=response_format, |
59 |
| - guidance_scale=guidance_scale, |
60 |
| - watermark=watermark, |
61 |
| - seed=seed, |
62 |
| - ) |
63 |
| - |
64 |
| - if response.data and len(response.data) > 0: |
65 |
| - for item in response.data: |
66 |
| - if response_format == "url": |
67 |
| - image = item.url |
68 |
| - tool_context.state["generated_image_url"] = image |
69 |
| - |
70 |
| - elif response_format == "b64_json": |
71 |
| - image = item.b64_json |
72 |
| - image_bytes = base64.b64decode(image) |
73 |
| - |
74 |
| - tool_context.state["generated_image_url"] = ( |
75 |
| - f"data:image/jpeg;base64,{image}" |
76 |
| - ) |
77 |
| - |
78 |
| - report_artifact = types.Part.from_bytes( |
79 |
| - data=image_bytes, mime_type="image/png" |
80 |
| - ) |
81 |
| - await tool_context.save_artifact(image_name, report_artifact) |
82 |
| - logger.debug(f"Image saved as ADK artifact: {image_name}") |
83 |
| - |
84 |
| - return {"status": "success", "image_name": image_name, "image": image} |
85 |
| - else: |
86 |
| - error_details = f"No images returned by Doubao model: {response}" |
| 95 | + success_list = [] |
| 96 | + error_list = [] |
| 97 | + for idx, item in enumerate(params): |
| 98 | + image_name = item.get("image_name", f"generated_image_{idx}") |
| 99 | + prompt = item.get("prompt") |
| 100 | + origin_image = item.get("origin_image") |
| 101 | + response_format = item.get("response_format", "url") |
| 102 | + guidance_scale = item.get("guidance_scale", 2.5) |
| 103 | + watermark = item.get("watermark", True) |
| 104 | + seed = item.get("seed", -1) |
| 105 | + |
| 106 | + try: |
| 107 | + tracer = trace.get_tracer("gcp.vertex.agent") |
| 108 | + with tracer.start_as_current_span("call_llm") as span: |
| 109 | + inputs = { |
| 110 | + "prompt": prompt, |
| 111 | + "image": origin_image, |
| 112 | + "response_format": response_format, |
| 113 | + "guidance_scale": guidance_scale, |
| 114 | + "watermark": watermark, |
| 115 | + "seed": seed, |
| 116 | + } |
| 117 | + input_part = { |
| 118 | + "role": "user", |
| 119 | + "content": json.dumps(inputs, ensure_ascii=False), |
| 120 | + } |
| 121 | + response = client.images.generate( |
| 122 | + model=getenv("MODEL_EDIT_NAME"), **inputs |
| 123 | + ) |
| 124 | + output_part = None |
| 125 | + if response.data and len(response.data) > 0: |
| 126 | + for item in response.data: |
| 127 | + if response_format == "url": |
| 128 | + image = item.url |
| 129 | + tool_context.state[f"{image_name}_url"] = image |
| 130 | + output_part = { |
| 131 | + "message.role": "model", |
| 132 | + "message.content": image, |
| 133 | + } |
| 134 | + elif response_format == "b64_json": |
| 135 | + image = item.b64_json |
| 136 | + image_bytes = base64.b64decode(image) |
| 137 | + |
| 138 | + tool_context.state[f"{image_name}_url"] = ( |
| 139 | + f"data:image/jpeg;base64,{image}" |
| 140 | + ) |
| 141 | + |
| 142 | + report_artifact = types.Part.from_bytes( |
| 143 | + data=image_bytes, mime_type="image/png" |
| 144 | + ) |
| 145 | + await tool_context.save_artifact( |
| 146 | + image_name, report_artifact |
| 147 | + ) |
| 148 | + logger.debug(f"Image saved as ADK artifact: {image_name}") |
| 149 | + |
| 150 | + success_list.append({image_name: image}) |
| 151 | + else: |
| 152 | + error_details = f"No images returned by Doubao model: {response}" |
| 153 | + logger.error(error_details) |
| 154 | + error_list.append(image_name) |
| 155 | + |
| 156 | + add_span_attributes( |
| 157 | + span, |
| 158 | + tool_context, |
| 159 | + input_part=input_part, |
| 160 | + output_part=output_part, |
| 161 | + output_tokens=response.usage.output_tokens, |
| 162 | + total_tokens=response.usage.total_tokens, |
| 163 | + request_model=getenv("MODEL_EDIT_NAME"), |
| 164 | + response_model=getenv("MODEL_EDIT_NAME"), |
| 165 | + ) |
| 166 | + |
| 167 | + except Exception as e: |
| 168 | + error_details = f"No images returned by Doubao model: {e}" |
87 | 169 | logger.error(error_details)
|
88 |
| - return {"status": "error", "message": error_details} |
| 170 | + traceback.print_exc() |
| 171 | + error_list.append(image_name) |
89 | 172 |
|
90 |
| - except Exception as e: |
| 173 | + if len(success_list) == 0: |
91 | 174 | return {
|
92 | 175 | "status": "error",
|
93 |
| - "message": f"Doubao image generation failed: {str(e)}", |
| 176 | + "success_list": success_list, |
| 177 | + "error_list": error_list, |
| 178 | + } |
| 179 | + else: |
| 180 | + return { |
| 181 | + "status": "success", |
| 182 | + "success_list": success_list, |
| 183 | + "error_list": error_list, |
94 | 184 | }
|
| 185 | + |
| 186 | + |
| 187 | +def add_span_attributes( |
| 188 | + span: Span, |
| 189 | + tool_context: ToolContext, |
| 190 | + input_part: dict = None, |
| 191 | + output_part: dict = None, |
| 192 | + input_tokens: int = None, |
| 193 | + output_tokens: int = None, |
| 194 | + total_tokens: int = None, |
| 195 | + request_model: str = None, |
| 196 | + response_model: str = None, |
| 197 | +): |
| 198 | + try: |
| 199 | + # common attributes |
| 200 | + app_name = tool_context._invocation_context.app_name |
| 201 | + user_id = tool_context._invocation_context.user_id |
| 202 | + agent_name = tool_context.agent_name |
| 203 | + session_id = tool_context._invocation_context.session.id |
| 204 | + span.set_attribute("gen_ai.agent.name", agent_name) |
| 205 | + span.set_attribute("openinference.instrumentation.veadk", VERSION) |
| 206 | + span.set_attribute("gen_ai.app.name", app_name) |
| 207 | + span.set_attribute("gen_ai.user.id", user_id) |
| 208 | + span.set_attribute("gen_ai.session.id", session_id) |
| 209 | + span.set_attribute("agent_name", agent_name) |
| 210 | + span.set_attribute("agent.name", agent_name) |
| 211 | + span.set_attribute("app_name", app_name) |
| 212 | + span.set_attribute("app.name", app_name) |
| 213 | + span.set_attribute("user.id", user_id) |
| 214 | + span.set_attribute("session.id", session_id) |
| 215 | + span.set_attribute("cozeloop.report.source", "veadk") |
| 216 | + |
| 217 | + # llm attributes |
| 218 | + span.set_attribute("gen_ai.system", "openai") |
| 219 | + span.set_attribute("gen_ai.operation.name", "chat") |
| 220 | + if request_model: |
| 221 | + span.set_attribute("gen_ai.request.model", request_model) |
| 222 | + if response_model: |
| 223 | + span.set_attribute("gen_ai.response.model", response_model) |
| 224 | + if total_tokens: |
| 225 | + span.set_attribute("gen_ai.usage.total_tokens", total_tokens) |
| 226 | + if output_tokens: |
| 227 | + span.set_attribute("gen_ai.usage.output_tokens", output_tokens) |
| 228 | + if input_tokens: |
| 229 | + span.set_attribute("gen_ai.usage.input_tokens", input_tokens) |
| 230 | + if input_part: |
| 231 | + span.add_event("gen_ai.user.message", input_part) |
| 232 | + if output_part: |
| 233 | + span.add_event("gen_ai.choice", output_part) |
| 234 | + |
| 235 | + except Exception: |
| 236 | + traceback.print_exc() |
0 commit comments