Skip to content

Commit af50ef0

Browse files
author
久氢
committed
feat(memory_collection): enhance MySQL support
Change-Id: I247f03ec92a23a90c208e6c33f35e82d881e8042 Co-developed-by: Cursor <noreply@cursor.com> Signed-off-by: 久氢 <mapenghui.mph@alibaba-inc.com>
1 parent 5110e13 commit af50ef0

File tree

4 files changed

+613
-27
lines changed

4 files changed

+613
-27
lines changed

agentrun/memory_collection/__memory_collection_async_template.py

Lines changed: 105 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,21 +128,27 @@ async def list_all_async(
128128
cls,
129129
*,
130130
memory_collection_name: Optional[str] = None,
131+
status: Optional[str] = None,
132+
type: Optional[str] = None,
131133
config: Optional[Config] = None,
132134
) -> List[MemoryCollectionListOutput]:
133135
"""列出所有记忆集合(异步)
134136
135137
Args:
136-
memory_collection_name: 记忆集合名称(可选)
137-
config: 配置
138+
memory_collection_name: 记忆集合名称(可选) / Memory Collection name (optional)
139+
status: 状态过滤(可选) / Status filter (optional)
140+
type: 类型过滤(可选) / Type filter (optional)
141+
config: 配置 / Configuration
138142
139143
Returns:
140-
List[MemoryCollectionListOutput]: 记忆集合列表
144+
List[MemoryCollectionListOutput]: 记忆集合列表 / Memory collection list
141145
"""
142146
return await cls._list_all_async(
143147
lambda mc: mc.memory_collection_id or "",
144148
config=config,
145149
memory_collection_name=memory_collection_name,
150+
status=status,
151+
type=type,
146152
)
147153

148154
async def update_async(
@@ -291,6 +297,32 @@ def _convert_vpc_endpoint_to_public(endpoint: str) -> str:
291297
)
292298
return endpoint
293299

300+
@staticmethod
301+
def _get_mysql_public_host(internal_host: str) -> str:
302+
"""获取 MySQL 公网地址
303+
304+
优先从环境变量 AGENTRUN_MYSQL_PUBLIC_HOST 读取公网地址,
305+
如果未设置则使用内网地址。
306+
307+
Args:
308+
internal_host: 内网地址
309+
310+
Returns:
311+
str: 公网地址或内网地址
312+
313+
Example:
314+
>>> import os
315+
>>> os.environ["AGENTRUN_MYSQL_PUBLIC_HOST"] = "public.mysql.com"
316+
>>> _get_mysql_public_host("internal.mysql.com")
317+
"public.mysql.com"
318+
"""
319+
import os
320+
321+
public_host = os.environ.get("AGENTRUN_MYSQL_PUBLIC_HOST")
322+
if public_host:
323+
return public_host
324+
return internal_host
325+
294326
@classmethod
295327
async def _build_mem0_config_async(
296328
cls,
@@ -315,6 +347,7 @@ async def _build_mem0_config_async(
315347
vector_store_config = memory_collection.vector_store_config
316348
provider = vector_store_config.provider or ""
317349

350+
# 处理 aliyun_tablestore provider
318351
if vector_store_config.config:
319352
vs_config = vector_store_config.config
320353
vector_store: Dict[str, Any] = {
@@ -357,6 +390,47 @@ async def _build_mem0_config_async(
357390

358391
mem0_config["vector_store"] = vector_store
359392

393+
# 处理 alibabacloud_mysql provider
394+
elif vector_store_config.mysql_config:
395+
mysql_config = vector_store_config.mysql_config
396+
vector_store: Dict[str, Any] = {
397+
"provider": provider,
398+
"config": {},
399+
}
400+
401+
# 获取 MySQL 密码
402+
password = ""
403+
if mysql_config.credential_name:
404+
try:
405+
password = await cls._get_credential_secret_async(
406+
mysql_config.credential_name, config
407+
)
408+
except Exception as e:
409+
raise ValueError(
410+
"Failed to get MySQL password from credential "
411+
f"'{mysql_config.credential_name}': {e}"
412+
) from e
413+
414+
# 获取公网地址(优先从环境变量读取)
415+
host = cls._get_mysql_public_host(mysql_config.host or "")
416+
417+
# 构建 MySQL 配置
418+
vector_store["config"] = {
419+
"dbname": mysql_config.db_name,
420+
"collection_name": mysql_config.collection_name,
421+
"user": mysql_config.user,
422+
"password": password,
423+
"host": host,
424+
"port": mysql_config.port or 3306,
425+
"embedding_model_dims": (
426+
mysql_config.vector_dimension or 1536
427+
),
428+
"distance_function": "cosine",
429+
"m_value": 16,
430+
}
431+
432+
mem0_config["vector_store"] = vector_store
433+
360434
# 构建 llm 配置
361435
if memory_collection.llm_config:
362436
llm_config = memory_collection.llm_config
@@ -407,6 +481,34 @@ async def _build_mem0_config_async(
407481

408482
return mem0_config
409483

484+
@staticmethod
485+
async def _get_credential_secret_async(
486+
credential_name: str, config: Optional[Config]
487+
) -> str:
488+
"""从 Credential 获取密钥(异步)
489+
490+
Args:
491+
credential_name: Credential 名称
492+
config: AgentRun 配置
493+
494+
Returns:
495+
str: 密钥
496+
497+
Raises:
498+
ValueError: 如果 Credential 不存在或密钥为空
499+
"""
500+
from agentrun.credential import Credential
501+
502+
credential = await Credential.get_by_name_async(
503+
credential_name, config=config
504+
)
505+
if not credential.credential_secret:
506+
raise ValueError(
507+
f"Credential {credential_name} secret is empty. "
508+
"Please ensure the credential is properly configured."
509+
)
510+
return credential.credential_secret
511+
410512
@staticmethod
411513
async def _resolve_model_service_config_async(
412514
model_service_name: str, config: Optional[Config]

0 commit comments

Comments
 (0)