Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/memos/api/handlers/formatters_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def post_process_pref_mem(
{
"cube_id": mem_cube_id,
"memories": pref_formatted_mem,
"total_nodes": len(pref_formatted_mem),
}
)
pref_instruction, pref_note = instruct_completion(pref_formatted_mem)
Expand Down Expand Up @@ -116,12 +117,14 @@ def post_process_textual_mem(
{
"cube_id": mem_cube_id,
"memories": fact_mem,
"total_nodes": len(fact_mem),
}
)
memories_result["tool_mem"].append(
{
"cube_id": mem_cube_id,
"memories": tool_mem,
"total_nodes": len(tool_mem),
}
)
return memories_result
75 changes: 46 additions & 29 deletions src/memos/api/handlers/memory_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@

from typing import TYPE_CHECKING, Any, Literal

from memos.api.handlers.formatters_handler import format_memory_item
from memos.api.handlers.formatters_handler import (
format_memory_item,
post_process_pref_mem,
post_process_textual_mem,
)
from memos.api.product_models import (
DeleteMemoryRequest,
DeleteMemoryResponse,
Expand Down Expand Up @@ -209,54 +213,67 @@ def handle_get_memory(memory_id: str, naive_mem_cube: NaiveMemCube) -> GetMemory
def handle_get_memories(
get_mem_req: GetMemoryRequest, naive_mem_cube: NaiveMemCube
) -> GetMemoryResponse:
# TODO: Implement get memory with filter
results: dict[str, Any] = {"text_mem": [], "pref_mem": [], "tool_mem": []}
memories = naive_mem_cube.text_mem.get_all(
user_name=get_mem_req.mem_cube_id,
user_id=get_mem_req.user_id,
page=get_mem_req.page,
page_size=get_mem_req.page_size,
)
total_nodes = memories["total_nodes"]
total_edges = memories["total_edges"]
del memories["total_nodes"]
del memories["total_edges"]
filter=get_mem_req.filter,
)["nodes"]

results = post_process_textual_mem(results, memories, get_mem_req.mem_cube_id)

if not get_mem_req.include_tool_memory:
results["tool_mem"] = []

preferences: list[TextualMemoryItem] = []
total_pref = 0

format_preferences = []
if get_mem_req.include_preference and naive_mem_cube.pref_mem is not None:
filter_params: dict[str, Any] = {}
if get_mem_req.user_id is not None:
filter_params["user_id"] = get_mem_req.user_id
if get_mem_req.mem_cube_id is not None:
filter_params["mem_cube_id"] = get_mem_req.mem_cube_id

preferences, total_pref = naive_mem_cube.pref_mem.get_memory_by_filter(
if get_mem_req.filter is not None:
# Check and remove user_id/mem_cube_id from filter if present
filter_copy = get_mem_req.filter.copy()
removed_fields = []

if "user_id" in filter_copy:
filter_copy.pop("user_id")
removed_fields.append("user_id")
if "mem_cube_id" in filter_copy:
filter_copy.pop("mem_cube_id")
removed_fields.append("mem_cube_id")

if removed_fields:
logger.warning(
f"Fields {removed_fields} found in filter will be ignored. "
f"Use request-level user_id/mem_cube_id parameters instead."
)

filter_params.update(filter_copy)

preferences, _ = naive_mem_cube.pref_mem.get_memory_by_filter(
filter_params, page=get_mem_req.page, page_size=get_mem_req.page_size
)
format_preferences = [format_memory_item(item) for item in preferences]

return GetMemoryResponse(
message="Memories retrieved successfully",
data={
"text_mem": [
{
"cube_id": get_mem_req.mem_cube_id,
"memories": memories,
"total_nodes": total_nodes,
"total_edges": total_edges,
}
],
"pref_mem": [
{
"cube_id": get_mem_req.mem_cube_id,
"memories": format_preferences,
"total_nodes": total_pref,
}
],
},
results = post_process_pref_mem(
results, format_preferences, get_mem_req.mem_cube_id, get_mem_req.include_preference
)

# Filter to only keep text_mem, pref_mem, tool_mem
filtered_results = {
"text_mem": results.get("text_mem", []),
"pref_mem": results.get("pref_mem", []),
"tool_mem": results.get("tool_mem", []),
}

return GetMemoryResponse(message="Memories retrieved successfully", data=filtered_results)


def handle_delete_memories(delete_mem_req: DeleteMemoryRequest, naive_mem_cube: NaiveMemCube):
logger.info(
Expand Down
4 changes: 3 additions & 1 deletion src/memos/api/product_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,7 +771,9 @@ class GetMemoryRequest(BaseRequest):

mem_cube_id: str = Field(..., description="Cube ID")
user_id: str | None = Field(None, description="User ID")
include_preference: bool = Field(True, description="Whether to handle preference memory")
include_preference: bool = Field(True, description="Whether to return preference memory")
include_tool_memory: bool = Field(False, description="Whether to return tool memory")
filter: dict[str, Any] | None = Field(None, description="Filter for the memory")
page: int | None = Field(
None,
description="Page number (starts from 1). If None, exports all data without pagination.",
Expand Down
48 changes: 33 additions & 15 deletions src/memos/mem_reader/multi_modal_struct.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import concurrent.futures
import json
import re
import traceback

from typing import Any
Expand Down Expand Up @@ -547,7 +548,11 @@ def _process_tool_trajectory_fine(
for fast_item in fast_memory_items:
# Extract memory text (string content)
mem_str = fast_item.memory or ""
if not mem_str.strip() or "tool:" not in mem_str:
if not mem_str.strip() or (
"tool:" not in mem_str
and "[tool_calls]:" not in mem_str
and not re.search(r"<tool_schema>.*?</tool_schema>", mem_str, re.DOTALL)
):
continue
try:
resp = self._get_llm_tool_trajectory_response(mem_str)
Expand All @@ -563,6 +568,8 @@ def _process_tool_trajectory_fine(
value=m.get("trajectory", ""),
info=info,
memory_type=memory_type,
correctness=m.get("correctness", ""),
experience=m.get("experience", ""),
tool_used_status=m.get("tool_used_status", []),
)
fine_memory_items.append(node)
Expand Down Expand Up @@ -606,16 +613,22 @@ def _process_multi_modal_data(
if mode == "fast":
return fast_memory_items
else:
# Part A: call llm
# Part A: call llm in parallel using thread pool
fine_memory_items = []
fine_memory_items_string_parser = self._process_string_fine(
fast_memory_items, info, custom_tags
)
fine_memory_items.extend(fine_memory_items_string_parser)

fine_memory_items_tool_trajectory_parser = self._process_tool_trajectory_fine(
fast_memory_items, info
)
with ContextThreadPoolExecutor(max_workers=2) as executor:
future_string = executor.submit(
self._process_string_fine, fast_memory_items, info, custom_tags
)
future_tool = executor.submit(
self._process_tool_trajectory_fine, fast_memory_items, info
)

# Collect results
fine_memory_items_string_parser = future_string.result()
fine_memory_items_tool_trajectory_parser = future_tool.result()

fine_memory_items.extend(fine_memory_items_string_parser)
fine_memory_items.extend(fine_memory_items_tool_trajectory_parser)

# Part B: get fine multimodal items
Expand Down Expand Up @@ -658,13 +671,18 @@ def _process_transfer_multi_modal_data(
}

fine_memory_items = []
# Part A: call llm
fine_memory_items_string_parser = self._process_string_fine([raw_node], info, custom_tags)
fine_memory_items.extend(fine_memory_items_string_parser)
# Part A: call llm in parallel using thread pool
with ContextThreadPoolExecutor(max_workers=2) as executor:
future_string = executor.submit(
self._process_string_fine, [raw_node], info, custom_tags
)
future_tool = executor.submit(self._process_tool_trajectory_fine, [raw_node], info)

fine_memory_items_tool_trajectory_parser = self._process_tool_trajectory_fine(
[raw_node], info
)
# Collect results
fine_memory_items_string_parser = future_string.result()
fine_memory_items_tool_trajectory_parser = future_tool.result()

fine_memory_items.extend(fine_memory_items_string_parser)
fine_memory_items.extend(fine_memory_items_tool_trajectory_parser)

# Part B: get fine multimodal items
Expand Down
Loading