Skip to content

Commit 4d7b280

Browse files
committed
Pebblo: Policy enforcement in Safe Retriever
1 parent d081a54 commit 4d7b280

File tree

4 files changed

+330
-16
lines changed

4 files changed

+330
-16
lines changed

libs/community/langchain_community/chains/pebblo_retrieval/base.py

Lines changed: 56 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -102,18 +102,27 @@ def _call(
102102
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
103103
question = inputs[self.input_key]
104104
auth_context = inputs.get(self.auth_context_key)
105-
semantic_context = inputs.get(self.semantic_context_key)
105+
is_privileged_user = self.pb_client.is_privileged_user(auth_context)
106+
semantic_context = self.determine_semantic_context(
107+
is_privileged_user, auth_context, inputs
108+
)
106109
_, prompt_entities = self.pb_client.check_prompt_validity(question)
107110

108111
accepts_run_manager = (
109112
"run_manager" in inspect.signature(self._get_docs).parameters
110113
)
111114
if accepts_run_manager:
112115
docs = self._get_docs(
113-
question, auth_context, semantic_context, run_manager=_run_manager
116+
question,
117+
auth_context,
118+
semantic_context,
119+
is_privileged_user,
120+
run_manager=_run_manager,
114121
)
115122
else:
116-
docs = self._get_docs(question, auth_context, semantic_context) # type: ignore[call-arg]
123+
docs = self._get_docs(
124+
question, auth_context, semantic_context, is_privileged_user
125+
) # type: ignore[call-arg]
117126
answer = self.combine_documents_chain.run(
118127
input_documents=docs, question=question, callbacks=_run_manager.get_child()
119128
)
@@ -155,7 +164,10 @@ async def _acall(
155164
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
156165
question = inputs[self.input_key]
157166
auth_context = inputs.get(self.auth_context_key)
158-
semantic_context = inputs.get(self.semantic_context_key)
167+
is_privileged_user = self.pb_client.is_privileged_user(auth_context)
168+
semantic_context = self.determine_semantic_context(
169+
is_privileged_user, auth_context, inputs
170+
)
159171
accepts_run_manager = (
160172
"run_manager" in inspect.signature(self._aget_docs).parameters
161173
)
@@ -164,10 +176,16 @@ async def _acall(
164176

165177
if accepts_run_manager:
166178
docs = await self._aget_docs(
167-
question, auth_context, semantic_context, run_manager=_run_manager
179+
question,
180+
auth_context,
181+
semantic_context,
182+
is_privileged_user,
183+
run_manager=_run_manager,
168184
)
169185
else:
170-
docs = await self._aget_docs(question, auth_context, semantic_context) # type: ignore[call-arg]
186+
docs = await self._aget_docs(
187+
question, auth_context, semantic_context, is_privileged_user
188+
) # type: ignore[call-arg]
171189
answer = await self.combine_documents_chain.arun(
172190
input_documents=docs, question=question, callbacks=_run_manager.get_child()
173191
)
@@ -254,6 +272,7 @@ def from_chain_type(
254272
api_key=api_key,
255273
classifier_location=classifier_location,
256274
classifier_url=classifier_url,
275+
app_name=app_name,
257276
)
258277
# send app discovery request
259278
pb_client.send_app_discover(app)
@@ -289,11 +308,13 @@ def _get_docs(
289308
question: str,
290309
auth_context: Optional[AuthContext],
291310
semantic_context: Optional[SemanticContext],
311+
is_privileged_user: bool = False,
292312
*,
293313
run_manager: CallbackManagerForChainRun,
294314
) -> List[Document]:
295315
"""Get docs."""
296-
set_enforcement_filters(self.retriever, auth_context, semantic_context)
316+
if not is_privileged_user:
317+
set_enforcement_filters(self.retriever, auth_context, semantic_context)
297318
return self.retriever.get_relevant_documents(
298319
question, callbacks=run_manager.get_child()
299320
)
@@ -303,15 +324,42 @@ async def _aget_docs(
303324
question: str,
304325
auth_context: Optional[AuthContext],
305326
semantic_context: Optional[SemanticContext],
327+
is_privileged_user: bool = False,
306328
*,
307329
run_manager: AsyncCallbackManagerForChainRun,
308330
) -> List[Document]:
309331
"""Get docs."""
310-
set_enforcement_filters(self.retriever, auth_context, semantic_context)
332+
if not is_privileged_user:
333+
set_enforcement_filters(self.retriever, auth_context, semantic_context)
311334
return await self.retriever.aget_relevant_documents(
312335
question, callbacks=run_manager.get_child()
313336
)
314337

338+
def determine_semantic_context(
339+
self,
340+
is_privileged_user: bool,
341+
auth_context: Optional[AuthContext],
342+
inputs: Dict[str, Any],
343+
) -> Optional[SemanticContext]:
344+
"""
345+
Determine semantic context based on the auth_context or inputs.
346+
347+
Args:
348+
is_privileged_user (bool): If the user is a privileged user.
349+
auth_context (Optional[AuthContext]): Authentication context.
350+
inputs (Dict[str, Any]): Input dictionary containing various parameters.
351+
352+
Returns:
353+
Optional[SemanticContext]: Resolved semantic context.
354+
"""
355+
semantic_context = None
356+
if not is_privileged_user:
357+
# Get semantic context from policy if present otherwise from inputs
358+
semantic_context = self.pb_client.get_semantic_context(
359+
auth_context
360+
) or inputs.get(self.semantic_context_key)
361+
return semantic_context
362+
315363
@staticmethod
316364
def _get_app_details( # type: ignore
317365
app_name: str, owner: str, description: str, llm: BaseLanguageModel, **kwargs

libs/community/langchain_community/chains/pebblo_retrieval/models.py

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Models for the PebbloRetrievalQA chain."""
22

3-
from typing import Any, List, Optional, Union
3+
from enum import Enum
4+
from typing import Any, List, Optional, Set, Union
45

56
from pydantic import BaseModel
67

@@ -10,7 +11,7 @@ class AuthContext(BaseModel):
1011

1112
name: Optional[str] = None
1213
user_id: str
13-
user_auth: List[str]
14+
user_auth: List[str] = []
1415
"""List of user authorizations, which may include their User ID and
1516
the groups they are part of"""
1617

@@ -149,3 +150,54 @@ class Qa(BaseModel):
149150
user: str
150151
user_identities: Optional[List[str]]
151152
classifier_location: str
153+
154+
155+
class PolicyType(Enum):
156+
"""Enums for policy types"""
157+
158+
IDENTITY = "identity"
159+
APPLICATION = "application"
160+
COST = "cost"
161+
162+
163+
class SemanticGuardrail(BaseModel):
164+
"""
165+
Semantic Guardrail for Entities and Topics (Restrictions).
166+
167+
Attributes:
168+
entities (Optional[Set[str]]): A set of entity restrictions.
169+
topics (Optional[Set[str]]): A set of topic restrictions.
170+
"""
171+
172+
entities: Set[str] = set()
173+
topics: Set[str] = set()
174+
175+
176+
class Policy(BaseModel):
177+
"""
178+
Policy base class.
179+
180+
Attributes:
181+
schema_version (int): The schema version of the policy.
182+
type (PolicyType): The type of policy.
183+
"""
184+
185+
schema_version: int = 1
186+
type: PolicyType
187+
188+
class Config:
189+
extra = "ignore"
190+
191+
192+
class IdentityPolicy(Policy):
193+
"""
194+
Policy for access control.
195+
196+
Attributes:
197+
privileged_identities (Set[str]): List of identities with privileged access.
198+
user_semantic_guardrail (dict[str, SemanticGuardrail]): Mapping of identities to
199+
semantic guardrail restrictions.
200+
"""
201+
202+
privileged_identities: Set[str] = set()
203+
user_semantic_guardrail: dict[str, SemanticGuardrail] = {}

0 commit comments

Comments
 (0)