Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 37 additions & 25 deletions src/memos/graph_dbs/polardb.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ def get_memory_count(self, memory_type: str, user_name: str | None = None) -> in
WHERE ag_catalog.agtype_access_operator(properties, '"memory_type"'::agtype) = %s::agtype
"""
query += "\nAND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype"
params = [f'"{memory_type}"', f'"{user_name}"']
params = [self.format_param_value(memory_type), self.format_param_value(user_name)]

# Get a connection from the pool
conn = self._get_connection()
Expand All @@ -389,7 +389,7 @@ def node_not_exist(self, scope: str, user_name: str | None = None) -> int:
"""
query += "\nAND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype"
query += "\nLIMIT 1"
params = [f'"{scope}"', f'"{user_name}"']
params = [self.format_param_value(scope), self.format_param_value(user_name)]

# Get a connection from the pool
conn = self._get_connection()
Expand Down Expand Up @@ -427,7 +427,11 @@ def remove_oldest_memory(
ORDER BY ag_catalog.agtype_access_operator(properties, '"updated_at"'::agtype) DESC
OFFSET %s
"""
select_params = [f'"{memory_type}"', f'"{user_name}"', keep_latest]
select_params = [
self.format_param_value(memory_type),
self.format_param_value(user_name),
keep_latest,
]
conn = self._get_connection()
try:
with conn.cursor() as cursor:
Expand Down Expand Up @@ -501,19 +505,23 @@ def update_node(self, id: str, fields: dict[str, Any], user_name: str | None = N
SET properties = %s, embedding = %s
WHERE ag_catalog.agtype_access_operator(properties, '"id"'::agtype) = %s::agtype
"""
params = [json.dumps(properties), json.dumps(embedding_vector), f'"{id}"']
params = [
json.dumps(properties),
json.dumps(embedding_vector),
self.format_param_value(id),
]
else:
query = f"""
UPDATE "{self.db_name}_graph"."Memory"
SET properties = %s
WHERE ag_catalog.agtype_access_operator(properties, '"id"'::agtype) = %s::agtype
"""
params = [json.dumps(properties), f'"{id}"']
params = [json.dumps(properties), self.format_param_value(id)]

# Only add user filter when user_name is provided
if user_name is not None:
query += "\nAND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype"
params.append(f'"{user_name}"')
params.append(self.format_param_value(user_name))

# Get a connection from the pool
conn = self._get_connection()
Expand All @@ -538,12 +546,12 @@ def delete_node(self, id: str, user_name: str | None = None) -> None:
DELETE FROM "{self.db_name}_graph"."Memory"
WHERE ag_catalog.agtype_access_operator(properties, '"id"'::agtype) = %s::agtype
"""
params = [f'"{id}"']
params = [self.format_param_value(id)]

# Only add user filter when user_name is provided
if user_name is not None:
query += "\nAND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype"
params.append(f'"{user_name}"')
params.append(self.format_param_value(user_name))

# Get a connection from the pool
conn = self._get_connection()
Expand Down Expand Up @@ -831,28 +839,17 @@ def get_node(

select_fields = "id, properties, embedding" if include_embedding else "id, properties"

# Helper function to format parameter value
def format_param_value(value: str) -> str:
"""Format parameter value to handle both quoted and unquoted formats"""
# Remove outer quotes if they exist
if value.startswith('"') and value.endswith('"'):
# Already has double quotes, return as is
return value
else:
# Add double quotes
return f'"{value}"'

query = f"""
SELECT {select_fields}
FROM "{self.db_name}_graph"."Memory"
WHERE ag_catalog.agtype_access_operator(properties, '"id"'::agtype) = %s::agtype
"""
params = [format_param_value(id)]
params = [self.format_param_value(id)]

# Only add user filter when user_name is provided
if user_name is not None:
query += "\nAND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype"
params.append(format_param_value(user_name))
params.append(self.format_param_value(user_name))

conn = self._get_connection()
try:
Expand Down Expand Up @@ -930,7 +927,7 @@ def get_nodes(
where_conditions.append(
"ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype) = %s::agtype"
)
params.append(f"{id_val}")
params.append(self.format_param_value(id_val))

where_clause = " OR ".join(where_conditions)

Expand All @@ -942,7 +939,7 @@ def get_nodes(

user_name = user_name if user_name else self.config.user_name
query += " AND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype"
params.append(f'"{user_name}"')
params.append(self.format_param_value(user_name))

conn = self._get_connection()
try:
Expand Down Expand Up @@ -2616,7 +2613,7 @@ def get_neighbors_by_tag(
exclude_conditions.append(
"ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype) != %s::agtype"
)
params.append(f'"{exclude_id}"')
params.append(self.format_param_value(exclude_id))
where_clauses.append(f"({' AND '.join(exclude_conditions)})")

# Status filter - keep only 'activated'
Expand All @@ -2633,7 +2630,7 @@ def get_neighbors_by_tag(
where_clauses.append(
"ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype"
)
params.append(f'"{user_name}"')
params.append(self.format_param_value(user_name))

# Testing showed no data; annotate.
where_clauses.append(
Expand Down Expand Up @@ -3022,3 +3019,18 @@ def _convert_graph_edges(self, core_node: dict) -> dict:
if tgt in id_map:
edge["target"] = id_map[tgt]
return data

def format_param_value(self, value: str | None) -> str:
"""Format parameter value to handle both quoted and unquoted formats"""
# Handle None value
if value is None:
logger.warning(f"format_param_value: value is None")
return "null"

# Remove outer quotes if they exist
if value.startswith('"') and value.endswith('"'):
# Already has double quotes, return as is
return value
else:
# Add double quotes
return f'"{value}"'
Loading