Skip to content

Issue with thread deletion using adelete_thread method -- memort leak issue while deleting thread #51

Closed
@shubhamgajbhiye1994

Description

@shubhamgajbhiye1994

I am attaching my file code.


file 1: remove_memory.py [remove threads async using adelete_thread method]

import logging
from langgraph.graph.graph import CompiledGraph # For type hinting app

logger = logging.getLogger(name)

async def remove_thread_memory(app: CompiledGraph, thread_id: str) -> None:
"""
Removes the persisted memory (checkpoint) of a specific conversation thread
from the LangGraph application's checkpointer.

This is useful for explicitly clearing a conversation's history, for example,
if a user requests it or for maintenance purposes.

Args:
    app: The compiled LangGraph application instance. It must have a checkpointer configured.
    thread_id: The unique identifier of the conversation thread whose memory is to be removed.

Raises:
    AttributeError: If the app does not have a checkpointer or if the checkpointer
                    does not have a 'delete_thread' method.
    Exception: Any exception raised by the checkpointer's delete_thread method.
"""
logger.info(f"Attempting to remove memory for thread_id: {thread_id}")
if not hasattr(app, "checkpointer") or app.checkpointer is None:
    logger.error(
        f"Application has no checkpointer. Cannot remove memory for thread_id: {thread_id}."
    )
    # Depending on desired behavior, could raise an error or just log and return.
    # Raising an error makes the issue more visible if this function is called incorrectly.
    raise AttributeError("Application checkpointer is not configured.")

if not hasattr(app.checkpointer, "delete_thread"):
    logger.error(
        f"Checkpointer of type {type(app.checkpointer).__name__} does not support 'delete_thread'. Cannot remove memory for thread_id: {thread_id}."
    )
    raise AttributeError("Checkpointer does not support 'delete_thread' method.")

try:
    # Remove the memory of the specified thread using the checkpointer
    await app.checkpointer.adelete_thread(thread_id)
    logger.info(f"Successfully removed memory for thread_id: {thread_id}")
except Exception as e:
    logger.error(
        f"Error removing memory for thread_id {thread_id}: {e}", exc_info=True
    )
    # Re-raise the exception to allow the caller (e.g., API endpoint) to handle it
    raise

file 2: langgraph redis checkpointer adder in graph

import logging
from langgraph.checkpoint.memory import InMemorySaver
import asyncio
from langgraph.checkpoint.redis.aio import AsyncRedisSaver
from langgraph.graph import START, StateGraph # END is used in memory_nodes
from redis.exceptions import ResponseError
from configs.config import settings

Import all necessary nodes for the graph

from ..nodes.memory_nodes import (
maintain_messages_node,
remove_initial_input_msg_node,
add_reconstructed_msg_node,
remove_reconstructed_msg_node,
add_initial_input_msg_node,
add_sql_agent_msg_node,
)
from ..nodes.query_reconstruct import reconstruct_query_node
from langgraph.checkpoint.base import RunnableConfig
from ..nodes.sql_agent import sql_agent_node
from ..nodes.states import AgentState # The definition of the graph's state

logger = logging.getLogger(name)
logger.info("Defining casai-analytics LangGraph workflow...")

_graph_instance = None
_initialization_lock = asyncio.Lock()

async def init_graph():
global _graph_instance
async with _initialization_lock:
if _graph_instance is None:
logger.info("Initializing LangGraph...")

        # Set the REDIS_URL environment variable
        import os
        os.environ["REDIS_URL"] = settings.REDIS_URL
        logger.info(f"Using Redis URL: {settings.REDIS_URL}")

        # Create Redis checkpointer
        try:
            checkpointer = await AsyncRedisSaver.from_conn_string(redis_url=settings.REDIS_URL).__aenter__()
            logger.info(f"Redis checkpointer initialized successfully.")
            
            # Safely handle index creation
            try:
                # await checkpointer.asetup() # comment once deployed
                pass
            except ResponseError as e:
                if "Index already exists" in str(e):
                    logger.warning("Redis index already exists. Skipping creation.")
                else:
                    logger.error("Unexpected Redis error during asetup", exc_info=True)
                    raise
        except Exception as e:
            logger.error(f"Failed to initialize Redis checkpointer: {e}", exc_info=True)
            logger.warning("Falling back to InMemorySaver...")
            
            # Fallback to InMemorySaver if Redis connection fails
            checkpointer = InMemorySaver()
            logger.warning(
                "InMemorySaver is being used for checkpointing. "
                "Conversation history will be lost on application restart."
            )
        
        builder = StateGraph(AgentState)
        builder.add_node("maintain_messages_node", maintain_messages_node)
        # builder.add_node("reconstruct_query_node", reconstruct_query_node)
        builder.add_node("sql_agent_node", sql_agent_node)
        # builder.add_node("remove_initial_input_msg_node", remove_initial_input_msg_node)
        # builder.add_node("add_reconstructed_msg_node", add_reconstructed_msg_node)
        # builder.add_node("remove_reconstructed_msg_node", remove_reconstructed_msg_node)
        # builder.add_node("add_initial_input_msg_node", add_initial_input_msg_node)
        builder.add_node("add_sql_agent_msg_node", add_sql_agent_msg_node)
        builder.add_edge(START, "maintain_messages_node")

        _graph_instance = builder.compile(checkpointer=checkpointer)
        logger.info("LangGraph initialized.")
    else:
        logger.debug("LangGraph already initialized.")

async def get_graph():
global _graph_instance
if _graph_instance is None:
raise RuntimeError("LangGraph not initialized. Call init_graph() first.")
return _graph_instance

            file 3: app.py [api for calling]


import logging  # Keep for logger instance
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Any

# Import custom logging setup and apply it
from configs.logging_config import setup_logging, TokenUsageLoggerCallback
from agent_core.workflow.workflow import init_graph, get_graph ##
from agent_core.workflow.remove_memory import remove_thread_memory
from configs.config import settings  

setup_logging()  # Initialize logging configuration

# Get a logger instance for this module after setup_logging has run
logger = logging.getLogger(__name__)
logger.info(f"{settings.API_TITLE} starting...")  # Use API_TITLE from settings
logger.info(
    f"Current effective log level for {__name__}: {logging.getLevelName(logger.getEffectiveLevel())}"
)
logger.info(f"API Version: {settings.API_VERSION}")


# Initialize FastAPI app using metadata from settings
api = FastAPI(
    title=settings.API_TITLE,
    description=settings.API_DESCRIPTION,
    version=settings.API_VERSION,
    docs_url="/docs"
)


class QueryRequest(BaseModel):
    """Request model for the /ask endpoint."""

    query: str
    thread_id: str


class ThreadRequest(BaseModel):
    """Request model for the /thread_memory endpoint."""

    thread_id: str


@api.on_event("startup")
async def startup_event():
    await init_graph()
    api.state.langgraph = await get_graph()

# Health check endpoint
@api.get("/health", summary="Health Check", tags=["General"])
async def health_check():
    """Checks if the API is running and healthy."""
    logger.info("Health check endpoint called")
    return {
        "status": "healthy",
        "log_level": logging.getLevelName(logger.getEffectiveLevel()),
    }


@api.post("/ask/", summary="Ask the Casai Analyst", tags=["Agent"])
async def ask_endpoint(request: QueryRequest) -> Any:
    """
    Submits a query to the Casai Data Analyst agent.
    Requires a natural language query and a unique thread_id for the conversation.
    """
    logger.info(
        f'Received query for thread_id: {request.thread_id}. Query (first 50 chars): "{request.query[:50]}..."'
    )

    # Potentially add TokenUsageLoggerCallback here if desired for per-request detailed logging
    callback_handler = TokenUsageLoggerCallback()
    config = {"configurable": {"thread_id": request.thread_id}, "callbacks": [callback_handler]}
    messages_input = {"messages": [("user", request.query)]}
    langgraph_app = api.state.langgraph
    try:
        # Using ainvoke for asynchronous operation with FastAPI
        answer = await langgraph_app.ainvoke(messages_input, config=config)

        logger.info(f"Successfully processed query for thread_id: {request.thread_id}")
        # Consider logging parts of the answer if it's not too verbose or sensitive
        # logger.debug(f"Answer for thread_id {request.thread_id}: {answer}")
        return answer
    except Exception as e:
        logger.error(
            f"Error processing query for thread_id {request.thread_id}: {e}",
            exc_info=True,
        )
        raise HTTPException(
            status_code=500,
            detail=f"An error occurred while processing your query: {str(e)}",
        )


@api.delete("/thread_memory/", summary="Delete Thread Memory", tags=["Agent"])
async def delete_thread_memory_endpoint(request: ThreadRequest):
    """
    Deletes the memory associated with a specific conversation thread_id.
    """
    logger.info(f"Received request to delete memory for thread_id: {request.thread_id}")
    try:
        # Note: remove_thread_memory is a synchronous function.
        # If it were I/O bound, run_in_threadpool would be advisable.
        # For now, assuming it's quick enough.
        remove_thread_memory(api.state.langgraph, request.thread_id)
        logger.info(f"Successfully deleted memory for thread_id: {request.thread_id}")
        return {
            "status": "success",
            "message": f"Memory for thread_id {request.thread_id} deleted.",
        }
    except Exception as e:
        logger.error(
            f"Error deleting memory for thread_id {request.thread_id}: {e}",
            exc_info=True,
        )
        raise HTTPException(
            status_code=500,
            detail=f"An error occurred while deleting thread memory: {str(e)}",
        )

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions