15
15
from ms_agent .tools import ToolManager
16
16
from ms_agent .utils import async_retry
17
17
from ms_agent .utils .logger import logger
18
- from omegaconf import DictConfig
18
+ from omegaconf import DictConfig , OmegaConf
19
19
20
20
from ..utils .utils import read_history , save_history
21
21
from .base import Agent
24
24
from .plan .utils import planer_mapping
25
25
from .runtime import Runtime
26
26
27
- DEFAULT_YAML = os .path .join (
28
- os .path .dirname (os .path .abspath (__file__ )), 'agent.yaml' )
29
-
30
27
31
28
class LLMAgent (Agent ):
32
29
"""
@@ -51,7 +48,7 @@ class LLMAgent(Agent):
51
48
DEFAULT_SYSTEM = 'You are a helpful assistant.'
52
49
53
50
def __init__ (self ,
54
- config_dir_or_id : Optional [str ] = DEFAULT_YAML ,
51
+ config_dir_or_id : Optional [str ] = None ,
55
52
config : Optional [DictConfig ] = None ,
56
53
env : Optional [Dict [str , str ]] = None ,
57
54
** kwargs ):
@@ -311,8 +308,9 @@ def _log_output(content: str, tag: str):
311
308
for _line in line .split ('\\ n' ):
312
309
logger .info (f'[{ tag } ] { _line } ' )
313
310
314
- @async_retry (max_attempts = 2 )
315
- async def _step (self , messages : List [Message ], tag : str ) -> List [Message ]:
311
+ @async_retry (max_attempts = 2 , delay = 1.0 )
312
+ async def _step (self , messages : List [Message ],
313
+ tag : str ) -> List [Message ]: # type: ignore
316
314
"""
317
315
Execute a single step in the agent's interaction loop.
318
316
@@ -348,12 +346,18 @@ async def _step(self, messages: List[Message], tag: str) -> List[Message]:
348
346
self .config .generation_config , 'stream' , False ):
349
347
self ._log_output ('[assistant]:' , tag = tag )
350
348
_content = ''
349
+ is_first = True
351
350
for _response_message in self ._handle_stream_message (
352
351
messages , tools = tools ):
352
+ if is_first :
353
+ messages .append (_response_message )
354
+ is_first = False
353
355
new_content = _response_message .content [len (_content ):]
354
356
sys .stdout .write (new_content )
355
357
sys .stdout .flush ()
356
358
_content = _response_message .content
359
+ messages [- 1 ] = _response_message
360
+ yield messages
357
361
sys .stdout .write ('\n ' )
358
362
else :
359
363
_response_message = self .llm .generate (messages , tools = tools )
@@ -384,7 +388,7 @@ async def _step(self, messages: List[Message], tag: str) -> List[Message]:
384
388
f'[usage] prompt_tokens: { _response_message .prompt_tokens } , '
385
389
f'completion_tokens: { _response_message .completion_tokens } ' ,
386
390
tag = tag )
387
- return messages
391
+ yield messages
388
392
389
393
def _prepare_llm (self ):
390
394
"""Initialize the LLM model from the configuration."""
@@ -443,13 +447,8 @@ def _save_history(self, messages: List[Message], **kwargs):
443
447
config = config ,
444
448
messages = messages )
445
449
446
- async def run (self , messages : Union [List [Message ], str ],
447
- ** kwargs ) -> List [Message ]:
448
- """
449
- Main method to execute the agent.
450
-
451
- Runs the agent loop, which includes generating responses,
452
- calling tools, and managing memory and planning.
450
+ async def _run (self , messages : Union [List [Message ], str ], ** kwargs ):
451
+ """Run the agent, mainly contains a llm calling and tool calling loop.
453
452
454
453
Args:
455
454
messages (Union[List[Message], str]): Input data for the agent. Can be a raw string prompt,
@@ -486,7 +485,9 @@ async def run(self, messages: Union[List[Message], str],
486
485
self ._log_output ('[' + message .role + ']:' , tag = self .tag )
487
486
self ._log_output (message .content , tag = self .tag )
488
487
while not self .runtime .should_stop :
489
- messages = await self ._step (messages , self .tag )
488
+ yield_step = self ._step (messages , self .tag )
489
+ async for messages in yield_step :
490
+ yield messages
490
491
self .runtime .round += 1
491
492
# +1 means the next round the assistant may give a conclusion
492
493
if self .runtime .round >= self .max_chat_round + 1 :
@@ -498,15 +499,35 @@ async def run(self, messages: Union[List[Message], str],
498
499
f'Task { messages [1 ].content } failed, max round({ self .max_chat_round } ) exceeded.'
499
500
))
500
501
self .runtime .should_stop = True
502
+ yield messages
501
503
# save history
502
504
self ._save_history (messages , ** kwargs )
503
505
504
506
await self ._loop_callback ('on_task_end' , messages )
505
507
await self ._cleanup_tools ()
506
- return messages
507
508
except Exception as e :
508
509
if hasattr (self .config , 'help' ):
509
510
logger .error (
510
511
f'[{ self .tag } ] Runtime error, please follow the instructions:\n \n { self .config .help } '
511
512
)
512
513
raise e
514
+
515
+ async def run (self , messages : Union [List [Message ], str ],
516
+ ** kwargs ) -> List [Message ]:
517
+ stream = kwargs .get ('stream' , False )
518
+ if stream :
519
+ OmegaConf .update (
520
+ self .config , 'generation_config.stream' , True , merge = True )
521
+
522
+ if stream :
523
+
524
+ async def stream_generator ():
525
+ async for chunk in self ._run (messages = messages , ** kwargs ):
526
+ yield chunk
527
+
528
+ return stream_generator ()
529
+ else :
530
+ res = None
531
+ async for chunk in self ._run (messages = messages , ** kwargs ):
532
+ res = chunk
533
+ return res
0 commit comments