|
10 | 10 | # from google.genai import AsyncClient as AsyncGeminiClient # google.genai.Client can be used with an async transport |
11 | 11 |
|
12 | 12 | from memfuse import Memory |
13 | | -from memfuse.prompts import PromptContext |
| 13 | +from memfuse.prompts import PromptContext, PromptFormatter |
14 | 14 |
|
15 | 15 | # Set up logger for this module |
16 | 16 | logger = logging.getLogger(__name__) |
@@ -133,26 +133,41 @@ def _instrument_generate_content_sync( |
133 | 133 | latest_user_query_message = gemini_query_messages[-1] |
134 | 134 |
|
135 | 135 | retrieved_memories = None |
136 | | - retrieved_chat_history = None |
| 136 | + chat_history = None |
137 | 137 |
|
138 | | - if latest_user_query_message: |
139 | | - query_response = memory.query_session(latest_user_query_message["content"]) |
140 | | - retrieved_memories = query_response["data"]["results"] if query_response else None |
141 | | - |
| 138 | + if latest_user_query_message: |
142 | 139 | # Get chat history |
143 | 140 | max_chat_history = memory.max_chat_history |
144 | | - chat_history_response = memory.list_messages(limit=max_chat_history) |
145 | | - if chat_history_response and chat_history_response.get("data", {}).get("messages"): |
146 | | - retrieved_chat_history = [ |
147 | | - {"role": msg["role"], "content": msg["content"]} |
148 | | - for msg in chat_history_response["data"]["messages"][::-1] |
149 | | - ] |
| 141 | + |
| 142 | + in_buffer_chat_history = memory.list_messages( |
| 143 | + limit=max_chat_history, |
| 144 | + buffer_only=True, |
| 145 | + ) |
| 146 | + |
| 147 | + in_buffer_messages_length = len(in_buffer_chat_history["data"]["messages"]) |
| 148 | + |
| 149 | + if in_buffer_messages_length < max_chat_history: |
| 150 | + in_db_chat_history = memory.list_messages( |
| 151 | + limit=max_chat_history - in_buffer_messages_length, |
| 152 | + buffer_only=False, |
| 153 | + ) |
| 154 | + else: |
| 155 | + in_db_chat_history = [] |
| 156 | + |
| 157 | + chat_history = [{"role": message["role"], "content": message["content"]} for message in in_db_chat_history["data"]["messages"][::-1]] + [{"role": message["role"], "content": message["content"]} for message in in_buffer_chat_history["data"]["messages"][::-1]] |
| 158 | + |
| 159 | + # Retrieve memories |
| 160 | + query_string = PromptFormatter.messages_to_query(chat_history + gemini_query_messages) |
| 161 | + query_response = memory.query_session(query_string) |
| 162 | + retrieved_memories = query_response["data"]["results"] if query_response else None |
| 163 | + |
| 164 | + logger.info(f"Retrieved memories: {retrieved_memories}") |
150 | 165 |
|
151 | 166 | # 3. Compose the prompt context for PromptFormatter |
152 | 167 | prompt_context = PromptContext( |
153 | 168 | query_messages=gemini_query_messages, |
154 | 169 | retrieved_memories=retrieved_memories, |
155 | | - retrieved_chat_history=retrieved_chat_history, |
| 170 | + retrieved_chat_history=chat_history, |
156 | 171 | max_chat_history=memory.max_chat_history, |
157 | 172 | ) |
158 | 173 |
|
@@ -191,19 +206,33 @@ async def _instrument_generate_content_async( |
191 | 206 | retrieved_memories = None |
192 | 207 | retrieved_chat_history = None |
193 | 208 |
|
194 | | - if latest_user_query_message: |
195 | | - # Properly await async memory operations |
196 | | - query_response = await memory.query_session(latest_user_query_message["content"]) |
197 | | - retrieved_memories = query_response["data"]["results"] if query_response else None |
198 | | - |
| 209 | + if latest_user_query_message: |
199 | 210 | # Get chat history |
200 | 211 | max_chat_history = memory.max_chat_history |
201 | | - chat_history_response = await memory.list_messages(limit=max_chat_history) |
202 | | - if chat_history_response and chat_history_response.get("data", {}).get("messages"): |
203 | | - retrieved_chat_history = [ |
204 | | - {"role": msg["role"], "content": msg["content"]} |
205 | | - for msg in chat_history_response["data"]["messages"][::-1] |
206 | | - ] |
| 212 | + |
| 213 | + in_buffer_chat_history = await memory.list_messages( |
| 214 | + limit=max_chat_history, |
| 215 | + buffer_only=True, |
| 216 | + ) |
| 217 | + |
| 218 | + in_buffer_messages_length = len(in_buffer_chat_history["data"]["messages"]) |
| 219 | + |
| 220 | + if in_buffer_messages_length < max_chat_history: |
| 221 | + in_db_chat_history = await memory.list_messages( |
| 222 | + limit=max_chat_history - in_buffer_messages_length, |
| 223 | + buffer_only=False, |
| 224 | + ) |
| 225 | + else: |
| 226 | + in_db_chat_history = [] |
| 227 | + |
| 228 | + retrieved_chat_history = [{"role": message["role"], "content": message["content"]} for message in in_db_chat_history["data"]["messages"][::-1]] + [{"role": message["role"], "content": message["content"]} for message in in_buffer_chat_history["data"]["messages"][::-1]] |
| 229 | + |
| 230 | + # Retrieve memories |
| 231 | + query_string = PromptFormatter.messages_to_query(retrieved_chat_history + gemini_query_messages) |
| 232 | + query_response = await memory.query_session(query_string) |
| 233 | + retrieved_memories = query_response["data"]["results"] if query_response else None |
| 234 | + |
| 235 | + logger.info(f"Retrieved memories: {retrieved_memories}") |
207 | 236 |
|
208 | 237 | # 3. Compose the prompt context for PromptFormatter |
209 | 238 | prompt_context = PromptContext( |
|
0 commit comments