Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 45 additions & 13 deletions agentkit/apps/agent_server_app/agent_server_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import logging
from contextlib import asynccontextmanager
from typing import override

import uvicorn
import json
from fastapi import Request
from fastapi import HTTPException
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import StreamingResponse
from opentelemetry import trace
from google.adk.a2a.utils.agent_to_a2a import to_a2a
from google.adk.agents.base_agent import BaseAgent
from google.adk.agents.run_config import RunConfig, StreamingMode
from google.adk.artifacts.in_memory_artifact_service import (
InMemoryArtifactService,
)
Expand All @@ -35,15 +37,20 @@
from google.adk.evaluation.local_eval_sets_manager import LocalEvalSetsManager
from google.adk.memory.in_memory_memory_service import InMemoryMemoryService
from google.adk.sessions.base_session_service import BaseSessionService
from google.adk.agents.run_config import RunConfig, StreamingMode
from google.adk.utils.context_utils import Aclosing
from google.genai import types
from opentelemetry import trace
from veadk import Agent
from veadk.memory.short_term_memory import ShortTermMemory
from veadk.runner import Runner

from agentkit.apps.base_app import BaseAgentkitApp
from agentkit.apps.agent_server_app.middleware import (
AgentkitTelemetryHTTPMiddleware,
)
from agentkit.apps.agent_server_app.telemetry import telemetry
from agentkit.apps.agent_server_app.middleware import AgentkitTelemetryHTTPMiddleware
from agentkit.apps.base_app import BaseAgentkitApp

logger = logging.getLogger(__name__)


class AgentKitAgentLoader(BaseAgentLoader):
Expand All @@ -63,7 +70,9 @@ def list_agents(self) -> list[str]:

class AgentkitAgentServerApp(BaseAgentkitApp):
def __init__(
self, agent: BaseAgent, short_term_memory: BaseSessionService | ShortTermMemory
self,
agent: BaseAgent,
short_term_memory: BaseSessionService | ShortTermMemory,
) -> None:
super().__init__()

Expand All @@ -88,7 +97,22 @@ def __init__(
agents_dir=".",
)

self.app = self.server.get_fast_api_app()
runner = Runner(agent=agent)
_a2a_server_app = to_a2a(agent=agent, runner=runner)

@asynccontextmanager
async def lifespan(app: FastAPI):
# trigger A2A server app startup
logger.info(
"Triggering A2A server app startup within API server..."
)
for handler in _a2a_server_app.router.on_startup:
await handler()
yield

self.app = self.server.get_fast_api_app(lifespan=lifespan)

self.app.mount("/", _a2a_server_app)

# Attach ASGI middleware for unified telemetry across all routes
self.app.add_middleware(AgentkitTelemetryHTTPMiddleware)
Expand All @@ -100,14 +124,18 @@ async def _invoke_compat(request: Request):
# Extract headers (fallback keys supported)
headers = request.headers
user_id = (
headers.get("user_id") or headers.get("x-user-id") or "agentkit_user"
headers.get("user_id")
or headers.get("x-user-id")
or "agentkit_user"
)
session_id = headers.get("session_id") or ""

# Determine app_name from loader
app_names = self.server.agent_loader.list_agents()
if not app_names:
raise HTTPException(status_code=404, detail="No agents configured")
raise HTTPException(
status_code=404, detail="No agents configured"
)
app_name = app_names[0]

# Parse payload and convert to ADK Content
Expand Down Expand Up @@ -156,7 +184,9 @@ async def event_generator():
user_id=user_id,
session_id=session_id,
new_message=content,
run_config=RunConfig(streaming_mode=StreamingMode.SSE),
run_config=RunConfig(
streaming_mode=StreamingMode.SSE
),
)
) as agen:
async for event in agen:
Expand All @@ -171,7 +201,9 @@ async def event_generator():
pass
except Exception as e:
yield f'data: {{"error": "{str(e)}"}}\n\n'
telemetry.trace_agent_server_finish(func_result="", exception=e)
telemetry.trace_agent_server_finish(
func_result="", exception=e
)

return StreamingResponse(
event_generator(),
Expand Down