Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ tmp/
# evaluation data
*.csv
*.jsonl
**settings.json**
evaluation/*tmp/
evaluation/results
evaluation/.env
Expand All @@ -19,7 +20,7 @@ evaluation/scripts/personamem

# benchmarks
benchmarks/

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
24 changes: 8 additions & 16 deletions src/memos/api/product_models.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,14 @@
import uuid

from typing import Generic, Literal, TypeAlias, TypeVar
from typing import Generic, Literal, TypeVar

from pydantic import BaseModel, Field
from typing_extensions import TypedDict

# Import message types from core types module
from memos.types import MessageDict

T = TypeVar("T")


# ─── Message Types ──────────────────────────────────────────────────────────────

# Chat message roles
MessageRole: TypeAlias = Literal["user", "assistant", "system"]


# Message structure
class MessageDict(TypedDict):
"""Typed dictionary for chat message dictionaries."""

role: MessageRole
content: str
T = TypeVar("T")


class BaseRequest(BaseModel):
Expand Down Expand Up @@ -86,6 +74,7 @@ class ChatRequest(BaseRequest):
history: list[MessageDict] | None = Field(None, description="Chat history")
internet_search: bool = Field(True, description="Whether to use internet search")
moscube: bool = Field(False, description="Whether to use MemOSCube")
session_id: str | None = Field(None, description="Session ID for soft-filtering memories")


class ChatCompleteRequest(BaseRequest):
Expand All @@ -100,6 +89,7 @@ class ChatCompleteRequest(BaseRequest):
base_prompt: str | None = Field(None, description="Base prompt to use for chat")
top_k: int = Field(10, description="Number of results to return")
threshold: float = Field(0.5, description="Threshold for filtering references")
session_id: str | None = Field(None, description="Session ID for soft-filtering memories")


class UserCreate(BaseRequest):
Expand Down Expand Up @@ -161,6 +151,7 @@ class MemoryCreateRequest(BaseRequest):
mem_cube_id: str | None = Field(None, description="Cube ID")
source: str | None = Field(None, description="Source of the memory")
user_profile: bool = Field(False, description="User profile memory")
session_id: str | None = Field(None, description="Session id")


class SearchRequest(BaseRequest):
Expand All @@ -170,6 +161,7 @@ class SearchRequest(BaseRequest):
query: str = Field(..., description="Search query")
mem_cube_id: str | None = Field(None, description="Cube ID to search in")
top_k: int = Field(10, description="Number of results to return")
session_id: str | None = Field(None, description="Session ID for soft-filtering memories")


class SuggestionRequest(BaseRequest):
Expand Down
4 changes: 4 additions & 0 deletions src/memos/api/routers/product_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ def create_memory(memory_req: MemoryCreateRequest):
mem_cube_id=memory_req.mem_cube_id,
source=memory_req.source,
user_profile=memory_req.user_profile,
session_id=memory_req.session_id,
)
return SimpleResponse(message="Memory created successfully")

Expand All @@ -224,6 +225,7 @@ def search_memories(search_req: SearchRequest):
user_id=search_req.user_id,
install_cube_ids=[search_req.mem_cube_id] if search_req.mem_cube_id else None,
top_k=search_req.top_k,
session_id=search_req.session_id,
)
return SearchResponse(message="Search completed successfully", data=result)

Expand Down Expand Up @@ -251,6 +253,7 @@ def generate_chat_response():
history=chat_req.history,
internet_search=chat_req.internet_search,
moscube=chat_req.moscube,
session_id=chat_req.session_id,
)

except Exception as e:
Expand Down Expand Up @@ -295,6 +298,7 @@ def chat_complete(chat_req: ChatCompleteRequest):
base_prompt=chat_req.base_prompt,
top_k=chat_req.top_k,
threshold=chat_req.threshold,
session_id=chat_req.session_id,
)

# Return the complete response
Expand Down
12 changes: 12 additions & 0 deletions src/memos/graph_dbs/nebular.py
Original file line number Diff line number Diff line change
Expand Up @@ -977,6 +977,7 @@ def search_by_embedding(
scope: str | None = None,
status: str | None = None,
threshold: float | None = None,
search_filter: dict | None = None,
**kwargs,
) -> list[dict]:
"""
Expand All @@ -989,6 +990,8 @@ def search_by_embedding(
status (str, optional): Node status filter (e.g., 'active', 'archived').
If provided, restricts results to nodes with matching status.
threshold (float, optional): Minimum similarity score threshold (0 ~ 1).
search_filter (dict, optional): Additional metadata filters for search results.
Keys should match node properties, values are the expected values.

Returns:
list[dict]: A list of dicts with 'id' and 'score', ordered by similarity.
Expand All @@ -998,6 +1001,7 @@ def search_by_embedding(
- If scope is provided, it restricts results to nodes with matching memory_type.
- If 'status' is provided, only nodes with the matching status will be returned.
- If threshold is provided, only results with score >= threshold will be returned.
- If search_filter is provided, additional WHERE clauses will be added for metadata filtering.
- Typical use case: restrict to 'status = activated' to avoid
matching archived or merged nodes.
"""
Expand All @@ -1017,6 +1021,14 @@ def search_by_embedding(
else:
where_clauses.append(f'n.user_name = "{self.config.user_name}"')

# Add search_filter conditions
if search_filter:
for key, value in search_filter.items():
if isinstance(value, str):
where_clauses.append(f'n.{key} = "{value}"')
else:
where_clauses.append(f"n.{key} = {value}")

where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""

gql = f"""
Expand Down
19 changes: 18 additions & 1 deletion src/memos/graph_dbs/neo4j.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,7 @@ def search_by_embedding(
scope: str | None = None,
status: str | None = None,
threshold: float | None = None,
search_filter: dict | None = None,
**kwargs,
) -> list[dict]:
"""
Expand All @@ -618,6 +619,8 @@ def search_by_embedding(
status (str, optional): Node status filter (e.g., 'active', 'archived').
If provided, restricts results to nodes with matching status.
threshold (float, optional): Minimum similarity score threshold (0 ~ 1).
search_filter (dict, optional): Additional metadata filters for search results.
Keys should match node properties, values are the expected values.

Returns:
list[dict]: A list of dicts with 'id' and 'score', ordered by similarity.
Expand All @@ -627,6 +630,7 @@ def search_by_embedding(
- If scope is provided, it restricts results to nodes with matching memory_type.
- If 'status' is provided, only nodes with the matching status will be returned.
- If threshold is provided, only results with score >= threshold will be returned.
- If search_filter is provided, additional WHERE clauses will be added for metadata filtering.
- Typical use case: restrict to 'status = activated' to avoid
matching archived or merged nodes.
"""
Expand All @@ -639,6 +643,12 @@ def search_by_embedding(
if not self.config.use_multi_db and self.config.user_name:
where_clauses.append("node.user_name = $user_name")

# Add search_filter conditions
if search_filter:
for key, _ in search_filter.items():
param_name = f"filter_{key}"
where_clauses.append(f"node.{key} = ${param_name}")

where_clause = ""
if where_clauses:
where_clause = "WHERE " + " AND ".join(where_clauses)
Expand All @@ -650,7 +660,8 @@ def search_by_embedding(
RETURN node.id AS id, score
"""

parameters = {"embedding": vector, "k": top_k, "scope": scope}
parameters = {"embedding": vector, "k": top_k}

if scope:
parameters["scope"] = scope
if status:
Expand All @@ -661,6 +672,12 @@ def search_by_embedding(
else:
parameters["user_name"] = self.config.user_name

# Add search_filter parameters
if search_filter:
for key, value in search_filter.items():
param_name = f"filter_{key}"
parameters[param_name] = value

with self.driver.session(database=self.db_name) as session:
result = session.run(query, parameters)
records = [{"id": record["id"], "score": record["score"]} for record in result]
Expand Down
7 changes: 7 additions & 0 deletions src/memos/graph_dbs/neo4j_community.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def search_by_embedding(
scope: str | None = None,
status: str | None = None,
threshold: float | None = None,
search_filter: dict | None = None,
**kwargs,
) -> list[dict]:
"""
Expand All @@ -140,6 +141,7 @@ def search_by_embedding(
scope (str, optional): Memory type filter (e.g., 'WorkingMemory', 'LongTermMemory').
status (str, optional): Node status filter (e.g., 'activated', 'archived').
threshold (float, optional): Minimum similarity score threshold (0 ~ 1).
search_filter (dict, optional): Additional metadata filters to apply.

Returns:
list[dict]: A list of dicts with 'id' and 'score', ordered by similarity.
Expand All @@ -149,6 +151,7 @@ def search_by_embedding(
- If 'scope' is provided, it restricts results to nodes with matching memory_type.
- If 'status' is provided, it further filters nodes by status.
- If 'threshold' is provided, only results with score >= threshold will be returned.
- If 'search_filter' is provided, it applies additional metadata-based filtering.
- The returned IDs can be used to fetch full node data from Neo4j if needed.
"""
# Build VecDB filter
Expand All @@ -163,6 +166,10 @@ def search_by_embedding(
else:
vec_filter["user_name"] = self.config.user_name

# Add search_filter conditions
if search_filter:
vec_filter.update(search_filter)

# Perform vector search
results = self.vec_db.search(query_vector=vector, top_k=top_k, filter=vec_filter)

Expand Down
14 changes: 13 additions & 1 deletion src/memos/mem_os/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,7 @@ def search(
mode: Literal["fast", "fine"] = "fast",
internet_search: bool = False,
moscube: bool = False,
session_id: str | None = None,
**kwargs,
) -> MOSSearchResult:
"""
Expand All @@ -563,6 +564,7 @@ def search(
MemoryResult: A dictionary containing the search results.
"""
target_user_id = user_id if user_id is not None else self.user_id

self._validate_user_exists(target_user_id)
# Get all cubes accessible by the target user
accessible_cubes = self.user_manager.get_user_cubes(target_user_id)
Expand All @@ -575,6 +577,11 @@ def search(
self._register_chat_history(target_user_id)
chat_history = self.chat_history_manager[target_user_id]

# Create search filter if session_id is provided
search_filter = None
if session_id is not None:
search_filter = {"session_id": session_id}

result: MOSSearchResult = {
"text_mem": [],
"act_mem": [],
Expand Down Expand Up @@ -602,10 +609,11 @@ def search(
manual_close_internet=not internet_search,
info={
"user_id": target_user_id,
"session_id": self.session_id,
"session_id": session_id if session_id is not None else self.session_id,
"chat_history": chat_history.chat_history,
},
moscube=moscube,
search_filter=search_filter,
)
result["text_mem"].append({"cube_id": mem_cube_id, "memories": memories})
logger.info(
Expand All @@ -624,6 +632,8 @@ def add(
doc_path: str | None = None,
mem_cube_id: str | None = None,
user_id: str | None = None,
session_id: str | None = None,
**kwargs,
) -> None:
"""
Add textual memories to a MemCube.
Expand All @@ -636,11 +646,13 @@ def add(
If None, the default MemCube for the user is used.
user_id (str, optional): The identifier of the user to add the memories to.
If None, the default user is used.
session_id (str, optional): session_id
"""
# user input messages
assert (messages is not None) or (memory_content is not None) or (doc_path is not None), (
"messages_or_doc_path or memory_content or doc_path must be provided."
)
self.session_id = session_id
target_user_id = user_id if user_id is not None else self.user_id
if mem_cube_id is None:
# Try to find a default cube for the user
Expand Down
14 changes: 12 additions & 2 deletions src/memos/mem_os/product.py
Original file line number Diff line number Diff line change
Expand Up @@ -928,6 +928,7 @@ def chat(
moscube: bool = False,
top_k: int = 10,
threshold: float = 0.5,
session_id: str | None = None,
) -> str:
"""
Chat with LLM with memory references and complete response.
Expand All @@ -942,6 +943,7 @@ def chat(
mode="fine",
internet_search=internet_search,
moscube=moscube,
session_id=session_id,
)["text_mem"]

memories_list = []
Expand Down Expand Up @@ -986,6 +988,7 @@ def chat_with_references(
top_k: int = 20,
internet_search: bool = False,
moscube: bool = False,
session_id: str | None = None,
) -> Generator[str, None, None]:
"""
Chat with LLM with memory references and streaming output.
Expand All @@ -1012,6 +1015,7 @@ def chat_with_references(
mode="fine",
internet_search=internet_search,
moscube=moscube,
session_id=session_id,
)["text_mem"]

yield f"data: {json.dumps({'type': 'status', 'data': '1'})}\n\n"
Expand Down Expand Up @@ -1300,6 +1304,7 @@ def search(
install_cube_ids: list[str] | None = None,
top_k: int = 10,
mode: Literal["fast", "fine"] = "fast",
session_id: str | None = None,
):
"""Search memories for a specific user."""

Expand All @@ -1310,7 +1315,9 @@ def search(
logger.info(
f"time search: load_user_cubes time user_id: {user_id} time is: {load_user_cubes_time_end - time_start}"
)
search_result = super().search(query, user_id, install_cube_ids, top_k, mode=mode)
search_result = super().search(
query, user_id, install_cube_ids, top_k, mode=mode, session_id=session_id
)
search_time_end = time.time()
logger.info(
f"time search: search text_mem time user_id: {user_id} time is: {search_time_end - load_user_cubes_time_end}"
Expand Down Expand Up @@ -1346,13 +1353,16 @@ def add(
mem_cube_id: str | None = None,
source: str | None = None,
user_profile: bool = False,
session_id: str | None = None,
):
"""Add memory for a specific user."""

# Load user cubes if not already loaded
self._load_user_cubes(user_id, self.default_cube_config)

result = super().add(messages, memory_content, doc_path, mem_cube_id, user_id)
result = super().add(
messages, memory_content, doc_path, mem_cube_id, user_id, session_id=session_id
)
if user_profile:
try:
user_interests = memory_content.split("'userInterests': '")[1].split("', '")[0]
Expand Down
Loading
Loading