@@ -102,18 +102,27 @@ def _call(
102
102
_run_manager = run_manager or CallbackManagerForChainRun .get_noop_manager ()
103
103
question = inputs [self .input_key ]
104
104
auth_context = inputs .get (self .auth_context_key )
105
- semantic_context = inputs .get (self .semantic_context_key )
105
+ is_privileged_user = self .pb_client .is_privileged_user (auth_context )
106
+ semantic_context = self .determine_semantic_context (
107
+ is_privileged_user , auth_context , inputs
108
+ )
106
109
_ , prompt_entities = self .pb_client .check_prompt_validity (question )
107
110
108
111
accepts_run_manager = (
109
112
"run_manager" in inspect .signature (self ._get_docs ).parameters
110
113
)
111
114
if accepts_run_manager :
112
115
docs = self ._get_docs (
113
- question , auth_context , semantic_context , run_manager = _run_manager
116
+ question ,
117
+ auth_context ,
118
+ semantic_context ,
119
+ is_privileged_user ,
120
+ run_manager = _run_manager ,
114
121
)
115
122
else :
116
- docs = self ._get_docs (question , auth_context , semantic_context ) # type: ignore[call-arg]
123
+ docs = self ._get_docs (
124
+ question , auth_context , semantic_context , is_privileged_user
125
+ ) # type: ignore[call-arg]
117
126
answer = self .combine_documents_chain .run (
118
127
input_documents = docs , question = question , callbacks = _run_manager .get_child ()
119
128
)
@@ -155,7 +164,10 @@ async def _acall(
155
164
_run_manager = run_manager or AsyncCallbackManagerForChainRun .get_noop_manager ()
156
165
question = inputs [self .input_key ]
157
166
auth_context = inputs .get (self .auth_context_key )
158
- semantic_context = inputs .get (self .semantic_context_key )
167
+ is_privileged_user = self .pb_client .is_privileged_user (auth_context )
168
+ semantic_context = self .determine_semantic_context (
169
+ is_privileged_user , auth_context , inputs
170
+ )
159
171
accepts_run_manager = (
160
172
"run_manager" in inspect .signature (self ._aget_docs ).parameters
161
173
)
@@ -164,10 +176,16 @@ async def _acall(
164
176
165
177
if accepts_run_manager :
166
178
docs = await self ._aget_docs (
167
- question , auth_context , semantic_context , run_manager = _run_manager
179
+ question ,
180
+ auth_context ,
181
+ semantic_context ,
182
+ is_privileged_user ,
183
+ run_manager = _run_manager ,
168
184
)
169
185
else :
170
- docs = await self ._aget_docs (question , auth_context , semantic_context ) # type: ignore[call-arg]
186
+ docs = await self ._aget_docs (
187
+ question , auth_context , semantic_context , is_privileged_user
188
+ ) # type: ignore[call-arg]
171
189
answer = await self .combine_documents_chain .arun (
172
190
input_documents = docs , question = question , callbacks = _run_manager .get_child ()
173
191
)
@@ -254,6 +272,7 @@ def from_chain_type(
254
272
api_key = api_key ,
255
273
classifier_location = classifier_location ,
256
274
classifier_url = classifier_url ,
275
+ app_name = app_name ,
257
276
)
258
277
# send app discovery request
259
278
pb_client .send_app_discover (app )
@@ -289,11 +308,13 @@ def _get_docs(
289
308
question : str ,
290
309
auth_context : Optional [AuthContext ],
291
310
semantic_context : Optional [SemanticContext ],
311
+ is_privileged_user : bool = False ,
292
312
* ,
293
313
run_manager : CallbackManagerForChainRun ,
294
314
) -> List [Document ]:
295
315
"""Get docs."""
296
- set_enforcement_filters (self .retriever , auth_context , semantic_context )
316
+ if not is_privileged_user :
317
+ set_enforcement_filters (self .retriever , auth_context , semantic_context )
297
318
return self .retriever .get_relevant_documents (
298
319
question , callbacks = run_manager .get_child ()
299
320
)
@@ -303,15 +324,42 @@ async def _aget_docs(
303
324
question : str ,
304
325
auth_context : Optional [AuthContext ],
305
326
semantic_context : Optional [SemanticContext ],
327
+ is_privileged_user : bool = False ,
306
328
* ,
307
329
run_manager : AsyncCallbackManagerForChainRun ,
308
330
) -> List [Document ]:
309
331
"""Get docs."""
310
- set_enforcement_filters (self .retriever , auth_context , semantic_context )
332
+ if not is_privileged_user :
333
+ set_enforcement_filters (self .retriever , auth_context , semantic_context )
311
334
return await self .retriever .aget_relevant_documents (
312
335
question , callbacks = run_manager .get_child ()
313
336
)
314
337
338
+ def determine_semantic_context (
339
+ self ,
340
+ is_privileged_user : bool ,
341
+ auth_context : Optional [AuthContext ],
342
+ inputs : Dict [str , Any ],
343
+ ) -> Optional [SemanticContext ]:
344
+ """
345
+ Determine semantic context based on the auth_context or inputs.
346
+
347
+ Args:
348
+ is_privileged_user (bool): If the user is a privileged user.
349
+ auth_context (Optional[AuthContext]): Authentication context.
350
+ inputs (Dict[str, Any]): Input dictionary containing various parameters.
351
+
352
+ Returns:
353
+ Optional[SemanticContext]: Resolved semantic context.
354
+ """
355
+ semantic_context = None
356
+ if not is_privileged_user :
357
+ # Get semantic context from policy if present otherwise from inputs
358
+ semantic_context = self .pb_client .get_semantic_context (
359
+ auth_context
360
+ ) or inputs .get (self .semantic_context_key )
361
+ return semantic_context
362
+
315
363
@staticmethod
316
364
def _get_app_details ( # type: ignore
317
365
app_name : str , owner : str , description : str , llm : BaseLanguageModel , ** kwargs
0 commit comments