Skip to content

dev->main #1284

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Mar 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions agents-api/agents_api/clients/pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@


async def _init_conn(conn):
await conn.set_type_codec(
"jsonb",
encoder=json.dumps,
decoder=json.loads,
schema="pg_catalog",
)
for datatype in ["json", "jsonb"]:
await conn.set_type_codec(
datatype,
encoder=json.dumps,
decoder=json.loads,
schema="pg_catalog",
)


async def create_db_pool(dsn: str | None = None, **kwargs):
Expand Down
38 changes: 28 additions & 10 deletions agents-api/agents_api/common/utils/template.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import TypeVar

from beartype import beartype
from fastapi import HTTPException, status
from jinja2 import TemplateSyntaxError, UndefinedError
from jinja2.sandbox import ImmutableSandboxedEnvironment
from jinja2schema import infer, to_json_schema
from jsonschema import validate
Expand Down Expand Up @@ -34,16 +36,32 @@ async def render_template_string(
variables: dict,
check: bool = False,
) -> str:
# Parse template
template = jinja_env.from_string(template_string)

# If check is required, get required vars from template and validate variables
if check:
schema = to_json_schema(infer(template_string))
validate(instance=variables, schema=schema)

# Render
return await template.render_async(**variables)
try:
# Parse template
template = jinja_env.from_string(template_string)

# If check is required, get required vars from template and validate variables
if check:
schema = to_json_schema(infer(template_string))
validate(instance=variables, schema=schema)

# Render
return await template.render_async(**variables)
except TemplateSyntaxError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Template syntax error: {e!s}",
)
except UndefinedError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Template undefined variable: {e!s}",
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Template rendering error: {e!s}",
)


# A render function that can render arbitrarily nested lists of dicts
Expand Down
35 changes: 20 additions & 15 deletions agents-api/agents_api/queries/entries/get_history.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
from uuid import UUID

from beartype import beartype
Expand All @@ -10,7 +9,7 @@

# Define the raw SQL query for getting history with a developer check and relations
history_query = """
WITH entries AS (
WITH collected_entries AS (
SELECT
e.entry_id AS id,
e.session_id,
Expand All @@ -23,13 +22,16 @@
e.timestamp,
e.tool_calls,
e.tool_call_id,
e.tokenizer
e.tokenizer,
s.created_at AS session_created_at
FROM entries e
JOIN developers d ON d.developer_id = $3
WHERE e.session_id = $1
AND e.source = ANY($2)
JOIN sessions s ON s.session_id = e.session_id
AND s.session_id = $1
AND s.developer_id = $3
WHERE e.source = ANY($2)
AND e.created_at >= s.created_at
),
relations AS (
collected_relations AS (
SELECT
er.head,
er.relation,
Expand All @@ -38,30 +40,31 @@
WHERE er.session_id = $1
)
SELECT
(SELECT json_agg(e) FROM entries e) AS entries,
(SELECT json_agg(r) FROM relations r) AS relations,
(SELECT json_agg(e) FROM collected_entries e) AS entries,
(SELECT json_agg(r) FROM collected_relations r) AS relations,
(SELECT session_created_at FROM collected_entries) AS created_at,
$1::uuid AS session_id
"""


def _transform(d):
def _transform(row):
return {
"entries": [
{
**entry,
}
for entry in json.loads(d.get("entries") or "[]")
for entry in (row["entries"] or [])
],
"relations": [
{
"head": r["head"],
"relation": r["relation"],
"tail": r["tail"],
}
for r in (d.get("relations") or [])
for r in (row["relations"] or [])
],
"session_id": d.get("session_id"),
"created_at": utcnow(),
"session_id": row["session_id"],
"created_at": row["created_at"] or utcnow(),
}


Expand All @@ -79,7 +82,8 @@ async def get_history(
session_id: UUID,
allowed_sources: list[str] = ["api_request", "api_response"],
) -> tuple[str, list] | tuple[str, list, str]:
"""Get the history of a session.
"""
Get session history.

Parameters:
developer_id (UUID): The ID of the developer.
Expand All @@ -89,6 +93,7 @@ async def get_history(
Returns:
tuple[str, list] | tuple[str, list, str]: SQL query and parameters for getting the history.
"""

return (
history_query,
[session_id, allowed_sources, developer_id],
Expand Down
15 changes: 11 additions & 4 deletions agents-api/agents_api/queries/entries/list_entries.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from datetime import timedelta
from typing import Literal
from uuid import UUID

from beartype import beartype
from fastapi import HTTPException

from ...autogen.openapi_model import Entry
from ...common.utils.datetime import utcnow
from ...common.utils.db_exceptions import common_db_exceptions
from ...metrics.counters import query_metrics
from ..utils import pg_query, rewrap_exceptions, wrap_in_class
Expand Down Expand Up @@ -34,11 +36,13 @@
e.model,
e.tokenizer
FROM entries e
JOIN developers d ON d.developer_id = $5
LEFT JOIN entry_relations er ON er.head = e.entry_id AND er.session_id = e.session_id
JOIN developers d ON d.developer_id = $5
LEFT JOIN entry_relations er ON er.head = e.entry_id AND er.session_id = e.session_id
WHERE e.session_id = $1
AND e.source = ANY($2)
AND (er.relation IS NULL OR er.relation != ALL($6))
AND e.source = ANY($2)
AND (er.relation IS NULL OR er.relation != ALL($6))
AND e.created_at >= $7
AND e.created_at >= (select created_at from sessions where session_id = $1)
ORDER BY e.{sort_by} {direction} -- safe to interpolate
LIMIT $3
OFFSET $4;
Expand All @@ -60,6 +64,7 @@ async def list_entries(
sort_by: Literal["created_at", "timestamp"] = "timestamp",
direction: Literal["asc", "desc"] = "asc",
exclude_relations: list[str] = [],
search_window: timedelta = timedelta(weeks=4),
) -> list[tuple[str, list] | tuple[str, list, str]]:
"""List entries in a session.

Expand Down Expand Up @@ -94,7 +99,9 @@ async def list_entries(
offset, # $4
developer_id, # $5
exclude_relations, # $6
utcnow() - search_window, # 7
]

return [
(
session_exists_query,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
SELECT COUNT(*) FROM latest_executions
WHERE
developer_id = $1
AND created_at >= (select created_at from developers where developer_id = $1)
AND task_id = $2;
"""

Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from datetime import timedelta
from typing import Literal
from uuid import UUID

from beartype import beartype

from ...common.utils.datetime import utcnow
from ...common.utils.db_exceptions import common_db_exceptions
from ..utils import pg_query, rewrap_exceptions, wrap_in_class

Expand All @@ -22,6 +24,8 @@
FROM latest_transitions
WHERE
execution_id = $1
AND created_at >= $2
AND created_at >= (select created_at from executions where execution_id = $1)
AND type = 'wait';
"""

Expand All @@ -33,6 +37,7 @@
async def get_paused_execution_token(
*,
execution_id: UUID,
search_window: timedelta = timedelta(weeks=4),
) -> tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]:
"""
Get a paused execution token for a given execution.
Expand All @@ -47,6 +52,6 @@ async def get_paused_execution_token(

return (
get_paused_execution_token_query,
[execution_id],
[execution_id, utcnow() - search_window],
"fetchrow",
)
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from datetime import timedelta
from typing import Literal
from uuid import UUID

Expand All @@ -15,7 +16,9 @@
SELECT * FROM transitions
WHERE
execution_id = $1
AND (current_step).scope_id = $6
AND (current_step).scope_id = $7
AND created_at >= $6
AND created_at >= (select created_at from executions where execution_id = $1)
ORDER BY
CASE WHEN $4 = 'created_at' AND $5 = 'asc' THEN created_at END ASC NULLS LAST,
CASE WHEN $4 = 'created_at' AND $5 = 'desc' THEN created_at END DESC NULLS LAST
Expand Down Expand Up @@ -82,6 +85,7 @@ async def list_execution_transitions(
sort_by: Literal["created_at"] = "created_at",
direction: Literal["asc", "desc"] = "desc",
scope_id: UUID | None = None,
search_window: timedelta = timedelta(weeks=2),
) -> tuple[str, list]:
"""
List execution transitions for a given execution.
Expand Down Expand Up @@ -112,11 +116,12 @@ async def list_execution_transitions(
offset,
sort_by,
direction,
utcnow() - search_window,
]

query = list_execution_transitions_query
if scope_id is None:
query = query.replace("AND (current_step).scope_id = $6", "")
query = query.replace("AND (current_step).scope_id = $7", "")
else:
params.append(str(scope_id))

Expand Down
17 changes: 8 additions & 9 deletions agents-api/agents_api/queries/executions/list_executions.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,14 @@
FROM
latest_executions
WHERE
developer_id = $1 AND
task_id = $2
developer_id = $1
AND task_id = $2
ORDER BY
CASE WHEN $3 = 'asc' THEN created_at END ASC NULLS LAST,
CASE WHEN $3 = 'desc' THEN created_at END DESC NULLS LAST
-- Add this back once we update the view to support sorting by updated_at
-- CASE WHEN $3 = 'updated_at' AND $4 = 'asc' THEN updated_at END ASC NULLS LAST,
-- CASE WHEN $3 = 'updated_at' AND $4 = 'desc' THEN updated_at END DESC NULLS LAST
LIMIT $4 OFFSET $5;
CASE WHEN $3 = 'created_at' AND $4 = 'asc' THEN created_at END ASC NULLS LAST,
CASE WHEN $3 = 'created_at' AND $4 = 'desc' THEN created_at END DESC NULLS LAST,
CASE WHEN $3 = 'updated_at' AND $4 = 'asc' THEN updated_at END ASC NULLS LAST,
CASE WHEN $3 = 'updated_at' AND $4 = 'desc' THEN updated_at END DESC NULLS LAST
LIMIT $5 OFFSET $6;
"""


Expand Down Expand Up @@ -108,7 +107,7 @@ async def list_executions(
[
developer_id,
task_id,
# sort_by,
sort_by,
direction,
limit,
offset,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
SELECT to_jsonb(a) AS agent FROM (
SELECT * FROM agents
WHERE
developer_id = $1 AND
agent_id = (
developer_id = $1
AND agent_id = (
SELECT agent_id FROM tasks
WHERE developer_id = $1 AND task_id = $2
LIMIT 1
Expand All @@ -26,17 +26,17 @@
SELECT COALESCE(jsonb_agg(r), '[]'::jsonb) AS tools FROM (
SELECT * FROM tools
WHERE
developer_id = $1 AND
task_id = $2
developer_id = $1
AND task_id = $2
) r
) AS tools,
(
SELECT to_jsonb(e) AS execution FROM (
SELECT * FROM latest_executions
WHERE
developer_id = $1 AND
task_id = $2 AND
execution_id = $3
developer_id = $1
AND task_id = $2
AND execution_id = $3
LIMIT 1
) e
) AS execution;
Expand Down Expand Up @@ -87,11 +87,8 @@ async def prepare_execution_input(
Returns:
tuple[str, list]: SQL query and parameters for preparing the execution input.
"""

return (
prepare_execution_input_query,
[
str(developer_id),
str(task_id),
str(execution_id),
],
[str(developer_id), str(task_id), str(execution_id)],
)
2 changes: 2 additions & 0 deletions agents-api/tests/test_execution_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,8 @@ async def _(
result = await list_executions(
developer_id=developer_id,
task_id=task.id,
sort_by="updated_at",
direction="asc",
connection_pool=pool,
)

Expand Down
Loading