Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 74 additions & 2 deletions src/memos/llms/openai.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import hashlib
import json

from collections.abc import Generator
from typing import ClassVar

import openai

Expand All @@ -13,11 +17,44 @@


class OpenAILLM(BaseLLM):
"""OpenAI LLM class."""
"""OpenAI LLM class with singleton pattern."""

_instances: ClassVar[dict] = {} # Class variable to store instances

def __new__(cls, config: OpenAILLMConfig) -> "OpenAILLM":
config_hash = cls._get_config_hash(config)

if config_hash not in cls._instances:
logger.info(f"Creating new OpenAI LLM instance for config hash: {config_hash}")
instance = super().__new__(cls)
cls._instances[config_hash] = instance
else:
logger.info(f"Reusing existing OpenAI LLM instance for config hash: {config_hash}")

return cls._instances[config_hash]

def __init__(self, config: OpenAILLMConfig):
# Avoid duplicate initialization
if hasattr(self, "_initialized"):
return

self.config = config
self.client = openai.Client(api_key=config.api_key, base_url=config.api_base)
self._initialized = True
logger.info("OpenAI LLM instance initialized")

@classmethod
def _get_config_hash(cls, config: OpenAILLMConfig) -> str:
"""Generate hash value of configuration"""
config_dict = config.model_dump()
config_str = json.dumps(config_dict, sort_keys=True)
return hashlib.md5(config_str.encode()).hexdigest()

@classmethod
def clear_cache(cls):
"""Clear all cached instances"""
cls._instances.clear()
logger.info("OpenAI LLM instance cache cleared")

def generate(self, messages: MessageList) -> str:
"""Generate a response from OpenAI LLM."""
Expand Down Expand Up @@ -71,15 +108,50 @@ def generate_stream(self, messages: MessageList, **kwargs) -> Generator[str, Non


class AzureLLM(BaseLLM):
"""Azure OpenAI LLM class."""
"""Azure OpenAI LLM class with singleton pattern."""

_instances: ClassVar[dict] = {} # Class variable to store instances

def __new__(cls, config: AzureLLMConfig):
# Generate hash value of config as cache key
config_hash = cls._get_config_hash(config)

if config_hash not in cls._instances:
logger.info(f"Creating new Azure LLM instance for config hash: {config_hash}")
instance = super().__new__(cls)
cls._instances[config_hash] = instance
else:
logger.info(f"Reusing existing Azure LLM instance for config hash: {config_hash}")

return cls._instances[config_hash]

def __init__(self, config: AzureLLMConfig):
# Avoid duplicate initialization
if hasattr(self, "_initialized"):
return

self.config = config
self.client = openai.AzureOpenAI(
azure_endpoint=config.base_url,
api_version=config.api_version,
api_key=config.api_key,
)
self._initialized = True
logger.info("Azure LLM instance initialized")

@classmethod
def _get_config_hash(cls, config: AzureLLMConfig) -> str:
"""Generate hash value of configuration"""
# Convert config to dict and sort to ensure consistency
config_dict = config.model_dump()
config_str = json.dumps(config_dict, sort_keys=True)
return hashlib.md5(config_str.encode()).hexdigest()

@classmethod
def clear_cache(cls):
"""Clear all cached instances"""
cls._instances.clear()
logger.info("Azure LLM instance cache cleared")

def generate(self, messages: MessageList) -> str:
"""Generate a response from Azure OpenAI LLM."""
Expand Down
4 changes: 3 additions & 1 deletion src/memos/mem_cube/general.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import time

from typing import Literal

Expand All @@ -23,11 +24,13 @@ class GeneralMemCube(BaseMemCube):
def __init__(self, config: GeneralMemCubeConfig):
"""Initialize the MemCube with a configuration."""
self.config = config
time_start = time.time()
self._text_mem: BaseTextMemory | None = (
MemoryFactory.from_config(config.text_mem)
if config.text_mem.backend != "uninitialized"
else None
)
logger.info(f"init_text_mem in {time.time() - time_start} seconds")
self._act_mem: BaseActMemory | None = (
MemoryFactory.from_config(config.act_mem)
if config.act_mem.backend != "uninitialized"
Expand Down Expand Up @@ -137,7 +140,6 @@ def init_from_dir(
if default_config is not None:
config = merge_config_with_default(config, default_config)
logger.info(f"Applied default config to cube {config.cube_id}")

mem_cube = GeneralMemCube(config)
mem_cube.load(dir, memory_types)
return mem_cube
Expand Down
12 changes: 8 additions & 4 deletions src/memos/mem_os/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,14 +483,14 @@ def register_mem_cube(
self.mem_cubes[mem_cube_id] = mem_cube_name_or_path
logger.info(f"register new cube {mem_cube_id} for user {target_user_id}")
elif os.path.exists(mem_cube_name_or_path):
self.mem_cubes[mem_cube_id] = GeneralMemCube.init_from_dir(mem_cube_name_or_path)
mem_cube_obj = GeneralMemCube.init_from_dir(mem_cube_name_or_path)
self.mem_cubes[mem_cube_id] = mem_cube_obj
else:
logger.warning(
f"MemCube {mem_cube_name_or_path} does not exist, try to init from remote repo."
)
self.mem_cubes[mem_cube_id] = GeneralMemCube.init_from_remote_repo(
mem_cube_name_or_path
)
mem_cube_obj = GeneralMemCube.init_from_remote_repo(mem_cube_name_or_path)
self.mem_cubes[mem_cube_id] = mem_cube_obj
# Check if cube already exists in database
existing_cube = self.user_manager.get_cube(mem_cube_id)

Expand Down Expand Up @@ -592,9 +592,13 @@ def search(
install_cube_ids = user_cube_ids
# create exist dict in mem_cubes and avoid one search slow
tmp_mem_cubes = {}
time_start_cube_get = time.time()
for mem_cube_id in install_cube_ids:
if mem_cube_id in self.mem_cubes:
tmp_mem_cubes[mem_cube_id] = self.mem_cubes.get(mem_cube_id)
logger.info(
f"time search: transform cube time user_id: {target_user_id} time is: {time.time() - time_start_cube_get}"
)

for mem_cube_id, mem_cube in tmp_mem_cubes.items():
if (
Expand Down
4 changes: 4 additions & 0 deletions src/memos/mem_os/product.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,10 +775,14 @@ def register_mem_cube(
return

# Create MemCube from path
time_start = time.time()
if os.path.exists(mem_cube_name_or_path):
mem_cube = GeneralMemCube.init_from_dir(
mem_cube_name_or_path, memory_types, default_config
)
logger.info(
f"time register_mem_cube: init_from_dir time is: {time.time() - time_start}"
)
else:
logger.warning(
f"MemCube {mem_cube_name_or_path} does not exist, try to init from remote repo."
Expand Down
21 changes: 19 additions & 2 deletions src/memos/memories/textual/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import shutil
import tempfile
import time

from datetime import datetime
from pathlib import Path
Expand Down Expand Up @@ -32,15 +33,28 @@ class TreeTextMemory(BaseTextMemory):

def __init__(self, config: TreeTextMemoryConfig):
"""Initialize memory with the given configuration."""
time_start = time.time()
self.config: TreeTextMemoryConfig = config
self.extractor_llm: OpenAILLM | OllamaLLM | AzureLLM = LLMFactory.from_config(
config.extractor_llm
)
logger.info(f"time init: extractor_llm time is: {time.time() - time_start}")

time_start_ex = time.time()
self.dispatcher_llm: OpenAILLM | OllamaLLM | AzureLLM = LLMFactory.from_config(
config.dispatcher_llm
)
logger.info(f"time init: dispatcher_llm time is: {time.time() - time_start_ex}")

time_start_em = time.time()
self.embedder: OllamaEmbedder = EmbedderFactory.from_config(config.embedder)
logger.info(f"time init: embedder time is: {time.time() - time_start_em}")

time_start_gs = time.time()
self.graph_store: Neo4jGraphDB = GraphStoreFactory.from_config(config.graph_db)
logger.info(f"time init: graph_store time is: {time.time() - time_start_gs}")

time_start_rr = time.time()
if config.reranker is None:
default_cfg = RerankerConfigFactory.model_validate(
{
Expand All @@ -54,9 +68,10 @@ def __init__(self, config: TreeTextMemoryConfig):
self.reranker = RerankerFactory.from_config(default_cfg)
else:
self.reranker = RerankerFactory.from_config(config.reranker)

logger.info(f"time init: reranker time is: {time.time() - time_start_rr}")
self.is_reorganize = config.reorganize

time_start_mm = time.time()
self.memory_manager: MemoryManager = MemoryManager(
self.graph_store,
self.embedder,
Expand All @@ -69,7 +84,8 @@ def __init__(self, config: TreeTextMemoryConfig):
},
is_reorganize=self.is_reorganize,
)

logger.info(f"time init: memory_manager time is: {time.time() - time_start_mm}")
time_start_ir = time.time()
# Create internet retriever if configured
self.internet_retriever = None
if config.internet_retriever is not None:
Expand All @@ -81,6 +97,7 @@ def __init__(self, config: TreeTextMemoryConfig):
)
else:
logger.info("No internet retriever configured")
logger.info(f"time init: internet_retriever time is: {time.time() - time_start_ir}")

def add(self, memories: list[TextualMemoryItem | dict[str, Any]]) -> list[str]:
"""Add memories.
Expand Down
Loading
Loading