Skip to content
Merged
8 changes: 8 additions & 0 deletions src/memos/api/routers/server_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@
status_tracker = TaskStatusTracker(redis_client=redis_client)
embedder = components["embedder"]
graph_db = components["graph_db"]
vector_db = components["vector_db"]


# =============================================================================
Expand Down Expand Up @@ -359,6 +360,13 @@ def get_user_names_by_memory_ids(request: GetUserNamesByMemoryIdsRequest):
),
)
result = graph_db.get_user_names_by_memory_ids(memory_ids=request.memory_ids)
if vector_db:
prefs = []
for collection_name in ["explicit_preference", "implicit_preference"]:
prefs.extend(
vector_db.get_by_ids(collection_name=collection_name, ids=request.memory_ids)
)
result.update({pref.id: pref.payload.get("mem_cube_id", None) for pref in prefs})
return GetUserNamesByMemoryIdsResponse(
code=200,
message="Successfully",
Expand Down
2 changes: 1 addition & 1 deletion src/memos/memories/textual/preference.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def get_all(self) -> list[TextualMemoryItem]:
Returns:
list[TextualMemoryItem]: List of all memories.
"""
all_collections = self.vector_db.list_collections()
all_collections = ["explicit_preference", "implicit_preference"]
all_memories = {}
for collection_name in all_collections:
items = self.vector_db.get_all(collection_name)
Expand Down
8 changes: 4 additions & 4 deletions src/memos/memories/textual/simple_preference.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def get_with_collection_name(
return None
return TextualMemoryItem(
id=res.id,
memory=res.payload.get("dialog_str", ""),
memory=res.memory,
metadata=PreferenceTextualMemoryMetadata(**res.payload),
)
except Exception as e:
Expand All @@ -116,7 +116,7 @@ def get_by_ids_with_collection_name(
return [
TextualMemoryItem(
id=memo.id,
memory=memo.payload.get("dialog_str", ""),
memory=memo.memory,
metadata=PreferenceTextualMemoryMetadata(**memo.payload),
)
for memo in res
Expand All @@ -132,14 +132,14 @@ def get_all(self) -> list[TextualMemoryItem]:
Returns:
list[TextualMemoryItem]: List of all memories.
"""
all_collections = self.vector_db.list_collections()
all_collections = ["explicit_preference", "implicit_preference"]
all_memories = {}
for collection_name in all_collections:
items = self.vector_db.get_all(collection_name)
all_memories[collection_name] = [
TextualMemoryItem(
id=memo.id,
memory=memo.payload.get("dialog_str", ""),
memory=memo.memory,
metadata=PreferenceTextualMemoryMetadata(**memo.payload),
)
for memo in items
Expand Down
6 changes: 2 additions & 4 deletions src/memos/vec_dbs/milvus.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,14 +457,13 @@ def get_by_id(self, collection_name: str, id: str) -> MilvusVecDBItem | None:
return None

entity = results[0]
payload = {k: v for k, v in entity.items() if k not in ["id", "vector", "score"]}

return MilvusVecDBItem(
id=entity["id"],
memory=entity.get("memory"),
original_text=entity.get("original_text"),
vector=entity.get("vector"),
payload=payload,
payload=entity.get("payload", {}),
)

def get_by_ids(self, collection_name: str, ids: list[str]) -> list[MilvusVecDBItem]:
Expand All @@ -479,14 +478,13 @@ def get_by_ids(self, collection_name: str, ids: list[str]) -> list[MilvusVecDBIt

items = []
for entity in results:
payload = {k: v for k, v in entity.items() if k not in ["id", "vector", "score"]}
items.append(
MilvusVecDBItem(
id=entity["id"],
memory=entity.get("memory"),
original_text=entity.get("original_text"),
vector=entity.get("vector"),
payload=payload,
payload=entity.get("payload", {}),
)
)

Expand Down