Skip to content

Commit b9cc748

Browse files
authored
feat: add tos handler (#94)
* feat: add toshandler * feat: Normalize the code * feat: add License Header * feat: fix session_id * feat: add tos_config & modify the code * feat: modify the code * feat: modify _build_tos_object_key * feat: add test_tos.py * update code & add try-except in runner * add image url in coze * fix tos
1 parent 6dee92e commit b9cc748

File tree

6 files changed

+335
-7
lines changed

6 files changed

+335
-7
lines changed

tests/test_runner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def _test_convert_messages(runner):
2828
role="user",
2929
)
3030
]
31-
actual_message = runner._convert_messages(message)
31+
actual_message = runner._convert_messages(message, session_id="test_session_id")
3232
assert actual_message == expected_message
3333

3434
message = ["test message 1", "test message 2"]
@@ -42,7 +42,7 @@ def _test_convert_messages(runner):
4242
role="user",
4343
),
4444
]
45-
actual_message = runner._convert_messages(message)
45+
actual_message = runner._convert_messages(message, session_id="test_session_id")
4646
assert actual_message == expected_message
4747

4848

tests/test_tos.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pytest
16+
from unittest import mock
17+
import veadk.integrations.ve_tos.ve_tos as tos_mod
18+
19+
# 使用 pytest-asyncio
20+
pytest_plugins = ("pytest_asyncio",)
21+
22+
23+
@pytest.fixture
24+
def mock_client(monkeypatch):
25+
fake_client = mock.Mock()
26+
27+
monkeypatch.setenv("DATABASE_TOS_REGION", "test-region")
28+
monkeypatch.setenv("VOLCENGINE_ACCESS_KEY", "test-access-key")
29+
monkeypatch.setenv("VOLCENGINE_SECRET_KEY", "test-secret-key")
30+
monkeypatch.setenv("DATABASE_TOS_BUCKET", "test-bucket")
31+
32+
monkeypatch.setattr(tos_mod.tos, "TosClientV2", lambda *a, **k: fake_client)
33+
34+
class FakeExceptions:
35+
class TosServerError(Exception):
36+
def __init__(self, msg):
37+
super().__init__(msg)
38+
self.status_code = None
39+
40+
monkeypatch.setattr(tos_mod.tos, "exceptions", FakeExceptions)
41+
monkeypatch.setattr(
42+
tos_mod.tos,
43+
"StorageClassType",
44+
type("S", (), {"Storage_Class_Standard": "STANDARD"}),
45+
)
46+
monkeypatch.setattr(
47+
tos_mod.tos, "ACLType", type("A", (), {"ACL_Private": "private"})
48+
)
49+
50+
return fake_client
51+
52+
53+
@pytest.fixture
54+
def tos_client(mock_client):
55+
return tos_mod.VeTOS()
56+
57+
58+
def test_create_bucket_exists(tos_client, mock_client):
59+
mock_client.head_bucket.return_value = None # head_bucket 正常返回表示存在
60+
result = tos_client.create_bucket()
61+
assert result is True
62+
mock_client.create_bucket.assert_not_called()
63+
64+
65+
def test_create_bucket_not_exists(tos_client, mock_client):
66+
exc = tos_mod.tos.exceptions.TosServerError("not found")
67+
exc.status_code = 404
68+
mock_client.head_bucket.side_effect = exc
69+
70+
result = tos_client.create_bucket()
71+
assert result is True
72+
mock_client.create_bucket.assert_called_once()
73+
74+
75+
@pytest.mark.asyncio
76+
async def test_upload_bytes_success(tos_client, mock_client):
77+
mock_client.head_bucket.return_value = True
78+
data = b"hello world"
79+
80+
result = await tos_client.upload("obj-key", data)
81+
assert result is True
82+
mock_client.put_object.assert_called_once()
83+
mock_client.close.assert_called_once()
84+
85+
86+
@pytest.mark.asyncio
87+
async def test_upload_file_success(tmp_path, tos_client, mock_client):
88+
mock_client.head_bucket.return_value = True
89+
file_path = tmp_path / "file.txt"
90+
file_path.write_text("hello file")
91+
92+
result = await tos_client.upload("obj-key", str(file_path))
93+
assert result is True
94+
mock_client.put_object_from_file.assert_called_once()
95+
mock_client.close.assert_called_once()
96+
97+
98+
def test_download_success(tmp_path, tos_client, mock_client):
99+
save_path = tmp_path / "out.txt"
100+
mock_client.get_object.return_value = [b"abc", b"def"]
101+
102+
result = tos_client.download("obj-key", str(save_path))
103+
assert result is True
104+
assert save_path.read_bytes() == b"abcdef"
105+
106+
107+
def test_download_fail(tos_client, mock_client):
108+
mock_client.get_object.side_effect = Exception("boom")
109+
result = tos_client.download("obj-key", "somewhere.txt")
110+
assert result is False
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
from veadk.config import getenv
17+
from veadk.utils.logger import get_logger
18+
import tos
19+
import asyncio
20+
from typing import Union
21+
from pydantic import BaseModel, Field
22+
from typing import Any
23+
from urllib.parse import urlparse
24+
from datetime import datetime
25+
26+
logger = get_logger(__name__)
27+
28+
29+
class TOSConfig(BaseModel):
30+
region: str = Field(
31+
default_factory=lambda: getenv("DATABASE_TOS_REGION"),
32+
description="TOS region",
33+
)
34+
ak: str = Field(
35+
default_factory=lambda: getenv("VOLCENGINE_ACCESS_KEY"),
36+
description="Volcengine access key",
37+
)
38+
sk: str = Field(
39+
default_factory=lambda: getenv("VOLCENGINE_SECRET_KEY"),
40+
description="Volcengine secret key",
41+
)
42+
bucket_name: str = Field(
43+
default_factory=lambda: getenv("DATABASE_TOS_BUCKET"),
44+
description="TOS bucket name",
45+
)
46+
47+
48+
class VeTOS(BaseModel):
49+
config: TOSConfig = Field(default_factory=TOSConfig)
50+
51+
def model_post_init(self, __context: Any) -> None:
52+
try:
53+
self._client = tos.TosClientV2(
54+
self.config.ak,
55+
self.config.sk,
56+
endpoint=f"tos-{self.config.region}.volces.com",
57+
region=self.config.region,
58+
)
59+
logger.info("Connected to TOS successfully.")
60+
except Exception as e:
61+
logger.error(f"Client initialization failed:{e}")
62+
return None
63+
64+
def create_bucket(self) -> bool:
65+
"""If the bucket does not exist, create it"""
66+
try:
67+
self._client.head_bucket(self.config.bucket_name)
68+
logger.info(f"Bucket {self.config.bucket_name} already exists")
69+
return True
70+
except tos.exceptions.TosServerError as e:
71+
if e.status_code == 404:
72+
self._client.create_bucket(
73+
bucket=self.config.bucket_name,
74+
storage_class=tos.StorageClassType.Storage_Class_Standard,
75+
acl=tos.ACLType.ACL_Private,
76+
)
77+
logger.info(f"Bucket {self.config.bucket_name} created successfully")
78+
return True
79+
except Exception as e:
80+
logger.error(f"Bucket creation failed: {str(e)}")
81+
return False
82+
83+
def build_tos_url(
84+
self, user_id: str, app_name: str, session_id: str, data_path: str
85+
) -> tuple[str, str]:
86+
"""generate TOS object key"""
87+
parsed_url = urlparse(data_path)
88+
89+
if parsed_url.scheme and parsed_url.scheme in ("http", "https", "ftp", "ftps"):
90+
file_name = os.path.basename(parsed_url.path)
91+
else:
92+
file_name = os.path.basename(data_path)
93+
94+
timestamp: str = datetime.now().strftime("%Y%m%d%H%M%S%f")[:-3]
95+
object_key: str = f"{app_name}-{user_id}-{session_id}/{timestamp}-{file_name}"
96+
tos_url: str = f"https://{self.config.bucket_name}.tos-{self.config.region}.volces.com/{object_key}"
97+
98+
return object_key, tos_url
99+
100+
def upload(
101+
self,
102+
object_key: str,
103+
data: Union[str, bytes],
104+
):
105+
if isinstance(data, str):
106+
data_type = "file"
107+
elif isinstance(data, bytes):
108+
data_type = "bytes"
109+
else:
110+
error_msg = f"Upload failed: data type error. Only str (file path) and bytes are supported, got {type(data)}"
111+
logger.error(error_msg)
112+
raise ValueError(error_msg)
113+
if data_type == "file":
114+
return asyncio.to_thread(self._do_upload_file, object_key, data)
115+
elif data_type == "bytes":
116+
return asyncio.to_thread(self._do_upload_bytes, object_key, data)
117+
118+
def _do_upload_bytes(self, object_key: str, bytes: bytes) -> bool:
119+
try:
120+
if not self._client:
121+
return False
122+
if not self.create_bucket():
123+
return False
124+
self._client.put_object(
125+
bucket=self.config.bucket_name, key=object_key, content=bytes
126+
)
127+
logger.debug(f"Upload success, object_key: {object_key}")
128+
self._close()
129+
return True
130+
except Exception as e:
131+
logger.error(f"Upload failed: {e}")
132+
self._close()
133+
return False
134+
135+
def _do_upload_file(self, object_key: str, file_path: str) -> bool:
136+
try:
137+
if not self._client:
138+
return False
139+
if not self.create_bucket():
140+
return False
141+
142+
self._client.put_object_from_file(
143+
bucket=self.config.bucket_name, key=object_key, file_path=file_path
144+
)
145+
self._close()
146+
logger.debug(f"Upload success, object_key: {object_key}")
147+
return True
148+
except Exception as e:
149+
logger.error(f"Upload failed: {e}")
150+
self._close()
151+
return False
152+
153+
def download(self, object_key: str, save_path: str) -> bool:
154+
"""download image from TOS"""
155+
try:
156+
object_stream = self._client.get_object(self.config.bucket_name, object_key)
157+
158+
save_dir = os.path.dirname(save_path)
159+
if save_dir and not os.path.exists(save_dir):
160+
os.makedirs(save_dir, exist_ok=True)
161+
162+
with open(save_path, "wb") as f:
163+
for chunk in object_stream:
164+
f.write(chunk)
165+
166+
logger.debug(f"Image download success, saved to: {save_path}")
167+
return True
168+
169+
except Exception as e:
170+
logger.error(f"Image download failed: {str(e)}")
171+
172+
return False
173+
174+
def _close(self):
175+
if self._client:
176+
self._client.close()

veadk/runner.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import asyncio
1415
from typing import Union
1516

1617
from google.adk.agents import RunConfig
@@ -31,6 +32,7 @@
3132
from veadk.types import MediaMessage
3233
from veadk.utils.logger import get_logger
3334
from veadk.utils.misc import read_png_to_bytes
35+
from veadk.integrations.ve_tos.ve_tos import VeTOS
3436

3537
logger = get_logger(__name__)
3638

@@ -84,22 +86,34 @@ def __init__(
8486
plugins=plugins,
8587
)
8688

87-
def _convert_messages(self, messages) -> list:
89+
def _convert_messages(self, messages, session_id) -> list:
8890
if isinstance(messages, str):
8991
messages = [types.Content(role="user", parts=[types.Part(text=messages)])]
9092
elif isinstance(messages, MediaMessage):
9193
assert messages.media.endswith(".png"), (
9294
"The MediaMessage only supports PNG format file for now."
9395
)
96+
data = read_png_to_bytes(messages.media)
97+
98+
ve_tos = VeTOS()
99+
object_key, tos_url = ve_tos.build_tos_url(
100+
self.user_id, self.app_name, session_id, messages.media
101+
)
102+
try:
103+
asyncio.create_task(ve_tos.upload(object_key, data))
104+
except Exception as e:
105+
logger.error(f"Upload to TOS failed: {e}")
106+
tos_url = None
107+
94108
messages = [
95109
types.Content(
96110
role="user",
97111
parts=[
98112
types.Part(text=messages.text),
99113
types.Part(
100114
inline_data=Blob(
101-
display_name=messages.media,
102-
data=read_png_to_bytes(messages.media),
115+
display_name=tos_url,
116+
data=data,
103117
mime_type="image/png",
104118
)
105119
),
@@ -109,7 +123,7 @@ def _convert_messages(self, messages) -> list:
109123
elif isinstance(messages, list):
110124
converted_messages = []
111125
for message in messages:
112-
converted_messages.extend(self._convert_messages(message))
126+
converted_messages.extend(self._convert_messages(message, session_id))
113127
messages = converted_messages
114128
else:
115129
raise ValueError(f"Unknown message type: {type(messages)}")
@@ -169,7 +183,7 @@ async def run(
169183
run_config: RunConfig | None = None,
170184
save_tracing_data: bool = False,
171185
):
172-
converted_messages: list = self._convert_messages(messages)
186+
converted_messages: list = self._convert_messages(messages, session_id)
173187

174188
await self.short_term_memory.create_session(
175189
app_name=self.app_name, user_id=self.user_id, session_id=session_id

0 commit comments

Comments
 (0)