Skip to content

Commit 9e5906c

Browse files
feat: add pydantic validation dunder to BaseInterceptor
1 parent fc52183 commit 9e5906c

File tree

8 files changed

+1258
-1723
lines changed

8 files changed

+1258
-1723
lines changed

lib/crewai/src/crewai/llms/hooks/base.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,14 @@
77
from __future__ import annotations
88

99
from abc import ABC, abstractmethod
10-
from typing import Generic, TypeVar
10+
from typing import TYPE_CHECKING, Any, Generic, TypeVar
11+
12+
from pydantic_core import core_schema
13+
14+
15+
if TYPE_CHECKING:
16+
from pydantic import GetCoreSchemaHandler
17+
from pydantic_core import CoreSchema
1118

1219

1320
T = TypeVar("T")
@@ -25,6 +32,7 @@ class BaseInterceptor(ABC, Generic[T, U]):
2532
U: Inbound message type (e.g., httpx.Response)
2633
2734
Example:
35+
>>> import httpx
2836
>>> class CustomInterceptor(BaseInterceptor[httpx.Request, httpx.Response]):
2937
... def on_outbound(self, message: httpx.Request) -> httpx.Request:
3038
... message.headers["X-Custom-Header"] = "value"
@@ -80,3 +88,46 @@ async def aon_inbound(self, message: U) -> U:
8088
Modified message object.
8189
"""
8290
raise NotImplementedError
91+
92+
@classmethod
93+
def __get_pydantic_core_schema__(
94+
cls, _source_type: Any, _handler: GetCoreSchemaHandler
95+
) -> CoreSchema:
96+
"""Generate Pydantic core schema for BaseInterceptor.
97+
98+
This allows the generic BaseInterceptor to be used in Pydantic models
99+
without requiring arbitrary_types_allowed=True. The schema validates
100+
that the value is an instance of BaseInterceptor.
101+
102+
Args:
103+
_source_type: The source type being validated (unused).
104+
_handler: Handler for generating schemas (unused).
105+
106+
Returns:
107+
A Pydantic core schema that validates BaseInterceptor instances.
108+
"""
109+
return core_schema.no_info_plain_validator_function(
110+
_validate_interceptor,
111+
serialization=core_schema.plain_serializer_function_ser_schema(
112+
lambda x: x, return_schema=core_schema.any_schema()
113+
),
114+
)
115+
116+
117+
def _validate_interceptor(value: Any) -> BaseInterceptor[T, U]:
118+
"""Validate that the value is a BaseInterceptor instance.
119+
120+
Args:
121+
value: The value to validate.
122+
123+
Returns:
124+
The validated BaseInterceptor instance.
125+
126+
Raises:
127+
ValueError: If the value is not a BaseInterceptor instance.
128+
"""
129+
if not isinstance(value, BaseInterceptor):
130+
raise ValueError(
131+
f"Expected BaseInterceptor instance, got {type(value).__name__}"
132+
)
133+
return value

lib/crewai/src/crewai/llms/hooks/transport.py

Lines changed: 49 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,53 @@
66

77
from __future__ import annotations
88

9-
from typing import TYPE_CHECKING, Any
9+
from collections.abc import Iterable
10+
from typing import TYPE_CHECKING, TypedDict
1011

11-
import httpx
12+
from httpx import (
13+
AsyncHTTPTransport as _AsyncHTTPTransport,
14+
HTTPTransport as _HTTPTransport,
15+
)
16+
from typing_extensions import NotRequired, Unpack
1217

1318

1419
if TYPE_CHECKING:
20+
from ssl import SSLContext
21+
22+
from httpx import Limits, Request, Response
23+
from httpx._types import CertTypes, ProxyTypes
24+
1525
from crewai.llms.hooks.base import BaseInterceptor
1626

1727

18-
class HTTPTransport(httpx.HTTPTransport):
28+
class HTTPTransportKwargs(TypedDict):
29+
"""Typed dictionary for httpx.HTTPTransport initialization parameters.
30+
31+
These parameters configure the underlying HTTP transport behavior including
32+
SSL verification, proxies, connection limits, and low-level socket options.
33+
"""
34+
35+
verify: bool | str | SSLContext
36+
cert: NotRequired[CertTypes | None]
37+
trust_env: bool
38+
http1: bool
39+
http2: bool
40+
limits: Limits
41+
proxy: NotRequired[ProxyTypes | None]
42+
uds: NotRequired[str | None]
43+
local_address: NotRequired[str | None]
44+
retries: int
45+
socket_options: NotRequired[
46+
Iterable[
47+
tuple[int, int, int]
48+
| tuple[int, int, bytes | bytearray]
49+
| tuple[int, int, None, int]
50+
]
51+
| None
52+
]
53+
54+
55+
class HTTPTransport(_HTTPTransport):
1956
"""HTTP transport that uses an interceptor for request/response modification.
2057
2158
This transport is used internally when a user provides a BaseInterceptor.
@@ -25,19 +62,19 @@ class HTTPTransport(httpx.HTTPTransport):
2562

2663
def __init__(
2764
self,
28-
interceptor: BaseInterceptor[httpx.Request, httpx.Response],
29-
**kwargs: Any,
65+
interceptor: BaseInterceptor[Request, Response],
66+
**kwargs: Unpack[HTTPTransportKwargs],
3067
) -> None:
3168
"""Initialize transport with interceptor.
3269
3370
Args:
3471
interceptor: HTTP interceptor for modifying raw request/response objects.
35-
**kwargs: Additional arguments passed to httpx.HTTPTransport.
72+
**kwargs: HTTPTransport configuration parameters (verify, cert, proxy, etc.).
3673
"""
3774
super().__init__(**kwargs)
3875
self.interceptor = interceptor
3976

40-
def handle_request(self, request: httpx.Request) -> httpx.Response:
77+
def handle_request(self, request: Request) -> Response:
4178
"""Handle request with interception.
4279
4380
Args:
@@ -51,7 +88,7 @@ def handle_request(self, request: httpx.Request) -> httpx.Response:
5188
return self.interceptor.on_inbound(response)
5289

5390

54-
class AsyncHTTPransport(httpx.AsyncHTTPTransport):
91+
class AsyncHTTPTransport(_AsyncHTTPTransport):
5592
"""Async HTTP transport that uses an interceptor for request/response modification.
5693
5794
This transport is used internally when a user provides a BaseInterceptor.
@@ -61,19 +98,19 @@ class AsyncHTTPransport(httpx.AsyncHTTPTransport):
6198

6299
def __init__(
63100
self,
64-
interceptor: BaseInterceptor[httpx.Request, httpx.Response],
65-
**kwargs: Any,
101+
interceptor: BaseInterceptor[Request, Response],
102+
**kwargs: Unpack[HTTPTransportKwargs],
66103
) -> None:
67104
"""Initialize async transport with interceptor.
68105
69106
Args:
70107
interceptor: HTTP interceptor for modifying raw request/response objects.
71-
**kwargs: Additional arguments passed to httpx.AsyncHTTPTransport.
108+
**kwargs: HTTPTransport configuration parameters (verify, cert, proxy, etc.).
72109
"""
73110
super().__init__(**kwargs)
74111
self.interceptor = interceptor
75112

76-
async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
113+
async def handle_async_request(self, request: Request) -> Response:
77114
"""Handle async request with interception.
78115
79116
Args:

lib/crewai/tests/agents/test_agent.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2117,15 +2117,14 @@ def test_agent_with_only_crewai_knowledge():
21172117
goal="Provide information based on knowledge sources",
21182118
backstory="You have access to specific knowledge sources.",
21192119
llm=LLM(
2120-
model="openrouter/openai/gpt-4o-mini",
2121-
api_key=os.getenv("OPENROUTER_API_KEY"),
2120+
model="gpt-4o-mini",
21222121
),
21232122
)
21242123

21252124
# Create a task that requires the agent to use the knowledge
21262125
task = Task(
21272126
description="What is Vidit's favorite color?",
2128-
expected_output="Vidit's favorclearite color.",
2127+
expected_output="Vidit's favorite color.",
21292128
agent=agent,
21302129
)
21312130

0 commit comments

Comments
 (0)