Skip to content

Commit fd0b567

Browse files
committed
pangea-sdk: make AI Guard messages parameter stricter
1 parent e8800fc commit fd0b567

File tree

8 files changed

+43
-27
lines changed

8 files changed

+43
-27
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## Unreleased
99

10+
### Changed
11+
12+
- AI Guard: `messages` parameter is no longer a generic. A new `Message` model
13+
has been introduced, and `messages` is now a `Sequence[Message]`.
14+
1015
## 6.1.1 - 2025-05-12
1116

1217
### Fixed

examples/ai_guard/ai_guard_examples/guard_text.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import os
44

55
from pangea import PangeaConfig
6-
from pangea.services.ai_guard import AIGuard
6+
from pangea.services.ai_guard import AIGuard, Message
77

88
token = os.getenv("PANGEA_AI_GUARD_TOKEN", "")
99
domain = os.getenv("PANGEA_DOMAIN", "aws.us.pangea.cloud")
@@ -18,7 +18,7 @@
1818
print("Response:", text_response.result.prompt_text)
1919

2020
# Structured input.
21-
structured_input = [{"role": "user", "content": "hello world"}]
21+
structured_input = [Message(role="user", content="hello world")]
2222
print("Guarding structured input:", structured_input)
2323
structured_response = ai_guard.guard_text(messages=structured_input)
2424
assert structured_response.result

examples/asyncio/ai_guard/async_ai_guard_examples/guard_text.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from pangea import PangeaConfig
77
from pangea.asyncio.services import AIGuardAsync
8+
from pangea.services.ai_guard import Message
89

910
token = os.getenv("PANGEA_AI_GUARD_TOKEN", "")
1011
domain = os.getenv("PANGEA_DOMAIN", "aws.us.pangea.cloud")
@@ -21,7 +22,7 @@ async def main() -> None:
2122
print("Response:", text_response.result.prompt_text)
2223

2324
# Structured input.
24-
structured_input = [{"role": "user", "content": "hello world"}]
25+
structured_input = [Message(role="user", content="hello world")]
2526
print("Guarding structured input:", structured_input)
2627
structured_response = await ai_guard.guard_text(messages=structured_input)
2728
assert structured_response.result

packages/pangea-sdk/pangea/asyncio/services/ai_guard.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
from __future__ import annotations
22

3+
from collections.abc import Sequence
34
from typing import overload
45

56
from typing_extensions import TypeVar
67

78
from pangea.asyncio.services.base import ServiceBaseAsync
89
from pangea.config import PangeaConfig
910
from pangea.response import PangeaResponse
10-
from pangea.services.ai_guard import LogFields, Overrides, TextGuardResult
11+
from pangea.services.ai_guard import LogFields, Message, Overrides, TextGuardResult
1112

1213
_T = TypeVar("_T")
1314

@@ -60,7 +61,7 @@ async def guard_text(
6061
debug: bool | None = None,
6162
overrides: Overrides | None = None,
6263
log_fields: LogFields | None = None,
63-
) -> PangeaResponse[TextGuardResult[None]]:
64+
) -> PangeaResponse[TextGuardResult]:
6465
"""
6566
Text Guard for scanning LLM inputs and outputs
6667
@@ -88,12 +89,12 @@ async def guard_text(
8889
async def guard_text(
8990
self,
9091
*,
91-
messages: _T,
92+
messages: Sequence[Message],
9293
recipe: str | None = None,
9394
debug: bool | None = None,
9495
overrides: Overrides | None = None,
9596
log_fields: LogFields | None = None,
96-
) -> PangeaResponse[TextGuardResult[_T]]:
97+
) -> PangeaResponse[TextGuardResult]:
9798
"""
9899
Text Guard for scanning LLM inputs and outputs
99100
@@ -115,19 +116,19 @@ async def guard_text(
115116
log_field: Additional fields to include in activity log
116117
117118
Examples:
118-
response = await ai_guard.guard_text(messages=[{"role": "user", "content": "hello world"}])
119+
response = await ai_guard.guard_text(messages=[Message(role="user", content="hello world")])
119120
"""
120121

121-
async def guard_text( # type: ignore[misc]
122+
async def guard_text(
122123
self,
123124
text: str | None = None,
124125
*,
125-
messages: _T | None = None,
126+
messages: Sequence[Message] | None = None,
126127
recipe: str | None = None,
127128
debug: bool | None = None,
128129
overrides: Overrides | None = None,
129130
log_fields: LogFields | None = None,
130-
) -> PangeaResponse[TextGuardResult[None]]:
131+
) -> PangeaResponse[TextGuardResult]:
131132
"""
132133
Text Guard for scanning LLM inputs and outputs
133134

packages/pangea-sdk/pangea/services/ai_guard.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from __future__ import annotations
22

3+
from collections.abc import Sequence
34
from typing import Generic, Literal, Optional, overload
45

6+
from pydantic import BaseModel, ConfigDict
57
from typing_extensions import TypeVar
68

79
from pangea.config import PangeaConfig
@@ -19,6 +21,13 @@
1921
PiiEntityAction = Literal["disabled", "report", "block", "mask", "partial_masking", "replacement", "hash", "fpe"]
2022

2123

24+
class Message(BaseModel):
25+
model_config = ConfigDict(extra="forbid")
26+
27+
role: str
28+
content: str
29+
30+
2231
class CodeDetectionOverride(APIRequestModel):
2332
disabled: Optional[bool] = None
2433
action: Optional[Literal["report", "block"]] = None
@@ -276,14 +285,14 @@ class TextGuardDetectors(APIResponseModel):
276285
code_detection: Optional[TextGuardDetector[CodeDetectionResult]] = None
277286

278287

279-
class TextGuardResult(PangeaResponseResult, Generic[_T]):
288+
class TextGuardResult(PangeaResponseResult):
280289
detectors: TextGuardDetectors
281290
"""Result of the recipe analyzing and input prompt."""
282291

283292
prompt_text: Optional[str] = None
284293
"""Updated prompt text, if applicable."""
285294

286-
prompt_messages: Optional[_T] = None
295+
prompt_messages: Optional[object] = None
287296
"""Updated structured prompt, if applicable."""
288297

289298
blocked: bool
@@ -347,7 +356,7 @@ def guard_text(
347356
debug: bool | None = None,
348357
overrides: Overrides | None = None,
349358
log_fields: LogFields | None = None,
350-
) -> PangeaResponse[TextGuardResult[None]]:
359+
) -> PangeaResponse[TextGuardResult]:
351360
"""
352361
Text Guard for scanning LLM inputs and outputs
353362
@@ -375,12 +384,12 @@ def guard_text(
375384
def guard_text(
376385
self,
377386
*,
378-
messages: _T,
387+
messages: Sequence[Message],
379388
recipe: str | None = None,
380389
debug: bool | None = None,
381390
overrides: Overrides | None = None,
382391
log_fields: LogFields | None = None,
383-
) -> PangeaResponse[TextGuardResult[_T]]:
392+
) -> PangeaResponse[TextGuardResult]:
384393
"""
385394
Text Guard for scanning LLM inputs and outputs
386395
@@ -402,19 +411,19 @@ def guard_text(
402411
log_field: Additional fields to include in activity log
403412
404413
Examples:
405-
response = ai_guard.guard_text(messages=[{"role": "user", "content": "hello world"}])
414+
response = ai_guard.guard_text(messages=[Message(role="user", content="hello world")])
406415
"""
407416

408-
def guard_text( # type: ignore[misc]
417+
def guard_text(
409418
self,
410419
text: str | None = None,
411420
*,
412-
messages: _T | None = None,
421+
messages: Sequence[Message] | None = None,
413422
recipe: str | None = None,
414423
debug: bool | None = None,
415424
overrides: Overrides | None = None,
416425
log_fields: LogFields | None = None,
417-
) -> PangeaResponse[TextGuardResult[None]]:
426+
) -> PangeaResponse[TextGuardResult]:
418427
"""
419428
Text Guard for scanning LLM inputs and outputs
420429

packages/pangea-sdk/tests/integration/asyncio/test_ai_guard.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from pangea import PangeaConfig
66
from pangea.asyncio.services import AIGuardAsync
7-
from pangea.services.ai_guard import LogFields
7+
from pangea.services.ai_guard import LogFields, Message
88
from pangea.tools import TestEnvironment, get_test_domain, get_test_token, logger_set_pangea_config
99
from tests.test_tools import load_test_environment
1010

@@ -42,7 +42,7 @@ async def test_text_guard(self) -> None:
4242

4343
async def test_text_guard_messages(self) -> None:
4444
response = await self.client.guard_text(
45-
messages=[{"role": "user", "content": "hello world"}], log_fields=LogFields(source="Acme Wizard")
45+
messages=[Message(role="user", content="hello world")], log_fields=LogFields(source="Acme Wizard")
4646
)
4747
assert response.status == "Success"
4848
assert response.result

packages/pangea-sdk/tests/integration/test_ai_guard.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from pangea import PangeaConfig
66
from pangea.services import AIGuard
7-
from pangea.services.ai_guard import LogFields
7+
from pangea.services.ai_guard import LogFields, Message
88
from pangea.tools import TestEnvironment, get_test_domain, get_test_token, logger_set_pangea_config
99
from tests.test_tools import load_test_environment
1010

@@ -35,7 +35,7 @@ def test_text_guard(self) -> None:
3535

3636
def test_text_guard_messages(self) -> None:
3737
response = self.client.guard_text(
38-
messages=[{"role": "user", "content": "hello world"}], log_fields=LogFields(source="Acme Wizard")
38+
messages=[Message(role="user", content="hello world")], log_fields=LogFields(source="Acme Wizard")
3939
)
4040
assert response.status == "Success"
4141
assert response.result

packages/pangea-sdk/tests/integration2/test_ai_guard.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from pangea import PangeaConfig
99
from pangea.asyncio.services.ai_guard import AIGuardAsync
1010
from pangea.services import AIGuard
11-
from pangea.services.ai_guard import LogFields, TextGuardResult
11+
from pangea.services.ai_guard import LogFields, Message, TextGuardResult
1212

1313
from ..utils import assert_matches_type
1414

@@ -34,7 +34,7 @@ def test_text_guard(self, client: AIGuard) -> None:
3434

3535
def test_text_guard_messages(self, client: AIGuard) -> None:
3636
response = client.guard_text(
37-
messages=[{"role": "user", "content": "hello world"}],
37+
messages=[Message(role="user", content="hello world")],
3838
debug=False,
3939
log_fields=LogFields(source="Acme Wizard"),
4040
)
@@ -52,7 +52,7 @@ async def test_text_guard(self, async_client: AIGuardAsync) -> None:
5252

5353
async def test_text_guard_messages(self, async_client: AIGuardAsync) -> None:
5454
response = await async_client.guard_text(
55-
messages=[{"role": "user", "content": "hello world"}],
55+
messages=[Message(role="user", content="hello world")],
5656
debug=False,
5757
log_fields=LogFields(source="Acme Wizard"),
5858
)

0 commit comments

Comments
 (0)