9
9
2. Method-style for direct tool access: `agent.tool.tool_name(param1="value")`
10
10
"""
11
11
12
- import asyncio
13
12
import json
14
13
import logging
15
14
import os
16
15
import random
17
16
from concurrent .futures import ThreadPoolExecutor
18
- from threading import Thread
19
- from typing import Any , AsyncIterator , Callable , Dict , List , Mapping , Optional , Type , TypeVar , Union
20
- from uuid import uuid4
17
+ from typing import Any , AsyncIterator , Callable , Generator , Mapping , Optional , Type , TypeVar , Union , cast
21
18
22
19
from opentelemetry import trace
23
20
from pydantic import BaseModel
24
21
25
22
from ..event_loop .event_loop import event_loop_cycle
26
- from ..handlers .callback_handler import CompositeCallbackHandler , PrintingCallbackHandler , null_callback_handler
23
+ from ..handlers .callback_handler import PrintingCallbackHandler , null_callback_handler
27
24
from ..handlers .tool_handler import AgentToolHandler
28
25
from ..models .bedrock import BedrockModel
29
26
from ..telemetry .metrics import EventLoopMetrics
@@ -183,7 +180,7 @@ def __init__(
183
180
self ,
184
181
model : Union [Model , str , None ] = None ,
185
182
messages : Optional [Messages ] = None ,
186
- tools : Optional [List [Union [str , Dict [str , str ], Any ]]] = None ,
183
+ tools : Optional [list [Union [str , dict [str , str ], Any ]]] = None ,
187
184
system_prompt : Optional [str ] = None ,
188
185
callback_handler : Optional [
189
186
Union [Callable [..., Any ], _DefaultCallbackHandlerSentinel ]
@@ -255,7 +252,7 @@ def __init__(
255
252
self .conversation_manager = conversation_manager if conversation_manager else SlidingWindowConversationManager ()
256
253
257
254
# Process trace attributes to ensure they're of compatible types
258
- self .trace_attributes : Dict [str , AttributeValue ] = {}
255
+ self .trace_attributes : dict [str , AttributeValue ] = {}
259
256
if trace_attributes :
260
257
for k , v in trace_attributes .items ():
261
258
if isinstance (v , (str , int , float , bool )) or (
@@ -312,7 +309,7 @@ def tool(self) -> ToolCaller:
312
309
return self .tool_caller
313
310
314
311
@property
315
- def tool_names (self ) -> List [str ]:
312
+ def tool_names (self ) -> list [str ]:
316
313
"""Get a list of all registered tool names.
317
314
318
315
Returns:
@@ -357,19 +354,25 @@ def __call__(self, prompt: str, **kwargs: Any) -> AgentResult:
357
354
- metrics: Performance metrics from the event loop
358
355
- state: The final state of the event loop
359
356
"""
357
+ callback_handler = kwargs .get ("callback_handler" , self .callback_handler )
358
+
360
359
self ._start_agent_trace_span (prompt )
361
360
362
361
try :
363
- # Run the event loop and get the result
364
- result = self ._run_loop (prompt , kwargs )
362
+ events = self ._run_loop (callback_handler , prompt , kwargs )
363
+ for event in events :
364
+ if "callback" in event :
365
+ callback_handler (** event ["callback" ])
366
+
367
+ stop_reason , message , metrics , state = event ["stop" ]
368
+ result = AgentResult (stop_reason , message , metrics , state )
365
369
366
370
self ._end_agent_trace_span (response = result )
367
371
368
372
return result
373
+
369
374
except Exception as e :
370
375
self ._end_agent_trace_span (error = e )
371
-
372
- # Re-raise the exception to preserve original behavior
373
376
raise
374
377
375
378
def structured_output (self , output_model : Type [T ], prompt : Optional [str ] = None ) -> T :
@@ -383,9 +386,9 @@ def structured_output(self, output_model: Type[T], prompt: Optional[str] = None)
383
386
instruct the model to output the structured data.
384
387
385
388
Args:
386
- output_model(Type[BaseModel]) : The output model (a JSON schema written as a Pydantic BaseModel)
389
+ output_model: The output model (a JSON schema written as a Pydantic BaseModel)
387
390
that the agent will use when responding.
388
- prompt(Optional[str]) : The prompt to use for the agent.
391
+ prompt: The prompt to use for the agent.
389
392
"""
390
393
messages = self .messages
391
394
if not messages and not prompt :
@@ -396,7 +399,12 @@ def structured_output(self, output_model: Type[T], prompt: Optional[str] = None)
396
399
messages .append ({"role" : "user" , "content" : [{"text" : prompt }]})
397
400
398
401
# get the structured output from the model
399
- return self .model .structured_output (output_model , messages , self .callback_handler )
402
+ events = self .model .structured_output (output_model , messages )
403
+ for event in events :
404
+ if "callback" in event :
405
+ self .callback_handler (** cast (dict , event ["callback" ]))
406
+
407
+ return event ["output" ]
400
408
401
409
async def stream_async (self , prompt : str , ** kwargs : Any ) -> AsyncIterator [Any ]:
402
410
"""Process a natural language prompt and yield events as an async iterator.
@@ -428,94 +436,63 @@ async def stream_async(self, prompt: str, **kwargs: Any) -> AsyncIterator[Any]:
428
436
yield event["data"]
429
437
```
430
438
"""
431
- self . _start_agent_trace_span ( prompt )
439
+ callback_handler = kwargs . get ( "callback_handler" , self . callback_handler )
432
440
433
- _stop_event = uuid4 ()
434
-
435
- queue = asyncio .Queue [Any ]()
436
- loop = asyncio .get_event_loop ()
437
-
438
- def enqueue (an_item : Any ) -> None :
439
- nonlocal queue
440
- nonlocal loop
441
- loop .call_soon_threadsafe (queue .put_nowait , an_item )
442
-
443
- def queuing_callback_handler (** handler_kwargs : Any ) -> None :
444
- enqueue (handler_kwargs .copy ())
441
+ self ._start_agent_trace_span (prompt )
445
442
446
- def target_callback () -> None :
447
- nonlocal kwargs
443
+ try :
444
+ events = self ._run_loop (callback_handler , prompt , kwargs )
445
+ for event in events :
446
+ if "callback" in event :
447
+ callback_handler (** event ["callback" ])
448
+ yield event ["callback" ]
448
449
449
- try :
450
- result = self ._run_loop (prompt , kwargs , supplementary_callback_handler = queuing_callback_handler )
451
- self ._end_agent_trace_span (response = result )
452
- except Exception as e :
453
- self ._end_agent_trace_span (error = e )
454
- enqueue (e )
455
- finally :
456
- enqueue (_stop_event )
450
+ stop_reason , message , metrics , state = event ["stop" ]
451
+ result = AgentResult (stop_reason , message , metrics , state )
457
452
458
- thread = Thread (target = target_callback , daemon = True )
459
- thread .start ()
453
+ self ._end_agent_trace_span (response = result )
460
454
461
- try :
462
- while True :
463
- item = await queue .get ()
464
- if item == _stop_event :
465
- break
466
- if isinstance (item , Exception ):
467
- raise item
468
- yield item
469
- finally :
470
- thread .join ()
455
+ except Exception as e :
456
+ self ._end_agent_trace_span (error = e )
457
+ raise
471
458
472
459
def _run_loop (
473
- self , prompt : str , kwargs : Dict [ str , Any ], supplementary_callback_handler : Optional [ Callable [... , Any ]] = None
474
- ) -> AgentResult :
460
+ self , callback_handler : Callable [... , Any ], prompt : str , kwargs : dict [ str , Any ]
461
+ ) -> Generator [ dict [ str , Any ], None , None ] :
475
462
"""Execute the agent's event loop with the given prompt and parameters."""
476
463
try :
477
- # If the call had a callback_handler passed in, then for this event_loop
478
- # cycle we call both handlers as the callback_handler
479
- invocation_callback_handler = (
480
- CompositeCallbackHandler (self .callback_handler , supplementary_callback_handler )
481
- if supplementary_callback_handler is not None
482
- else self .callback_handler
483
- )
484
-
485
464
# Extract key parameters
486
- invocation_callback_handler ( init_event_loop = True , ** kwargs )
465
+ yield { "callback" : { " init_event_loop" : True , ** kwargs }}
487
466
488
467
# Set up the user message with optional knowledge base retrieval
489
- message_content : List [ContentBlock ] = [{"text" : prompt }]
468
+ message_content : list [ContentBlock ] = [{"text" : prompt }]
490
469
new_message : Message = {"role" : "user" , "content" : message_content }
491
470
self .messages .append (new_message )
492
471
493
472
# Execute the event loop cycle with retry logic for context limits
494
- return self ._execute_event_loop_cycle (invocation_callback_handler , kwargs )
473
+ yield from self ._execute_event_loop_cycle (callback_handler , kwargs )
495
474
496
475
finally :
497
476
self .conversation_manager .apply_management (self )
498
477
499
- def _execute_event_loop_cycle (self , callback_handler : Callable [..., Any ], kwargs : dict [str , Any ]) -> AgentResult :
478
+ def _execute_event_loop_cycle (
479
+ self , callback_handler : Callable [..., Any ], kwargs : dict [str , Any ]
480
+ ) -> Generator [dict [str , Any ], None , None ]:
500
481
"""Execute the event loop cycle with retry logic for context window limits.
501
482
502
483
This internal method handles the execution of the event loop cycle and implements
503
484
retry logic for handling context window overflow exceptions by reducing the
504
485
conversation context and retrying.
505
486
506
- Args:
507
- callback_handler: The callback handler to use for events.
508
- kwargs: Additional parameters to pass through event loop.
509
-
510
- Returns:
511
- The result of the event loop cycle.
487
+ Yields:
488
+ Events of the loop cycle.
512
489
"""
513
490
# Add `Agent` to kwargs to keep backwards-compatibility
514
491
kwargs ["agent" ] = self
515
492
516
493
try :
517
494
# Execute the main event loop cycle
518
- events = event_loop_cycle (
495
+ yield from event_loop_cycle (
519
496
model = self .model ,
520
497
system_prompt = self .system_prompt ,
521
498
messages = self .messages , # will be modified by event_loop_cycle
@@ -527,19 +504,11 @@ def _execute_event_loop_cycle(self, callback_handler: Callable[..., Any], kwargs
527
504
event_loop_parent_span = self .trace_span ,
528
505
kwargs = kwargs ,
529
506
)
530
- for event in events :
531
- if "callback" in event :
532
- callback_handler (** event ["callback" ])
533
-
534
- stop_reason , message , metrics , state = event ["stop" ]
535
-
536
- return AgentResult (stop_reason , message , metrics , state )
537
507
538
508
except ContextWindowOverflowException as e :
539
509
# Try reducing the context size and retrying
540
-
541
510
self .conversation_manager .reduce_context (self , e = e )
542
- return self ._execute_event_loop_cycle (callback_handler , kwargs )
511
+ yield from self ._execute_event_loop_cycle (callback_handler_override , kwargs )
543
512
544
513
def _record_tool_execution (
545
514
self ,
@@ -625,7 +594,7 @@ def _end_agent_trace_span(
625
594
error: Error to record as a trace attribute.
626
595
"""
627
596
if self .trace_span :
628
- trace_attributes : Dict [str , Any ] = {
597
+ trace_attributes : dict [str , Any ] = {
629
598
"span" : self .trace_span ,
630
599
}
631
600
0 commit comments