Description
I am attaching my file code.
file 1: remove_memory.py [remove threads async using adelete_thread method]
import logging
from langgraph.graph.graph import CompiledGraph # For type hinting app
logger = logging.getLogger(name)
async def remove_thread_memory(app: CompiledGraph, thread_id: str) -> None:
"""
Removes the persisted memory (checkpoint) of a specific conversation thread
from the LangGraph application's checkpointer.
This is useful for explicitly clearing a conversation's history, for example,
if a user requests it or for maintenance purposes.
Args:
app: The compiled LangGraph application instance. It must have a checkpointer configured.
thread_id: The unique identifier of the conversation thread whose memory is to be removed.
Raises:
AttributeError: If the app does not have a checkpointer or if the checkpointer
does not have a 'delete_thread' method.
Exception: Any exception raised by the checkpointer's delete_thread method.
"""
logger.info(f"Attempting to remove memory for thread_id: {thread_id}")
if not hasattr(app, "checkpointer") or app.checkpointer is None:
logger.error(
f"Application has no checkpointer. Cannot remove memory for thread_id: {thread_id}."
)
# Depending on desired behavior, could raise an error or just log and return.
# Raising an error makes the issue more visible if this function is called incorrectly.
raise AttributeError("Application checkpointer is not configured.")
if not hasattr(app.checkpointer, "delete_thread"):
logger.error(
f"Checkpointer of type {type(app.checkpointer).__name__} does not support 'delete_thread'. Cannot remove memory for thread_id: {thread_id}."
)
raise AttributeError("Checkpointer does not support 'delete_thread' method.")
try:
# Remove the memory of the specified thread using the checkpointer
await app.checkpointer.adelete_thread(thread_id)
logger.info(f"Successfully removed memory for thread_id: {thread_id}")
except Exception as e:
logger.error(
f"Error removing memory for thread_id {thread_id}: {e}", exc_info=True
)
# Re-raise the exception to allow the caller (e.g., API endpoint) to handle it
raise
file 2: langgraph redis checkpointer adder in graph
import logging
from langgraph.checkpoint.memory import InMemorySaver
import asyncio
from langgraph.checkpoint.redis.aio import AsyncRedisSaver
from langgraph.graph import START, StateGraph # END is used in memory_nodes
from redis.exceptions import ResponseError
from configs.config import settings
Import all necessary nodes for the graph
from ..nodes.memory_nodes import (
maintain_messages_node,
remove_initial_input_msg_node,
add_reconstructed_msg_node,
remove_reconstructed_msg_node,
add_initial_input_msg_node,
add_sql_agent_msg_node,
)
from ..nodes.query_reconstruct import reconstruct_query_node
from langgraph.checkpoint.base import RunnableConfig
from ..nodes.sql_agent import sql_agent_node
from ..nodes.states import AgentState # The definition of the graph's state
logger = logging.getLogger(name)
logger.info("Defining casai-analytics LangGraph workflow...")
_graph_instance = None
_initialization_lock = asyncio.Lock()
async def init_graph():
global _graph_instance
async with _initialization_lock:
if _graph_instance is None:
logger.info("Initializing LangGraph...")
# Set the REDIS_URL environment variable
import os
os.environ["REDIS_URL"] = settings.REDIS_URL
logger.info(f"Using Redis URL: {settings.REDIS_URL}")
# Create Redis checkpointer
try:
checkpointer = await AsyncRedisSaver.from_conn_string(redis_url=settings.REDIS_URL).__aenter__()
logger.info(f"Redis checkpointer initialized successfully.")
# Safely handle index creation
try:
# await checkpointer.asetup() # comment once deployed
pass
except ResponseError as e:
if "Index already exists" in str(e):
logger.warning("Redis index already exists. Skipping creation.")
else:
logger.error("Unexpected Redis error during asetup", exc_info=True)
raise
except Exception as e:
logger.error(f"Failed to initialize Redis checkpointer: {e}", exc_info=True)
logger.warning("Falling back to InMemorySaver...")
# Fallback to InMemorySaver if Redis connection fails
checkpointer = InMemorySaver()
logger.warning(
"InMemorySaver is being used for checkpointing. "
"Conversation history will be lost on application restart."
)
builder = StateGraph(AgentState)
builder.add_node("maintain_messages_node", maintain_messages_node)
# builder.add_node("reconstruct_query_node", reconstruct_query_node)
builder.add_node("sql_agent_node", sql_agent_node)
# builder.add_node("remove_initial_input_msg_node", remove_initial_input_msg_node)
# builder.add_node("add_reconstructed_msg_node", add_reconstructed_msg_node)
# builder.add_node("remove_reconstructed_msg_node", remove_reconstructed_msg_node)
# builder.add_node("add_initial_input_msg_node", add_initial_input_msg_node)
builder.add_node("add_sql_agent_msg_node", add_sql_agent_msg_node)
builder.add_edge(START, "maintain_messages_node")
_graph_instance = builder.compile(checkpointer=checkpointer)
logger.info("LangGraph initialized.")
else:
logger.debug("LangGraph already initialized.")
async def get_graph():
global _graph_instance
if _graph_instance is None:
raise RuntimeError("LangGraph not initialized. Call init_graph() first.")
return _graph_instance
file 3: app.py [api for calling]
import logging # Keep for logger instance
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Any
# Import custom logging setup and apply it
from configs.logging_config import setup_logging, TokenUsageLoggerCallback
from agent_core.workflow.workflow import init_graph, get_graph ##
from agent_core.workflow.remove_memory import remove_thread_memory
from configs.config import settings
setup_logging() # Initialize logging configuration
# Get a logger instance for this module after setup_logging has run
logger = logging.getLogger(__name__)
logger.info(f"{settings.API_TITLE} starting...") # Use API_TITLE from settings
logger.info(
f"Current effective log level for {__name__}: {logging.getLevelName(logger.getEffectiveLevel())}"
)
logger.info(f"API Version: {settings.API_VERSION}")
# Initialize FastAPI app using metadata from settings
api = FastAPI(
title=settings.API_TITLE,
description=settings.API_DESCRIPTION,
version=settings.API_VERSION,
docs_url="/docs"
)
class QueryRequest(BaseModel):
"""Request model for the /ask endpoint."""
query: str
thread_id: str
class ThreadRequest(BaseModel):
"""Request model for the /thread_memory endpoint."""
thread_id: str
@api.on_event("startup")
async def startup_event():
await init_graph()
api.state.langgraph = await get_graph()
# Health check endpoint
@api.get("/health", summary="Health Check", tags=["General"])
async def health_check():
"""Checks if the API is running and healthy."""
logger.info("Health check endpoint called")
return {
"status": "healthy",
"log_level": logging.getLevelName(logger.getEffectiveLevel()),
}
@api.post("/ask/", summary="Ask the Casai Analyst", tags=["Agent"])
async def ask_endpoint(request: QueryRequest) -> Any:
"""
Submits a query to the Casai Data Analyst agent.
Requires a natural language query and a unique thread_id for the conversation.
"""
logger.info(
f'Received query for thread_id: {request.thread_id}. Query (first 50 chars): "{request.query[:50]}..."'
)
# Potentially add TokenUsageLoggerCallback here if desired for per-request detailed logging
callback_handler = TokenUsageLoggerCallback()
config = {"configurable": {"thread_id": request.thread_id}, "callbacks": [callback_handler]}
messages_input = {"messages": [("user", request.query)]}
langgraph_app = api.state.langgraph
try:
# Using ainvoke for asynchronous operation with FastAPI
answer = await langgraph_app.ainvoke(messages_input, config=config)
logger.info(f"Successfully processed query for thread_id: {request.thread_id}")
# Consider logging parts of the answer if it's not too verbose or sensitive
# logger.debug(f"Answer for thread_id {request.thread_id}: {answer}")
return answer
except Exception as e:
logger.error(
f"Error processing query for thread_id {request.thread_id}: {e}",
exc_info=True,
)
raise HTTPException(
status_code=500,
detail=f"An error occurred while processing your query: {str(e)}",
)
@api.delete("/thread_memory/", summary="Delete Thread Memory", tags=["Agent"])
async def delete_thread_memory_endpoint(request: ThreadRequest):
"""
Deletes the memory associated with a specific conversation thread_id.
"""
logger.info(f"Received request to delete memory for thread_id: {request.thread_id}")
try:
# Note: remove_thread_memory is a synchronous function.
# If it were I/O bound, run_in_threadpool would be advisable.
# For now, assuming it's quick enough.
remove_thread_memory(api.state.langgraph, request.thread_id)
logger.info(f"Successfully deleted memory for thread_id: {request.thread_id}")
return {
"status": "success",
"message": f"Memory for thread_id {request.thread_id} deleted.",
}
except Exception as e:
logger.error(
f"Error deleting memory for thread_id {request.thread_id}: {e}",
exc_info=True,
)
raise HTTPException(
status_code=500,
detail=f"An error occurred while deleting thread memory: {str(e)}",
)