Skip to content

Commit e9a6044

Browse files
Merge branch 'main' into main
2 parents c50ff6f + 474d700 commit e9a6044

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+3147
-921
lines changed

application/agents/base.py

+27-3
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from application.llm.llm_creator import LLMCreator
1111
from application.logging import build_stack_data, log_activity, LogContext
1212
from application.retriever.base import BaseRetriever
13+
from bson.objectid import ObjectId
1314

1415

1516
class BaseAgent(ABC):
@@ -23,7 +24,7 @@ def __init__(
2324
prompt: str = "",
2425
chat_history: Optional[List[Dict]] = None,
2526
decoded_token: Optional[Dict] = None,
26-
attachments: Optional[List[Dict]]=None,
27+
attachments: Optional[List[Dict]] = None,
2728
):
2829
self.endpoint = endpoint
2930
self.llm_name = llm_name
@@ -58,6 +59,27 @@ def _gen_inner(
5859
) -> Generator[Dict, None, None]:
5960
pass
6061

62+
def _get_tools(self, api_key: str = None) -> Dict[str, Dict]:
63+
mongo = MongoDB.get_client()
64+
db = mongo["docsgpt"]
65+
agents_collection = db["agents"]
66+
tools_collection = db["user_tools"]
67+
68+
agent_data = agents_collection.find_one({"key": api_key or self.user_api_key})
69+
tool_ids = agent_data.get("tools", []) if agent_data else []
70+
71+
tools = (
72+
tools_collection.find(
73+
{"_id": {"$in": [ObjectId(tool_id) for tool_id in tool_ids]}}
74+
)
75+
if tool_ids
76+
else []
77+
)
78+
tools = list(tools)
79+
tools_by_id = {str(tool["_id"]): tool for tool in tools} if tools else {}
80+
81+
return tools_by_id
82+
6183
def _get_user_tools(self, user="local"):
6284
mongo = MongoDB.get_client()
6385
db = mongo["docsgpt"]
@@ -243,9 +265,11 @@ def _llm_handler(
243265
tools_dict: Dict,
244266
messages: List[Dict],
245267
log_context: Optional[LogContext] = None,
246-
attachments: Optional[List[Dict]] = None
268+
attachments: Optional[List[Dict]] = None,
247269
):
248-
resp = self.llm_handler.handle_response(self, resp, tools_dict, messages, attachments)
270+
resp = self.llm_handler.handle_response(
271+
self, resp, tools_dict, messages, attachments
272+
)
249273
if log_context:
250274
data = build_stack_data(self.llm_handler)
251275
log_context.stacks.append({"component": "llm_handler", "data": data})

application/agents/classic_agent.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,25 @@
55

66
from application.retriever.base import BaseRetriever
77
import logging
8+
89
logger = logging.getLogger(__name__)
910

11+
1012
class ClassicAgent(BaseAgent):
1113
def _gen_inner(
1214
self, query: str, retriever: BaseRetriever, log_context: LogContext
1315
) -> Generator[Dict, None, None]:
1416
retrieved_data = self._retriever_search(retriever, query, log_context)
15-
16-
tools_dict = self._get_user_tools(self.user)
17+
if self.user_api_key:
18+
tools_dict = self._get_tools(self.user_api_key)
19+
else:
20+
tools_dict = self._get_user_tools(self.user)
1721
self._prepare_tools(tools_dict)
1822

1923
messages = self._build_messages(self.prompt, query, retrieved_data)
2024

2125
resp = self._llm_gen(messages, log_context)
22-
26+
2327
attachments = self.attachments
2428

2529
if isinstance(resp, str):
@@ -33,7 +37,7 @@ def _gen_inner(
3337
yield {"answer": resp.message.content}
3438
return
3539

36-
resp = self._llm_handler(resp, tools_dict, messages, log_context,attachments)
40+
resp = self._llm_handler(resp, tools_dict, messages, log_context, attachments)
3741

3842
if isinstance(resp, str):
3943
yield {"answer": resp}

application/agents/react_agent.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,10 @@ def _gen_inner(
3030
) -> Generator[Dict, None, None]:
3131
retrieved_data = self._retriever_search(retriever, query, log_context)
3232

33-
tools_dict = self._get_user_tools(self.user)
33+
if self.user_api_key:
34+
tools_dict = self._get_tools(self.user_api_key)
35+
else:
36+
tools_dict = self._get_user_tools(self.user)
3437
self._prepare_tools(tools_dict)
3538

3639
docs_together = "\n".join([doc["text"] for doc in retrieved_data])

application/api/answer/routes.py

+72-27
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
conversations_collection = db["conversations"]
2828
sources_collection = db["sources"]
2929
prompts_collection = db["prompts"]
30-
api_key_collection = db["api_keys"]
30+
agents_collection = db["agents"]
3131
user_logs_collection = db["user_logs"]
3232
attachments_collection = db["attachments"]
3333

@@ -86,19 +86,42 @@ def run_async_chain(chain, question, chat_history):
8686
return result
8787

8888

89+
def get_agent_key(agent_id, user_id):
90+
if not agent_id:
91+
return None
92+
93+
try:
94+
agent = agents_collection.find_one({"_id": ObjectId(agent_id)})
95+
if agent is None:
96+
raise Exception("Agent not found", 404)
97+
98+
if agent.get("user") == user_id:
99+
agents_collection.update_one(
100+
{"_id": ObjectId(agent_id)},
101+
{"$set": {"lastUsedAt": datetime.datetime.now(datetime.timezone.utc)}},
102+
)
103+
return str(agent["key"])
104+
105+
raise Exception("Unauthorized access to the agent", 403)
106+
107+
except Exception as e:
108+
logger.error(f"Error in get_agent_key: {str(e)}")
109+
raise
110+
111+
89112
def get_data_from_api_key(api_key):
90-
data = api_key_collection.find_one({"key": api_key})
91-
# # Raise custom exception if the API key is not found
92-
if data is None:
93-
raise Exception("Invalid API Key, please generate new key", 401)
113+
data = agents_collection.find_one({"key": api_key})
114+
if not data:
115+
raise Exception("Invalid API Key, please generate a new key", 401)
94116

95-
if "source" in data and isinstance(data["source"], DBRef):
96-
source_doc = db.dereference(data["source"])
117+
source = data.get("source")
118+
if isinstance(source, DBRef):
119+
source_doc = db.dereference(source)
97120
data["source"] = str(source_doc["_id"])
98-
if "retriever" in source_doc:
99-
data["retriever"] = source_doc["retriever"]
121+
data["retriever"] = source_doc.get("retriever", data.get("retriever"))
100122
else:
101123
data["source"] = {}
124+
102125
return data
103126

104127

@@ -128,7 +151,8 @@ def save_conversation(
128151
llm,
129152
decoded_token,
130153
index=None,
131-
api_key=None
154+
api_key=None,
155+
agent_id=None,
132156
):
133157
current_time = datetime.datetime.now(datetime.timezone.utc)
134158
if conversation_id is not None and index is not None:
@@ -202,7 +226,9 @@ def save_conversation(
202226
],
203227
}
204228
if api_key:
205-
api_key_doc = api_key_collection.find_one({"key": api_key})
229+
if agent_id:
230+
conversation_data["agent_id"] = agent_id
231+
api_key_doc = agents_collection.find_one({"key": api_key})
206232
if api_key_doc:
207233
conversation_data["api_key"] = api_key_doc["key"]
208234
conversation_id = conversations_collection.insert_one(
@@ -234,14 +260,17 @@ def complete_stream(
234260
index=None,
235261
should_save_conversation=True,
236262
attachments=None,
263+
agent_id=None,
237264
):
238265
try:
239266
response_full, thought, source_log_docs, tool_calls = "", "", [], []
240267
attachment_ids = []
241268

242269
if attachments:
243270
attachment_ids = [attachment["id"] for attachment in attachments]
244-
logger.info(f"Processing request with {len(attachments)} attachments: {attachment_ids}")
271+
logger.info(
272+
f"Processing request with {len(attachments)} attachments: {attachment_ids}"
273+
)
245274

246275
answer = agent.gen(query=question, retriever=retriever)
247276

@@ -294,7 +323,8 @@ def complete_stream(
294323
llm,
295324
decoded_token,
296325
index,
297-
api_key=user_api_key
326+
api_key=user_api_key,
327+
agent_id=agent_id,
298328
)
299329
else:
300330
conversation_id = None
@@ -366,7 +396,9 @@ class Stream(Resource):
366396
required=False, description="Index of the query to update"
367397
),
368398
"save_conversation": fields.Boolean(
369-
required=False, default=True, description="Whether to save the conversation"
399+
required=False,
400+
default=True,
401+
description="Whether to save the conversation",
370402
),
371403
"attachments": fields.List(
372404
fields.String, required=False, description="List of attachment IDs"
@@ -400,6 +432,14 @@ def post(self):
400432
chunks = int(data.get("chunks", 2))
401433
token_limit = data.get("token_limit", settings.DEFAULT_MAX_HISTORY)
402434
retriever_name = data.get("retriever", "classic")
435+
agent_id = data.get("agent_id", None)
436+
agent_type = settings.AGENT_NAME
437+
agent_key = get_agent_key(agent_id, request.decoded_token.get("sub"))
438+
439+
if agent_key:
440+
data.update({"api_key": agent_key})
441+
else:
442+
agent_id = None
403443

404444
if "api_key" in data:
405445
data_key = get_data_from_api_key(data["api_key"])
@@ -408,6 +448,7 @@ def post(self):
408448
source = {"active_docs": data_key.get("source")}
409449
retriever_name = data_key.get("retriever", retriever_name)
410450
user_api_key = data["api_key"]
451+
agent_type = data_key.get("agent_type", agent_type)
411452
decoded_token = {"sub": data_key.get("user")}
412453

413454
elif "active_docs" in data:
@@ -423,8 +464,10 @@ def post(self):
423464

424465
if not decoded_token:
425466
return make_response({"error": "Unauthorized"}, 401)
426-
427-
attachments = get_attachments_content(attachment_ids, decoded_token.get("sub"))
467+
468+
attachments = get_attachments_content(
469+
attachment_ids, decoded_token.get("sub")
470+
)
428471

429472
logger.info(
430473
f"/stream - request_data: {data}, source: {source}, attachments: {len(attachments)}",
@@ -436,7 +479,7 @@ def post(self):
436479
chunks = 0
437480

438481
agent = AgentCreator.create_agent(
439-
settings.AGENT_NAME,
482+
agent_type,
440483
endpoint="stream",
441484
llm_name=settings.LLM_NAME,
442485
gpt_model=gpt_model,
@@ -471,6 +514,7 @@ def post(self):
471514
isNoneDoc=data.get("isNoneDoc"),
472515
index=index,
473516
should_save_conversation=save_conv,
517+
agent_id=agent_id,
474518
),
475519
mimetype="text/event-stream",
476520
)
@@ -552,6 +596,7 @@ def post(self):
552596
chunks = int(data.get("chunks", 2))
553597
token_limit = data.get("token_limit", settings.DEFAULT_MAX_HISTORY)
554598
retriever_name = data.get("retriever", "classic")
599+
agent_type = settings.AGENT_NAME
555600

556601
if "api_key" in data:
557602
data_key = get_data_from_api_key(data["api_key"])
@@ -560,6 +605,7 @@ def post(self):
560605
source = {"active_docs": data_key.get("source")}
561606
retriever_name = data_key.get("retriever", retriever_name)
562607
user_api_key = data["api_key"]
608+
agent_type = data_key.get("agent_type", agent_type)
563609
decoded_token = {"sub": data_key.get("user")}
564610

565611
elif "active_docs" in data:
@@ -584,7 +630,7 @@ def post(self):
584630
)
585631

586632
agent = AgentCreator.create_agent(
587-
settings.AGENT_NAME,
633+
agent_type,
588634
endpoint="api/answer",
589635
llm_name=settings.LLM_NAME,
590636
gpt_model=gpt_model,
@@ -815,28 +861,27 @@ def post(self):
815861
def get_attachments_content(attachment_ids, user):
816862
"""
817863
Retrieve content from attachment documents based on their IDs.
818-
864+
819865
Args:
820866
attachment_ids (list): List of attachment document IDs
821867
user (str): User identifier to verify ownership
822-
868+
823869
Returns:
824870
list: List of dictionaries containing attachment content and metadata
825871
"""
826872
if not attachment_ids:
827873
return []
828-
874+
829875
attachments = []
830876
for attachment_id in attachment_ids:
831877
try:
832-
attachment_doc = attachments_collection.find_one({
833-
"_id": ObjectId(attachment_id),
834-
"user": user
835-
})
836-
878+
attachment_doc = attachments_collection.find_one(
879+
{"_id": ObjectId(attachment_id), "user": user}
880+
)
881+
837882
if attachment_doc:
838883
attachments.append(attachment_doc)
839884
except Exception as e:
840885
logger.error(f"Error retrieving attachment {attachment_id}: {e}")
841-
886+
842887
return attachments

0 commit comments

Comments
 (0)