Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 67 additions & 21 deletions src/memos/graph_dbs/nebular.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import traceback

from contextlib import suppress
Expand Down Expand Up @@ -35,7 +36,28 @@ def _compose_node(item: dict[str, Any]) -> tuple[str, str, dict[str, Any]]:

@timed
def _escape_str(value: str) -> str:
return value.replace('"', '\\"')
out = []
for ch in value:
code = ord(ch)
if ch == "\\":
out.append("\\\\")
elif ch == '"':
out.append('\\"')
elif ch == "\n":
out.append("\\n")
elif ch == "\r":
out.append("\\r")
elif ch == "\t":
out.append("\\t")
elif ch == "\b":
out.append("\\b")
elif ch == "\f":
out.append("\\f")
elif code < 0x20 or code in (0x2028, 0x2029):
out.append(f"\\u{code:04x}")
else:
out.append(ch)
return "".join(out)


@timed
Expand Down Expand Up @@ -1153,28 +1175,36 @@ def import_graph(self, data: dict[str, Any]) -> None:
data: A dictionary containing all nodes and edges to be loaded.
"""
for node in data.get("nodes", []):
id, memory, metadata = _compose_node(node)
try:
id, memory, metadata = _compose_node(node)

if not self.config.use_multi_db and self.config.user_name:
metadata["user_name"] = self.config.user_name
if not self.config.use_multi_db and self.config.user_name:
metadata["user_name"] = self.config.user_name

metadata = self._prepare_node_metadata(metadata)
metadata.update({"id": id, "memory": memory})
properties = ", ".join(f"{k}: {self._format_value(v, k)}" for k, v in metadata.items())
node_gql = f"INSERT OR IGNORE (n@Memory {{{properties}}})"
self.execute_query(node_gql)
metadata = self._prepare_node_metadata(metadata)
metadata.update({"id": id, "memory": memory})
properties = ", ".join(
f"{k}: {self._format_value(v, k)}" for k, v in metadata.items()
)
node_gql = f"INSERT OR IGNORE (n@Memory {{{properties}}})"
self.execute_query(node_gql)
except Exception as e:
logger.error(f"Fail to load node: {node}, error: {e}")

for edge in data.get("edges", []):
source_id, target_id = edge["source"], edge["target"]
edge_type = edge["type"]
props = ""
if not self.config.use_multi_db and self.config.user_name:
props = f'{{user_name: "{self.config.user_name}"}}'
edge_gql = f'''
MATCH (a@Memory {{id: "{source_id}"}}), (b@Memory {{id: "{target_id}"}})
INSERT OR IGNORE (a) -[e@{edge_type} {props}]-> (b)
'''
self.execute_query(edge_gql)
try:
source_id, target_id = edge["source"], edge["target"]
edge_type = edge["type"]
props = ""
if not self.config.use_multi_db and self.config.user_name:
props = f'{{user_name: "{self.config.user_name}"}}'
edge_gql = f'''
MATCH (a@Memory {{id: "{source_id}"}}), (b@Memory {{id: "{target_id}"}})
INSERT OR IGNORE (a) -[e@{edge_type} {props}]-> (b)
'''
self.execute_query(edge_gql)
except Exception as e:
logger.error(f"Fail to load edge: {edge}, error: {e}")

@timed
def get_all_memory_items(self, scope: str, include_embedding: bool = False) -> (list)[dict]:
Expand Down Expand Up @@ -1555,6 +1585,7 @@ def _prepare_node_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]:
# Normalize embedding type
embedding = metadata.get("embedding")
if embedding and isinstance(embedding, list):
metadata.pop("embedding")
metadata[self.dim_field] = _normalize([float(x) for x in embedding])

return metadata
Expand All @@ -1563,26 +1594,41 @@ def _prepare_node_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]:
def _format_value(self, val: Any, key: str = "") -> str:
from nebulagraph_python.py_data_types import NVector

# None
if val is None:
return "NULL"
# bool
if isinstance(val, bool):
return "true" if val else "false"
# str
if isinstance(val, str):
return f'"{_escape_str(val)}"'
# num
elif isinstance(val, (int | float)):
return str(val)
# time
elif isinstance(val, datetime):
return f'datetime("{val.isoformat()}")'
# list
elif isinstance(val, list):
if key == self.dim_field:
dim = len(val)
joined = ",".join(str(float(x)) for x in val)
return f"VECTOR<{dim}, FLOAT>([{joined}])"
else:
return f"[{', '.join(self._format_value(v) for v in val)}]"
# NVector
elif isinstance(val, NVector):
if key == self.dim_field:
dim = len(val)
joined = ",".join(str(float(x)) for x in val)
return f"VECTOR<{dim}, FLOAT>([{joined}])"
elif val is None:
return "NULL"
else:
logger.warning("Invalid NVector")
# dict
if isinstance(val, dict):
j = json.dumps(val, ensure_ascii=False, separators=(",", ":"))
return f'"{_escape_str(j)}"'
else:
return f'"{_escape_str(str(val))}"'

Expand Down
15 changes: 5 additions & 10 deletions src/memos/graph_dbs/neo4j.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,12 +323,11 @@ def edge_exists(
return result.single() is not None

# Graph Query & Reasoning
def get_node(self, id: str, include_embedding: bool = True) -> dict[str, Any] | None:
def get_node(self, id: str, **kwargs) -> dict[str, Any] | None:
"""
Retrieve the metadata and memory of a node.
Args:
id: Node identifier.
include_embedding (bool): Whether to include the large embedding field.
Returns:
Dictionary of node fields, or None if not found.
"""
Expand All @@ -345,12 +344,11 @@ def get_node(self, id: str, include_embedding: bool = True) -> dict[str, Any] |
record = session.run(query, params).single()
return self._parse_node(dict(record["n"])) if record else None

def get_nodes(self, ids: list[str], include_embedding: bool = True) -> list[dict[str, Any]]:
def get_nodes(self, ids: list[str], **kwargs) -> list[dict[str, Any]]:
"""
Retrieve the metadata and memory of a list of nodes.
Args:
ids: List of Node identifier.
include_embedding (bool): Whether to include the large embedding field.
Returns:
list[dict]: Parsed node records containing 'id', 'memory', and 'metadata'.

Expand Down Expand Up @@ -833,7 +831,7 @@ def clear(self) -> None:
logger.error(f"[ERROR] Failed to clear database '{self.db_name}': {e}")
raise

def export_graph(self, include_embedding: bool = True) -> dict[str, Any]:
def export_graph(self, **kwargs) -> dict[str, Any]:
"""
Export all graph nodes and edges in a structured form.

Expand Down Expand Up @@ -914,13 +912,12 @@ def import_graph(self, data: dict[str, Any]) -> None:
target_id=edge["target"],
)

def get_all_memory_items(self, scope: str, include_embedding: bool = True) -> list[dict]:
def get_all_memory_items(self, scope: str, **kwargs) -> list[dict]:
"""
Retrieve all memory items of a specific memory_type.

Args:
scope (str): Must be one of 'WorkingMemory', 'LongTermMemory', or 'UserMemory'.
include_embedding (bool): Whether to include the large embedding field.
Returns:

Returns:
Expand All @@ -946,9 +943,7 @@ def get_all_memory_items(self, scope: str, include_embedding: bool = True) -> li
results = session.run(query, params)
return [self._parse_node(dict(record["n"])) for record in results]

def get_structure_optimization_candidates(
self, scope: str, include_embedding: bool = True
) -> list[dict]:
def get_structure_optimization_candidates(self, scope: str, **kwargs) -> list[dict]:
"""
Find nodes that are likely candidates for structure optimization:
- Isolated nodes, nodes with empty background, or nodes with exactly one child.
Expand Down
4 changes: 1 addition & 3 deletions src/memos/graph_dbs/neo4j_community.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,14 +169,12 @@ def search_by_embedding(
# Return consistent format
return [{"id": r.id, "score": r.score} for r in results]

def get_all_memory_items(self, scope: str, include_embedding: bool = True) -> list[dict]:
def get_all_memory_items(self, scope: str, **kwargs) -> list[dict]:
"""
Retrieve all memory items of a specific memory_type.

Args:
scope (str): Must be one of 'WorkingMemory', 'LongTermMemory', or 'UserMemory'.
include_embedding (bool): Whether to include the large embedding field.

Returns:
list[dict]: Full list of memory items under this scope.
"""
Expand Down
Loading