|
| 1 | +import asyncio |
| 2 | +from thread_manager import ThreadManager |
| 3 | +from assistant_manager import AssistantManager |
| 4 | + |
| 5 | + |
| 6 | +class ChatSession: |
| 7 | + def __init__(self, thread_manager: ThreadManager, assistant_manager: AssistantManager, assistant_name: str, |
| 8 | + model_name: str, assistant_id: str = None, thread_id: str = None): |
| 9 | + self.thread_manager = thread_manager |
| 10 | + self.assistant_manager = assistant_manager |
| 11 | + self.assistant_name = assistant_name |
| 12 | + self.model_name = model_name |
| 13 | + self.assistant_id = assistant_id |
| 14 | + self.thread_id = thread_id |
| 15 | + |
| 16 | + async def start_session(self): |
| 17 | + if self.thread_id is None: |
| 18 | + # Get or create a thread |
| 19 | + self.thread_id = await self.get_or_create_thread() |
| 20 | + |
| 21 | + if self.assistant_id is None: |
| 22 | + # Find or create an assistant |
| 23 | + self.assistant_id = await self.find_or_create_assistant( |
| 24 | + name=self.assistant_name, |
| 25 | + model=self.model_name |
| 26 | + ) |
| 27 | + |
| 28 | + # Display existing chat history |
| 29 | + await self.display_chat_history() |
| 30 | + |
| 31 | + prev_messages = await self.thread_manager.list_messages(self.thread_id) |
| 32 | + if prev_messages is None: |
| 33 | + print("An error occurred while retrieving messages.") |
| 34 | + return |
| 35 | + |
| 36 | + # Start the chat loop |
| 37 | + await self.chat_loop() |
| 38 | + |
| 39 | + async def chat_loop(self): |
| 40 | + try: |
| 41 | + while True: |
| 42 | + user_input = input("You: ") |
| 43 | + if user_input.lower() in ['exit', 'quit', 'bye']: |
| 44 | + break |
| 45 | + if user_input.lower() in ['/delete', '/clear']: |
| 46 | + await self.thread_manager.delete_thread(self.thread_id) |
| 47 | + self.thread_id = await self.get_or_create_thread() |
| 48 | + continue |
| 49 | + |
| 50 | + response = await self.get_latest_response(user_input) |
| 51 | + |
| 52 | + if response: |
| 53 | + print("Assistant:", response) |
| 54 | + |
| 55 | + finally: |
| 56 | + print(f"Session ended") |
| 57 | + |
| 58 | + async def get_or_create_thread(self): |
| 59 | + data = self.thread_manager.read_thread_data() |
| 60 | + thread_id = data.get('thread_id') |
| 61 | + if not thread_id: |
| 62 | + thread = await self.thread_manager.create_thread(messages=[]) |
| 63 | + thread_id = thread.id |
| 64 | + self.thread_manager.save_thread_data(thread_id) |
| 65 | + return thread_id |
| 66 | + |
| 67 | + async def find_or_create_assistant(self, name: str, model: str): |
| 68 | + """ |
| 69 | + Finds an existing assistant by name or creates a new one. |
| 70 | +
|
| 71 | + Args: |
| 72 | + name (str): The name of the assistant. |
| 73 | + model (str): The model ID for the assistant. |
| 74 | +
|
| 75 | + Returns: |
| 76 | + str: The ID of the found or created assistant. |
| 77 | + """ |
| 78 | + assistant_id = await self.assistant_manager.get_assistant_id_by_name(name) |
| 79 | + if not assistant_id: |
| 80 | + assistant = await self.assistant_manager.create_assistant(name=name, |
| 81 | + model=model, |
| 82 | + instructions="", |
| 83 | + tools=[{"type": "retrieval"}] |
| 84 | + ) |
| 85 | + assistant_id = assistant.id |
| 86 | + return assistant_id |
| 87 | + |
| 88 | + async def send_message(self, content): |
| 89 | + return await self.thread_manager.send_message(self.thread_id, content) |
| 90 | + |
| 91 | + async def display_chat_history(self): |
| 92 | + messages = await self.thread_manager.list_messages(self.thread_id) |
| 93 | + if messages is None: |
| 94 | + return |
| 95 | + print(messages) |
| 96 | + for message in reversed(messages.data): |
| 97 | + role = message.role |
| 98 | + content = message.content[0].text.value # Assuming message content is structured this way |
| 99 | + print(f"{role.title()}: {content}") |
| 100 | + |
| 101 | + async def get_latest_response(self, user_input): |
| 102 | + # Send the user message |
| 103 | + await self.send_message(user_input) |
| 104 | + |
| 105 | + # Create a new run for the assistant to respond |
| 106 | + await self.create_run() |
| 107 | + |
| 108 | + # Wait for the assistant's response |
| 109 | + await self.wait_for_assistant() |
| 110 | + |
| 111 | + # Retrieve the latest response |
| 112 | + return await self.retrieve_latest_response() |
| 113 | + |
| 114 | + async def create_run(self): |
| 115 | + return await self.thread_manager.create_run(self.thread_id, self.assistant_id) |
| 116 | + |
| 117 | + async def wait_for_assistant(self): |
| 118 | + while True: |
| 119 | + runs = await self.thread_manager.list_runs(self.thread_id) |
| 120 | + latest_run = runs.data[0] |
| 121 | + if latest_run.status in ["completed", "failed"]: |
| 122 | + break |
| 123 | + await asyncio.sleep(2) # Wait for 2 seconds before checking again |
| 124 | + |
| 125 | + async def retrieve_latest_response(self): |
| 126 | + response = await self.thread_manager.list_messages(self.thread_id) |
| 127 | + for message in response.data: |
| 128 | + if message.role == "assistant": |
| 129 | + return message.content[0].text.value |
| 130 | + return None |
0 commit comments