Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/memos/api/product_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ class ChatRequest(BaseRequest):
mem_cube_id: str | None = Field(None, description="Cube ID to use for chat")
history: list[MessageDict] | None = Field(None, description="Chat history")
internet_search: bool = Field(True, description="Whether to use internet search")
moscube: bool = Field(False, description="Whether to use MemOSCube")


class UserCreate(BaseRequest):
Expand Down
1 change: 1 addition & 0 deletions src/memos/api/routers/product_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ def generate_chat_response():
cube_id=chat_req.mem_cube_id,
history=chat_req.history,
internet_search=chat_req.internet_search,
moscube=chat_req.moscube,
)

except Exception as e:
Expand Down
3 changes: 3 additions & 0 deletions src/memos/mem_os/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,8 @@ def search(
top_k: int | None = None,
mode: Literal["fast", "fine"] = "fast",
internet_search: bool = False,
moscube: bool = False,
**kwargs,
) -> MOSSearchResult:
"""
Search for textual memories across all registered MemCubes.
Expand Down Expand Up @@ -603,6 +605,7 @@ def search(
"session_id": self.session_id,
"chat_history": chat_history.chat_history,
},
moscube=moscube,
)
result["text_mem"].append({"cube_id": mem_cube_id, "memories": memories})
logger.info(
Expand Down
4 changes: 3 additions & 1 deletion src/memos/mem_os/product.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,7 @@ def _send_message_to_scheduler(
self.mem_scheduler.submit_messages(messages=[message_item])

def _filter_memories_by_threshold(
self, memories: list[TextualMemoryItem], threshold: float = 0.20, min_num: int = 3
self, memories: list[TextualMemoryItem], threshold: float = 0.50, min_num: int = 3
) -> list[TextualMemoryItem]:
"""
Filter memories by threshold.
Expand Down Expand Up @@ -717,6 +717,7 @@ def chat_with_references(
history: MessageList | None = None,
top_k: int = 10,
internet_search: bool = False,
moscube: bool = False,
) -> Generator[str, None, None]:
"""
Chat with LLM with memory references and streaming output.
Expand All @@ -742,6 +743,7 @@ def chat_with_references(
top_k=top_k,
mode="fine",
internet_search=internet_search,
moscube=moscube,
)["text_mem"]

yield f"data: {json.dumps({'type': 'status', 'data': '1'})}\n\n"
Expand Down
3 changes: 3 additions & 0 deletions src/memos/memories/textual/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def search(
mode: str = "fast",
memory_type: str = "All",
manual_close_internet: bool = False,
moscube: bool = False,
) -> list[TextualMemoryItem]:
"""Search for memories based on a query.
User query -> TaskGoalParser -> MemoryPathResolver ->
Expand All @@ -122,13 +123,15 @@ def search(
self.graph_store,
self.embedder,
internet_retriever=None,
moscube=moscube,
)
else:
searcher = Searcher(
self.dispatcher_llm,
self.graph_store,
self.embedder,
internet_retriever=self.internet_retriever,
moscube=moscube,
)
return searcher.search(query, top_k, info, mode, memory_type)

Expand Down
21 changes: 12 additions & 9 deletions src/memos/memories/textual/tree_text_memory/retrieve/searcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(
graph_store: Neo4jGraphDB,
embedder: OllamaEmbedder,
internet_retriever: InternetRetrieverFactory | None = None,
moscube: bool = False,
):
self.graph_store = graph_store
self.embedder = embedder
Expand All @@ -38,6 +39,7 @@ def __init__(

# Create internet retriever from config if provided
self.internet_retriever = internet_retriever
self.moscube = moscube

@timed
def search(
Expand Down Expand Up @@ -157,16 +159,17 @@ def _retrieve_paths(self, query, parsed_goal, query_embedding, info, top_k, mode
memory_type,
)
)
tasks.append(
executor.submit(
self._retrieve_from_memcubes,
query,
parsed_goal,
query_embedding,
top_k,
"memos_cube01",
if self.moscube:
tasks.append(
executor.submit(
self._retrieve_from_memcubes,
query,
parsed_goal,
query_embedding,
top_k,
"memos_cube01",
)
)
)

results = []
for t in tasks:
Expand Down
Loading