Skip to content

Commit 92995e6

Browse files
Update Claude agent connector layering
1 parent b6249cd commit 92995e6

File tree

22 files changed

+182
-117
lines changed

22 files changed

+182
-117
lines changed

python/packages/a2a/agent_framework_a2a/_agent.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
AgentResponse,
3030
AgentResponseUpdate,
3131
AgentThread,
32-
BareAgent,
32+
BaseAgent,
3333
ChatMessage,
3434
Content,
3535
ResponseStream,
@@ -58,12 +58,12 @@ def _get_uri_data(uri: str) -> str:
5858
return match.group("base64_data")
5959

6060

61-
class A2AAgent(AgentTelemetryLayer, BareAgent):
61+
class A2AAgent(AgentTelemetryLayer, BaseAgent):
6262
"""Agent2Agent (A2A) protocol implementation.
6363
6464
Wraps an A2A Client to connect the Agent Framework with external A2A-compliant agents
6565
via HTTP/JSON-RPC. Converts framework ChatMessages to A2A Messages on send, and converts
66-
A2A responses (Messages/Tasks) back to framework types. Inherits BareAgent capabilities
66+
A2A responses (Messages/Tasks) back to framework types. Inherits BaseAgent capabilities
6767
while managing the underlying A2A protocol communication.
6868
6969
Can be initialized with a URL, AgentCard, or existing A2A Client instance.
@@ -99,7 +99,7 @@ def __init__(
9999
timeout: Request timeout configuration. Can be a float (applied to all timeout components),
100100
httpx.Timeout object (for full control), or None (uses 10.0s connect, 60.0s read,
101101
10.0s write, 5.0s pool - optimized for A2A operations).
102-
kwargs: any additional properties, passed to BareAgent.
102+
kwargs: any additional properties, passed to BaseAgent.
103103
"""
104104
super().__init__(id=id, name=name, description=description, **kwargs)
105105
self._http_client: httpx.AsyncClient | None = http_client
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# Copyright (c) Microsoft. All rights reserved.

python/packages/claude/agent_framework_claude/_agent.py

Lines changed: 111 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22

33
import contextlib
44
import sys
5-
from collections.abc import AsyncIterable, Callable, MutableMapping, Sequence
5+
from collections.abc import AsyncIterable, Awaitable, Callable, MutableMapping, Sequence
66
from pathlib import Path
7-
from typing import TYPE_CHECKING, Any, ClassVar, Generic
7+
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, overload
88

99
from agent_framework import (
10-
AgentMiddlewareTypes,
10+
AgentMiddlewareLayer,
1111
AgentResponse,
1212
AgentResponseUpdate,
1313
AgentThread,
@@ -16,13 +16,16 @@
1616
Content,
1717
ContextProvider,
1818
FunctionTool,
19+
ResponseStream,
1920
Role,
2021
ToolProtocol,
2122
get_logger,
23+
merge_chat_options,
2224
normalize_messages,
25+
normalize_tools,
2326
)
24-
from agent_framework._types import normalize_tools
2527
from agent_framework.exceptions import ServiceException, ServiceInitializationError
28+
from agent_framework.observability import AgentTelemetryLayer
2629
from claude_agent_sdk import (
2730
ClaudeAgentOptions as SDKOptions,
2831
)
@@ -145,7 +148,7 @@ class ClaudeAgentOptions(TypedDict, total=False):
145148
)
146149

147150

148-
class ClaudeAgent(BaseAgent, Generic[TOptions]):
151+
class RawClaudeAgent(BaseAgent, Generic[TOptions]):
149152
"""Claude Agent using Claude Code CLI.
150153
151154
Wraps the Claude Agent SDK to provide agentic capabilities including
@@ -175,7 +178,7 @@ class ClaudeAgent(BaseAgent, Generic[TOptions]):
175178
.. code-block:: python
176179
177180
async with ClaudeAgent() as agent:
178-
async for update in agent.run_stream("Write a poem"):
181+
async for update in agent.run("Write a poem", stream=True):
179182
print(update.text, end="", flush=True)
180183
181184
With session management:
@@ -214,7 +217,6 @@ def __init__(
214217
name: str | None = None,
215218
description: str | None = None,
216219
context_provider: ContextProvider | None = None,
217-
middleware: Sequence[AgentMiddlewareTypes] | None = None,
218220
tools: ToolProtocol
219221
| Callable[..., Any]
220222
| MutableMapping[str, Any]
@@ -224,8 +226,9 @@ def __init__(
224226
default_options: TOptions | MutableMapping[str, Any] | None = None,
225227
env_file_path: str | None = None,
226228
env_file_encoding: str | None = None,
229+
**kwargs: Any,
227230
) -> None:
228-
"""Initialize a ClaudeAgent instance.
231+
"""Initialize a Claude agent instance.
229232
230233
Args:
231234
instructions: System prompt for the agent.
@@ -237,20 +240,20 @@ def __init__(
237240
name: Name of the agent.
238241
description: Description of the agent.
239242
context_provider: Context provider for the agent.
240-
middleware: List of middleware.
241243
tools: Tools for the agent. Can be:
242244
- Strings for built-in tools (e.g., "Read", "Write", "Bash", "Glob")
243245
- Functions or ToolProtocol instances for custom tools
244246
default_options: Default ClaudeAgentOptions including system_prompt, model, etc.
245247
env_file_path: Path to .env file.
246248
env_file_encoding: Encoding of .env file.
249+
kwargs: Additional keyword arguments passed to BaseAgent.
247250
"""
248251
super().__init__(
249252
id=id,
250253
name=name,
251254
description=description,
252255
context_provider=context_provider,
253-
middleware=middleware,
256+
**kwargs,
254257
)
255258

256259
self._client = client
@@ -295,6 +298,11 @@ def __init__(
295298
self._started = False
296299
self._current_session_id: str | None = None
297300

301+
@property
302+
def default_options(self) -> dict[str, Any]:
303+
"""Expose default options for telemetry and middleware layers."""
304+
return dict(self._default_options)
305+
298306
def _normalize_tools(
299307
self,
300308
tools: ToolProtocol
@@ -328,7 +336,7 @@ def _normalize_tools(
328336
normalized = normalize_tools(tool)
329337
self._custom_tools.extend(normalized)
330338

331-
async def __aenter__(self) -> "ClaudeAgent[TOptions]":
339+
async def __aenter__(self) -> "RawClaudeAgent[TOptions]":
332340
"""Start the agent when entering async context."""
333341
await self.start()
334342
return self
@@ -549,55 +557,80 @@ def _format_prompt(self, messages: list[ChatMessage] | None) -> str:
549557
return ""
550558
return "\n".join([msg.text or "" for msg in messages])
551559

552-
async def run(
560+
@overload
561+
def run(
553562
self,
554563
messages: str | ChatMessage | Sequence[str | ChatMessage] | None = None,
555564
*,
565+
stream: Literal[False] = ...,
556566
thread: AgentThread | None = None,
557-
options: TOptions | MutableMapping[str, Any] | None = None,
558567
**kwargs: Any,
559-
) -> AgentResponse[Any]:
560-
"""Run the agent with the given messages.
568+
) -> Awaitable[AgentResponse[Any]]: ...
561569

562-
Args:
563-
messages: The messages to process.
570+
@overload
571+
def run(
572+
self,
573+
messages: str | ChatMessage | Sequence[str | ChatMessage] | None = None,
574+
*,
575+
stream: Literal[True],
576+
thread: AgentThread | None = None,
577+
**kwargs: Any,
578+
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
564579

565-
Keyword Args:
566-
thread: The conversation thread. If thread has service_thread_id set,
567-
the agent will resume that session.
568-
options: Runtime options (model, permission_mode can be changed per-request).
569-
kwargs: Additional keyword arguments.
580+
def run(
581+
self,
582+
messages: str | ChatMessage | Sequence[str | ChatMessage] | None = None,
583+
*,
584+
stream: bool = False,
585+
thread: AgentThread | None = None,
586+
**kwargs: Any,
587+
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
588+
"""Run the agent with the given messages."""
589+
options = kwargs.pop("options", None)
590+
if stream:
591+
592+
def _finalize(updates: Sequence[AgentResponseUpdate]) -> AgentResponse[Any]:
593+
response = AgentResponse.from_agent_run_response_updates(updates)
594+
session_id = _get_session_id_from_updates(updates)
595+
if session_id and thread is not None:
596+
thread.service_thread_id = session_id
597+
return response
598+
599+
return ResponseStream(
600+
self._stream_updates(messages=messages, thread=thread, options=options, **kwargs),
601+
finalizer=_finalize,
602+
)
570603

571-
Returns:
572-
AgentResponse with the agent's response.
573-
"""
604+
return self._run_impl(messages=messages, thread=thread, options=options, **kwargs)
605+
606+
async def _run_impl(
607+
self,
608+
messages: str | ChatMessage | Sequence[str | ChatMessage] | None = None,
609+
*,
610+
thread: AgentThread | None = None,
611+
options: TOptions | MutableMapping[str, Any] | None = None,
612+
**kwargs: Any,
613+
) -> AgentResponse[Any]:
614+
"""Non-streaming implementation of run."""
574615
thread = thread or self.get_new_thread()
575-
return await AgentResponse.from_agent_response_generator(
576-
self.run_stream(messages, thread=thread, options=options, **kwargs)
577-
)
616+
updates: list[AgentResponseUpdate] = []
617+
async for update in self._stream_updates(messages=messages, thread=thread, options=options, **kwargs):
618+
updates.append(update)
619+
response = AgentResponse.from_agent_run_response_updates(updates)
620+
session_id = _get_session_id_from_updates(updates)
621+
if session_id:
622+
thread.service_thread_id = session_id
623+
return response
578624

579-
async def run_stream(
625+
async def _stream_updates(
580626
self,
581627
messages: str | ChatMessage | Sequence[str | ChatMessage] | None = None,
582628
*,
583629
thread: AgentThread | None = None,
584630
options: TOptions | MutableMapping[str, Any] | None = None,
585631
**kwargs: Any,
586632
) -> AsyncIterable[AgentResponseUpdate]:
587-
"""Stream the agent's response.
588-
589-
Args:
590-
messages: The messages to process.
591-
592-
Keyword Args:
593-
thread: The conversation thread. If thread has service_thread_id set,
594-
the agent will resume that session.
595-
options: Runtime options (model, permission_mode can be changed per-request).
596-
kwargs: Additional keyword arguments.
597-
598-
Yields:
599-
AgentResponseUpdate objects containing chunks of the response.
600-
"""
633+
"""Stream the agent's response updates."""
601634
thread = thread or self.get_new_thread()
602635

603636
# Ensure we're connected to the right session
@@ -606,12 +639,18 @@ async def run_stream(
606639
if not self._client:
607640
raise ServiceException("Claude SDK client not initialized.")
608641

642+
merged_options = merge_chat_options(
643+
{"instructions": self._default_options.get("system_prompt")},
644+
merge_chat_options(self._default_options, dict(options) if options else None),
645+
)
646+
runtime_options = dict(merged_options)
647+
runtime_options.pop("system_prompt", None)
648+
runtime_options.pop("instructions", None)
649+
609650
prompt = self._format_prompt(normalize_messages(messages))
610651

611652
# Apply runtime options (model, permission_mode)
612-
await self._apply_runtime_options(dict(options) if options else None)
613-
614-
session_id: str | None = None
653+
await self._apply_runtime_options(runtime_options if runtime_options else None)
615654

616655
await self._client.query(prompt)
617656
async for message in self._client.receive_response():
@@ -638,8 +677,29 @@ async def run_stream(
638677
raw_representation=message,
639678
)
640679
elif isinstance(message, ResultMessage):
641-
session_id = message.session_id
642-
643-
# Update thread with session ID
644-
if session_id:
645-
thread.service_thread_id = session_id
680+
if message.session_id:
681+
yield AgentResponseUpdate(
682+
role=Role.ASSISTANT,
683+
contents=[Content.from_text(text="", raw_representation=message)],
684+
raw_representation=message,
685+
)
686+
687+
688+
class ClaudeAgent( # type: ignore[misc]
689+
AgentTelemetryLayer,
690+
AgentMiddlewareLayer,
691+
RawClaudeAgent[TOptions],
692+
Generic[TOptions],
693+
):
694+
"""Claude agent with middleware and telemetry layers applied."""
695+
696+
pass
697+
698+
699+
def _get_session_id_from_updates(updates: Sequence[AgentResponseUpdate]) -> str | None:
700+
"""Extract session_id from ResultMessage entries in updates."""
701+
for update in updates:
702+
raw = update.raw_representation
703+
if isinstance(raw, ResultMessage):
704+
return raw.session_id
705+
return None

python/packages/claude/tests/test_claude_agent.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,7 @@ async def test_run_with_thread(self) -> None:
312312

313313

314314
class TestClaudeAgentRunStream:
315-
"""Tests for ClaudeAgent run_stream method."""
315+
"""Tests for ClaudeAgent streaming run method."""
316316

317317
@staticmethod
318318
async def _create_async_generator(items: list[Any]) -> Any:
@@ -332,7 +332,7 @@ def _create_mock_client(self, messages: list[Any]) -> MagicMock:
332332
return mock_client
333333

334334
async def test_run_stream_yields_updates(self) -> None:
335-
"""Test run_stream yields AgentResponseUpdate objects."""
335+
"""Test run(stream=True) yields AgentResponseUpdate objects."""
336336
from claude_agent_sdk import AssistantMessage, ResultMessage, TextBlock
337337
from claude_agent_sdk.types import StreamEvent
338338

@@ -371,10 +371,10 @@ async def test_run_stream_yields_updates(self) -> None:
371371
with patch("agent_framework_claude._agent.ClaudeSDKClient", return_value=mock_client):
372372
agent = ClaudeAgent()
373373
updates: list[AgentResponseUpdate] = []
374-
async for update in agent.run_stream("Hello"):
374+
async for update in agent.run("Hello", stream=True):
375375
updates.append(update)
376376
# StreamEvent yields text deltas
377-
assert len(updates) == 2
377+
assert len(updates) == 3
378378
assert updates[0].role == Role.ASSISTANT
379379
assert updates[0].text == "Streaming "
380380
assert updates[1].text == "response"

python/packages/copilotstudio/agent_framework_copilotstudio/_agent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
AgentResponse,
99
AgentResponseUpdate,
1010
AgentThread,
11-
BareAgent,
11+
BaseAgent,
1212
ChatMessage,
1313
Content,
1414
ContextProvider,
@@ -69,7 +69,7 @@ class CopilotStudioSettings(AFBaseSettings):
6969
tenantid: str | None = None
7070

7171

72-
class CopilotStudioAgent(BareAgent):
72+
class CopilotStudioAgent(BaseAgent):
7373
"""A Copilot Studio Agent."""
7474

7575
def __init__(

0 commit comments

Comments
 (0)