22
33import contextlib
44import sys
5- from collections .abc import AsyncIterable , Callable , MutableMapping , Sequence
5+ from collections .abc import AsyncIterable , Awaitable , Callable , MutableMapping , Sequence
66from pathlib import Path
7- from typing import TYPE_CHECKING , Any , ClassVar , Generic
7+ from typing import TYPE_CHECKING , Any , ClassVar , Generic , Literal , overload
88
99from agent_framework import (
10- AgentMiddlewareTypes ,
10+ AgentMiddlewareLayer ,
1111 AgentResponse ,
1212 AgentResponseUpdate ,
1313 AgentThread ,
1616 Content ,
1717 ContextProvider ,
1818 FunctionTool ,
19+ ResponseStream ,
1920 Role ,
2021 ToolProtocol ,
2122 get_logger ,
23+ merge_chat_options ,
2224 normalize_messages ,
25+ normalize_tools ,
2326)
24- from agent_framework ._types import normalize_tools
2527from agent_framework .exceptions import ServiceException , ServiceInitializationError
28+ from agent_framework .observability import AgentTelemetryLayer
2629from claude_agent_sdk import (
2730 ClaudeAgentOptions as SDKOptions ,
2831)
@@ -145,7 +148,7 @@ class ClaudeAgentOptions(TypedDict, total=False):
145148)
146149
147150
148- class ClaudeAgent (BaseAgent , Generic [TOptions ]):
151+ class RawClaudeAgent (BaseAgent , Generic [TOptions ]):
149152 """Claude Agent using Claude Code CLI.
150153
151154 Wraps the Claude Agent SDK to provide agentic capabilities including
@@ -175,7 +178,7 @@ class ClaudeAgent(BaseAgent, Generic[TOptions]):
175178 .. code-block:: python
176179
177180 async with ClaudeAgent() as agent:
178- async for update in agent.run_stream ("Write a poem"):
181+ async for update in agent.run ("Write a poem", stream=True ):
179182 print(update.text, end="", flush=True)
180183
181184 With session management:
@@ -214,7 +217,6 @@ def __init__(
214217 name : str | None = None ,
215218 description : str | None = None ,
216219 context_provider : ContextProvider | None = None ,
217- middleware : Sequence [AgentMiddlewareTypes ] | None = None ,
218220 tools : ToolProtocol
219221 | Callable [..., Any ]
220222 | MutableMapping [str , Any ]
@@ -224,8 +226,9 @@ def __init__(
224226 default_options : TOptions | MutableMapping [str , Any ] | None = None ,
225227 env_file_path : str | None = None ,
226228 env_file_encoding : str | None = None ,
229+ ** kwargs : Any ,
227230 ) -> None :
228- """Initialize a ClaudeAgent instance.
231+ """Initialize a Claude agent instance.
229232
230233 Args:
231234 instructions: System prompt for the agent.
@@ -237,20 +240,20 @@ def __init__(
237240 name: Name of the agent.
238241 description: Description of the agent.
239242 context_provider: Context provider for the agent.
240- middleware: List of middleware.
241243 tools: Tools for the agent. Can be:
242244 - Strings for built-in tools (e.g., "Read", "Write", "Bash", "Glob")
243245 - Functions or ToolProtocol instances for custom tools
244246 default_options: Default ClaudeAgentOptions including system_prompt, model, etc.
245247 env_file_path: Path to .env file.
246248 env_file_encoding: Encoding of .env file.
249+ kwargs: Additional keyword arguments passed to BaseAgent.
247250 """
248251 super ().__init__ (
249252 id = id ,
250253 name = name ,
251254 description = description ,
252255 context_provider = context_provider ,
253- middleware = middleware ,
256+ ** kwargs ,
254257 )
255258
256259 self ._client = client
@@ -295,6 +298,11 @@ def __init__(
295298 self ._started = False
296299 self ._current_session_id : str | None = None
297300
301+ @property
302+ def default_options (self ) -> dict [str , Any ]:
303+ """Expose default options for telemetry and middleware layers."""
304+ return dict (self ._default_options )
305+
298306 def _normalize_tools (
299307 self ,
300308 tools : ToolProtocol
@@ -328,7 +336,7 @@ def _normalize_tools(
328336 normalized = normalize_tools (tool )
329337 self ._custom_tools .extend (normalized )
330338
331- async def __aenter__ (self ) -> "ClaudeAgent [TOptions]" :
339+ async def __aenter__ (self ) -> "RawClaudeAgent [TOptions]" :
332340 """Start the agent when entering async context."""
333341 await self .start ()
334342 return self
@@ -549,55 +557,80 @@ def _format_prompt(self, messages: list[ChatMessage] | None) -> str:
549557 return ""
550558 return "\n " .join ([msg .text or "" for msg in messages ])
551559
552- async def run (
560+ @overload
561+ def run (
553562 self ,
554563 messages : str | ChatMessage | Sequence [str | ChatMessage ] | None = None ,
555564 * ,
565+ stream : Literal [False ] = ...,
556566 thread : AgentThread | None = None ,
557- options : TOptions | MutableMapping [str , Any ] | None = None ,
558567 ** kwargs : Any ,
559- ) -> AgentResponse [Any ]:
560- """Run the agent with the given messages.
568+ ) -> Awaitable [AgentResponse [Any ]]: ...
561569
562- Args:
563- messages: The messages to process.
570+ @overload
571+ def run (
572+ self ,
573+ messages : str | ChatMessage | Sequence [str | ChatMessage ] | None = None ,
574+ * ,
575+ stream : Literal [True ],
576+ thread : AgentThread | None = None ,
577+ ** kwargs : Any ,
578+ ) -> ResponseStream [AgentResponseUpdate , AgentResponse [Any ]]: ...
564579
565- Keyword Args:
566- thread: The conversation thread. If thread has service_thread_id set,
567- the agent will resume that session.
568- options: Runtime options (model, permission_mode can be changed per-request).
569- kwargs: Additional keyword arguments.
580+ def run (
581+ self ,
582+ messages : str | ChatMessage | Sequence [str | ChatMessage ] | None = None ,
583+ * ,
584+ stream : bool = False ,
585+ thread : AgentThread | None = None ,
586+ ** kwargs : Any ,
587+ ) -> Awaitable [AgentResponse [Any ]] | ResponseStream [AgentResponseUpdate , AgentResponse [Any ]]:
588+ """Run the agent with the given messages."""
589+ options = kwargs .pop ("options" , None )
590+ if stream :
591+
592+ def _finalize (updates : Sequence [AgentResponseUpdate ]) -> AgentResponse [Any ]:
593+ response = AgentResponse .from_agent_run_response_updates (updates )
594+ session_id = _get_session_id_from_updates (updates )
595+ if session_id and thread is not None :
596+ thread .service_thread_id = session_id
597+ return response
598+
599+ return ResponseStream (
600+ self ._stream_updates (messages = messages , thread = thread , options = options , ** kwargs ),
601+ finalizer = _finalize ,
602+ )
570603
571- Returns:
572- AgentResponse with the agent's response.
573- """
604+ return self ._run_impl (messages = messages , thread = thread , options = options , ** kwargs )
605+
606+ async def _run_impl (
607+ self ,
608+ messages : str | ChatMessage | Sequence [str | ChatMessage ] | None = None ,
609+ * ,
610+ thread : AgentThread | None = None ,
611+ options : TOptions | MutableMapping [str , Any ] | None = None ,
612+ ** kwargs : Any ,
613+ ) -> AgentResponse [Any ]:
614+ """Non-streaming implementation of run."""
574615 thread = thread or self .get_new_thread ()
575- return await AgentResponse .from_agent_response_generator (
576- self .run_stream (messages , thread = thread , options = options , ** kwargs )
577- )
616+ updates : list [AgentResponseUpdate ] = []
617+ async for update in self ._stream_updates (messages = messages , thread = thread , options = options , ** kwargs ):
618+ updates .append (update )
619+ response = AgentResponse .from_agent_run_response_updates (updates )
620+ session_id = _get_session_id_from_updates (updates )
621+ if session_id :
622+ thread .service_thread_id = session_id
623+ return response
578624
579- async def run_stream (
625+ async def _stream_updates (
580626 self ,
581627 messages : str | ChatMessage | Sequence [str | ChatMessage ] | None = None ,
582628 * ,
583629 thread : AgentThread | None = None ,
584630 options : TOptions | MutableMapping [str , Any ] | None = None ,
585631 ** kwargs : Any ,
586632 ) -> AsyncIterable [AgentResponseUpdate ]:
587- """Stream the agent's response.
588-
589- Args:
590- messages: The messages to process.
591-
592- Keyword Args:
593- thread: The conversation thread. If thread has service_thread_id set,
594- the agent will resume that session.
595- options: Runtime options (model, permission_mode can be changed per-request).
596- kwargs: Additional keyword arguments.
597-
598- Yields:
599- AgentResponseUpdate objects containing chunks of the response.
600- """
633+ """Stream the agent's response updates."""
601634 thread = thread or self .get_new_thread ()
602635
603636 # Ensure we're connected to the right session
@@ -606,12 +639,18 @@ async def run_stream(
606639 if not self ._client :
607640 raise ServiceException ("Claude SDK client not initialized." )
608641
642+ merged_options = merge_chat_options (
643+ {"instructions" : self ._default_options .get ("system_prompt" )},
644+ merge_chat_options (self ._default_options , dict (options ) if options else None ),
645+ )
646+ runtime_options = dict (merged_options )
647+ runtime_options .pop ("system_prompt" , None )
648+ runtime_options .pop ("instructions" , None )
649+
609650 prompt = self ._format_prompt (normalize_messages (messages ))
610651
611652 # Apply runtime options (model, permission_mode)
612- await self ._apply_runtime_options (dict (options ) if options else None )
613-
614- session_id : str | None = None
653+ await self ._apply_runtime_options (runtime_options if runtime_options else None )
615654
616655 await self ._client .query (prompt )
617656 async for message in self ._client .receive_response ():
@@ -638,8 +677,29 @@ async def run_stream(
638677 raw_representation = message ,
639678 )
640679 elif isinstance (message , ResultMessage ):
641- session_id = message .session_id
642-
643- # Update thread with session ID
644- if session_id :
645- thread .service_thread_id = session_id
680+ if message .session_id :
681+ yield AgentResponseUpdate (
682+ role = Role .ASSISTANT ,
683+ contents = [Content .from_text (text = "" , raw_representation = message )],
684+ raw_representation = message ,
685+ )
686+
687+
688+ class ClaudeAgent ( # type: ignore[misc]
689+ AgentTelemetryLayer ,
690+ AgentMiddlewareLayer ,
691+ RawClaudeAgent [TOptions ],
692+ Generic [TOptions ],
693+ ):
694+ """Claude agent with middleware and telemetry layers applied."""
695+
696+ pass
697+
698+
699+ def _get_session_id_from_updates (updates : Sequence [AgentResponseUpdate ]) -> str | None :
700+ """Extract session_id from ResultMessage entries in updates."""
701+ for update in updates :
702+ raw = update .raw_representation
703+ if isinstance (raw , ResultMessage ):
704+ return raw .session_id
705+ return None
0 commit comments