27
27
conversations_collection = db ["conversations" ]
28
28
sources_collection = db ["sources" ]
29
29
prompts_collection = db ["prompts" ]
30
- api_key_collection = db ["api_keys " ]
30
+ agents_collection = db ["agents " ]
31
31
user_logs_collection = db ["user_logs" ]
32
32
attachments_collection = db ["attachments" ]
33
33
@@ -86,19 +86,42 @@ def run_async_chain(chain, question, chat_history):
86
86
return result
87
87
88
88
89
+ def get_agent_key (agent_id , user_id ):
90
+ if not agent_id :
91
+ return None
92
+
93
+ try :
94
+ agent = agents_collection .find_one ({"_id" : ObjectId (agent_id )})
95
+ if agent is None :
96
+ raise Exception ("Agent not found" , 404 )
97
+
98
+ if agent .get ("user" ) == user_id :
99
+ agents_collection .update_one (
100
+ {"_id" : ObjectId (agent_id )},
101
+ {"$set" : {"lastUsedAt" : datetime .datetime .now (datetime .timezone .utc )}},
102
+ )
103
+ return str (agent ["key" ])
104
+
105
+ raise Exception ("Unauthorized access to the agent" , 403 )
106
+
107
+ except Exception as e :
108
+ logger .error (f"Error in get_agent_key: { str (e )} " )
109
+ raise
110
+
111
+
89
112
def get_data_from_api_key (api_key ):
90
- data = api_key_collection .find_one ({"key" : api_key })
91
- # # Raise custom exception if the API key is not found
92
- if data is None :
93
- raise Exception ("Invalid API Key, please generate new key" , 401 )
113
+ data = agents_collection .find_one ({"key" : api_key })
114
+ if not data :
115
+ raise Exception ("Invalid API Key, please generate a new key" , 401 )
94
116
95
- if "source" in data and isinstance (data ["source" ], DBRef ):
96
- source_doc = db .dereference (data ["source" ])
117
+ source = data .get ("source" )
118
+ if isinstance (source , DBRef ):
119
+ source_doc = db .dereference (source )
97
120
data ["source" ] = str (source_doc ["_id" ])
98
- if "retriever" in source_doc :
99
- data ["retriever" ] = source_doc ["retriever" ]
121
+ data ["retriever" ] = source_doc .get ("retriever" , data .get ("retriever" ))
100
122
else :
101
123
data ["source" ] = {}
124
+
102
125
return data
103
126
104
127
@@ -128,7 +151,8 @@ def save_conversation(
128
151
llm ,
129
152
decoded_token ,
130
153
index = None ,
131
- api_key = None
154
+ api_key = None ,
155
+ agent_id = None ,
132
156
):
133
157
current_time = datetime .datetime .now (datetime .timezone .utc )
134
158
if conversation_id is not None and index is not None :
@@ -202,7 +226,9 @@ def save_conversation(
202
226
],
203
227
}
204
228
if api_key :
205
- api_key_doc = api_key_collection .find_one ({"key" : api_key })
229
+ if agent_id :
230
+ conversation_data ["agent_id" ] = agent_id
231
+ api_key_doc = agents_collection .find_one ({"key" : api_key })
206
232
if api_key_doc :
207
233
conversation_data ["api_key" ] = api_key_doc ["key" ]
208
234
conversation_id = conversations_collection .insert_one (
@@ -234,14 +260,17 @@ def complete_stream(
234
260
index = None ,
235
261
should_save_conversation = True ,
236
262
attachments = None ,
263
+ agent_id = None ,
237
264
):
238
265
try :
239
266
response_full , thought , source_log_docs , tool_calls = "" , "" , [], []
240
267
attachment_ids = []
241
268
242
269
if attachments :
243
270
attachment_ids = [attachment ["id" ] for attachment in attachments ]
244
- logger .info (f"Processing request with { len (attachments )} attachments: { attachment_ids } " )
271
+ logger .info (
272
+ f"Processing request with { len (attachments )} attachments: { attachment_ids } "
273
+ )
245
274
246
275
answer = agent .gen (query = question , retriever = retriever )
247
276
@@ -294,7 +323,8 @@ def complete_stream(
294
323
llm ,
295
324
decoded_token ,
296
325
index ,
297
- api_key = user_api_key
326
+ api_key = user_api_key ,
327
+ agent_id = agent_id ,
298
328
)
299
329
else :
300
330
conversation_id = None
@@ -366,7 +396,9 @@ class Stream(Resource):
366
396
required = False , description = "Index of the query to update"
367
397
),
368
398
"save_conversation" : fields .Boolean (
369
- required = False , default = True , description = "Whether to save the conversation"
399
+ required = False ,
400
+ default = True ,
401
+ description = "Whether to save the conversation" ,
370
402
),
371
403
"attachments" : fields .List (
372
404
fields .String , required = False , description = "List of attachment IDs"
@@ -400,6 +432,14 @@ def post(self):
400
432
chunks = int (data .get ("chunks" , 2 ))
401
433
token_limit = data .get ("token_limit" , settings .DEFAULT_MAX_HISTORY )
402
434
retriever_name = data .get ("retriever" , "classic" )
435
+ agent_id = data .get ("agent_id" , None )
436
+ agent_type = settings .AGENT_NAME
437
+ agent_key = get_agent_key (agent_id , request .decoded_token .get ("sub" ))
438
+
439
+ if agent_key :
440
+ data .update ({"api_key" : agent_key })
441
+ else :
442
+ agent_id = None
403
443
404
444
if "api_key" in data :
405
445
data_key = get_data_from_api_key (data ["api_key" ])
@@ -408,6 +448,7 @@ def post(self):
408
448
source = {"active_docs" : data_key .get ("source" )}
409
449
retriever_name = data_key .get ("retriever" , retriever_name )
410
450
user_api_key = data ["api_key" ]
451
+ agent_type = data_key .get ("agent_type" , agent_type )
411
452
decoded_token = {"sub" : data_key .get ("user" )}
412
453
413
454
elif "active_docs" in data :
@@ -423,8 +464,10 @@ def post(self):
423
464
424
465
if not decoded_token :
425
466
return make_response ({"error" : "Unauthorized" }, 401 )
426
-
427
- attachments = get_attachments_content (attachment_ids , decoded_token .get ("sub" ))
467
+
468
+ attachments = get_attachments_content (
469
+ attachment_ids , decoded_token .get ("sub" )
470
+ )
428
471
429
472
logger .info (
430
473
f"/stream - request_data: { data } , source: { source } , attachments: { len (attachments )} " ,
@@ -436,7 +479,7 @@ def post(self):
436
479
chunks = 0
437
480
438
481
agent = AgentCreator .create_agent (
439
- settings . AGENT_NAME ,
482
+ agent_type ,
440
483
endpoint = "stream" ,
441
484
llm_name = settings .LLM_NAME ,
442
485
gpt_model = gpt_model ,
@@ -471,6 +514,7 @@ def post(self):
471
514
isNoneDoc = data .get ("isNoneDoc" ),
472
515
index = index ,
473
516
should_save_conversation = save_conv ,
517
+ agent_id = agent_id ,
474
518
),
475
519
mimetype = "text/event-stream" ,
476
520
)
@@ -552,6 +596,7 @@ def post(self):
552
596
chunks = int (data .get ("chunks" , 2 ))
553
597
token_limit = data .get ("token_limit" , settings .DEFAULT_MAX_HISTORY )
554
598
retriever_name = data .get ("retriever" , "classic" )
599
+ agent_type = settings .AGENT_NAME
555
600
556
601
if "api_key" in data :
557
602
data_key = get_data_from_api_key (data ["api_key" ])
@@ -560,6 +605,7 @@ def post(self):
560
605
source = {"active_docs" : data_key .get ("source" )}
561
606
retriever_name = data_key .get ("retriever" , retriever_name )
562
607
user_api_key = data ["api_key" ]
608
+ agent_type = data_key .get ("agent_type" , agent_type )
563
609
decoded_token = {"sub" : data_key .get ("user" )}
564
610
565
611
elif "active_docs" in data :
@@ -584,7 +630,7 @@ def post(self):
584
630
)
585
631
586
632
agent = AgentCreator .create_agent (
587
- settings . AGENT_NAME ,
633
+ agent_type ,
588
634
endpoint = "api/answer" ,
589
635
llm_name = settings .LLM_NAME ,
590
636
gpt_model = gpt_model ,
@@ -815,28 +861,27 @@ def post(self):
815
861
def get_attachments_content (attachment_ids , user ):
816
862
"""
817
863
Retrieve content from attachment documents based on their IDs.
818
-
864
+
819
865
Args:
820
866
attachment_ids (list): List of attachment document IDs
821
867
user (str): User identifier to verify ownership
822
-
868
+
823
869
Returns:
824
870
list: List of dictionaries containing attachment content and metadata
825
871
"""
826
872
if not attachment_ids :
827
873
return []
828
-
874
+
829
875
attachments = []
830
876
for attachment_id in attachment_ids :
831
877
try :
832
- attachment_doc = attachments_collection .find_one ({
833
- "_id" : ObjectId (attachment_id ),
834
- "user" : user
835
- })
836
-
878
+ attachment_doc = attachments_collection .find_one (
879
+ {"_id" : ObjectId (attachment_id ), "user" : user }
880
+ )
881
+
837
882
if attachment_doc :
838
883
attachments .append (attachment_doc )
839
884
except Exception as e :
840
885
logger .error (f"Error retrieving attachment { attachment_id } : { e } " )
841
-
886
+
842
887
return attachments
0 commit comments