Skip to content

Commit bac34d8

Browse files
committed
fix: chat session object
1 parent 2d94cd4 commit bac34d8

File tree

1 file changed

+130
-0
lines changed

1 file changed

+130
-0
lines changed

chat_session.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
import asyncio
2+
from thread_manager import ThreadManager
3+
from assistant_manager import AssistantManager
4+
5+
6+
class ChatSession:
7+
def __init__(self, thread_manager: ThreadManager, assistant_manager: AssistantManager, assistant_name: str,
8+
model_name: str, assistant_id: str = None, thread_id: str = None):
9+
self.thread_manager = thread_manager
10+
self.assistant_manager = assistant_manager
11+
self.assistant_name = assistant_name
12+
self.model_name = model_name
13+
self.assistant_id = assistant_id
14+
self.thread_id = thread_id
15+
16+
async def start_session(self):
17+
if self.thread_id is None:
18+
# Get or create a thread
19+
self.thread_id = await self.get_or_create_thread()
20+
21+
if self.assistant_id is None:
22+
# Find or create an assistant
23+
self.assistant_id = await self.find_or_create_assistant(
24+
name=self.assistant_name,
25+
model=self.model_name
26+
)
27+
28+
# Display existing chat history
29+
await self.display_chat_history()
30+
31+
prev_messages = await self.thread_manager.list_messages(self.thread_id)
32+
if prev_messages is None:
33+
print("An error occurred while retrieving messages.")
34+
return
35+
36+
# Start the chat loop
37+
await self.chat_loop()
38+
39+
async def chat_loop(self):
40+
try:
41+
while True:
42+
user_input = input("You: ")
43+
if user_input.lower() in ['exit', 'quit', 'bye']:
44+
break
45+
if user_input.lower() in ['/delete', '/clear']:
46+
await self.thread_manager.delete_thread(self.thread_id)
47+
self.thread_id = await self.get_or_create_thread()
48+
continue
49+
50+
response = await self.get_latest_response(user_input)
51+
52+
if response:
53+
print("Assistant:", response)
54+
55+
finally:
56+
print(f"Session ended")
57+
58+
async def get_or_create_thread(self):
59+
data = self.thread_manager.read_thread_data()
60+
thread_id = data.get('thread_id')
61+
if not thread_id:
62+
thread = await self.thread_manager.create_thread(messages=[])
63+
thread_id = thread.id
64+
self.thread_manager.save_thread_data(thread_id)
65+
return thread_id
66+
67+
async def find_or_create_assistant(self, name: str, model: str):
68+
"""
69+
Finds an existing assistant by name or creates a new one.
70+
71+
Args:
72+
name (str): The name of the assistant.
73+
model (str): The model ID for the assistant.
74+
75+
Returns:
76+
str: The ID of the found or created assistant.
77+
"""
78+
assistant_id = await self.assistant_manager.get_assistant_id_by_name(name)
79+
if not assistant_id:
80+
assistant = await self.assistant_manager.create_assistant(name=name,
81+
model=model,
82+
instructions="",
83+
tools=[{"type": "retrieval"}]
84+
)
85+
assistant_id = assistant.id
86+
return assistant_id
87+
88+
async def send_message(self, content):
89+
return await self.thread_manager.send_message(self.thread_id, content)
90+
91+
async def display_chat_history(self):
92+
messages = await self.thread_manager.list_messages(self.thread_id)
93+
if messages is None:
94+
return
95+
print(messages)
96+
for message in reversed(messages.data):
97+
role = message.role
98+
content = message.content[0].text.value # Assuming message content is structured this way
99+
print(f"{role.title()}: {content}")
100+
101+
async def get_latest_response(self, user_input):
102+
# Send the user message
103+
await self.send_message(user_input)
104+
105+
# Create a new run for the assistant to respond
106+
await self.create_run()
107+
108+
# Wait for the assistant's response
109+
await self.wait_for_assistant()
110+
111+
# Retrieve the latest response
112+
return await self.retrieve_latest_response()
113+
114+
async def create_run(self):
115+
return await self.thread_manager.create_run(self.thread_id, self.assistant_id)
116+
117+
async def wait_for_assistant(self):
118+
while True:
119+
runs = await self.thread_manager.list_runs(self.thread_id)
120+
latest_run = runs.data[0]
121+
if latest_run.status in ["completed", "failed"]:
122+
break
123+
await asyncio.sleep(2) # Wait for 2 seconds before checking again
124+
125+
async def retrieve_latest_response(self):
126+
response = await self.thread_manager.list_messages(self.thread_id)
127+
for message in response.data:
128+
if message.role == "assistant":
129+
return message.content[0].text.value
130+
return None

0 commit comments

Comments
 (0)