Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 88 additions & 29 deletions veadk/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import Union

from google.adk.agents import RunConfig
from google.adk.agents.invocation_context import LlmCallsLimitExceededError
from google.adk.agents.run_config import StreamingMode
from google.adk.plugins.base_plugin import BasePlugin
from google.adk.runners import Runner as ADKRunner
Expand Down Expand Up @@ -49,20 +50,25 @@ class Runner:
def __init__(
self,
agent: VeAgent,
short_term_memory: ShortTermMemory,
short_term_memory: ShortTermMemory | None = None,
plugins: list[BasePlugin] | None = None,
app_name: str = "veadk_default_app",
user_id: str = "veadk_default_user",
):
# basic settings
self.app_name = app_name
self.user_id = user_id

# agent settings
self.agent = agent

self.short_term_memory = short_term_memory
self.session_service = short_term_memory.session_service
if not short_term_memory:
logger.info(
"No short term memory provided, using a in-memory memory by default."
)
self.short_term_memory = ShortTermMemory()
else:
self.short_term_memory = short_term_memory

self.session_service = self.short_term_memory.session_service

# prevent VeRemoteAgent has no long-term memory attr
if isinstance(self.agent, Agent):
Expand Down Expand Up @@ -114,35 +120,44 @@ async def _run(
self,
session_id: str,
message: types.Content,
run_config: RunConfig | None = None,
stream: bool = False,
):
stream_mode = StreamingMode.SSE if stream else StreamingMode.NONE

async def event_generator():
async for event in self.runner.run_async(
user_id=self.user_id,
session_id=session_id,
new_message=message,
run_config=RunConfig(streaming_mode=stream_mode),
):
if event.get_function_calls():
for function_call in event.get_function_calls():
logger.debug(f"Function call: {function_call}")
elif (
event.content is not None
and event.content.parts
and event.content.parts[0].text is not None
and len(event.content.parts[0].text.strip()) > 0
):
yield event.content.parts[0].text
if run_config is not None:
stream_mode = run_config.streaming_mode
else:
run_config = RunConfig(streaming_mode=stream_mode)
try:

final_output = ""
async for chunk in event_generator():
async def event_generator():
async for event in self.runner.run_async(
user_id=self.user_id,
session_id=session_id,
new_message=message,
run_config=run_config,
):
if event.get_function_calls():
for function_call in event.get_function_calls():
logger.debug(f"Function call: {function_call}")
elif (
event.content is not None
and event.content.parts
and event.content.parts[0].text is not None
and len(event.content.parts[0].text.strip()) > 0
):
yield event.content.parts[0].text

final_output = ""
async for chunk in event_generator():
if stream:
print(chunk, end="", flush=True)
final_output += chunk
if stream:
print(chunk, end="", flush=True)
final_output += chunk
if stream:
print() # end with a new line
print() # end with a new line
except LlmCallsLimitExceededError as e:
logger.warning(f"Max number of llm calls limit exceeded: {e}")

return final_output

Expand All @@ -151,6 +166,7 @@ async def run(
messages: RunnerMessage,
session_id: str,
stream: bool = False,
run_config: RunConfig | None = None,
save_tracing_data: bool = False,
):
converted_messages: list = self._convert_messages(messages)
Expand All @@ -163,7 +179,9 @@ async def run(

final_output = ""
for converted_message in converted_messages:
final_output = await self._run(session_id, converted_message, stream)
final_output = await self._run(
session_id, converted_message, run_config, stream
)

# try to save tracing file
if save_tracing_data:
Expand Down Expand Up @@ -193,6 +211,47 @@ def get_trace_id(self) -> str:
logger.warning(f"Get tracer id failed as {e}")
return "<unknown_trace_id>"

async def run_with_raw_message(
self,
message: types.Content,
session_id: str,
run_config: RunConfig | None = None,
):
run_config = RunConfig() if not run_config else run_config

await self.short_term_memory.create_session(
app_name=self.app_name, user_id=self.user_id, session_id=session_id
)

try:

async def event_generator():
async for event in self.runner.run_async(
user_id=self.user_id,
session_id=session_id,
new_message=message,
run_config=run_config,
):
if event.get_function_calls():
for function_call in event.get_function_calls():
logger.debug(f"Function call: {function_call}")
elif (
event.content is not None
and event.content.parts
and event.content.parts[0].text is not None
and len(event.content.parts[0].text.strip()) > 0
):
yield event.content.parts[0].text

final_output = ""

async for chunk in event_generator():
final_output += chunk
except LlmCallsLimitExceededError as e:
logger.warning(f"Max number of llm calls limit exceeded: {e}")

return final_output

def _print_trace_id(self) -> None:
if not isinstance(self.agent, Agent):
logger.warning(
Expand Down