Skip to content

Commit

Permalink
fix: Apply various small fixes to task execution logic
Browse files Browse the repository at this point in the history
  • Loading branch information
whiterabbit1983 committed Jul 23, 2024
1 parent ab0ec70 commit 277c168
Show file tree
Hide file tree
Showing 11 changed files with 160 additions and 156 deletions.
41 changes: 26 additions & 15 deletions agents-api/agents_api/activities/task_steps/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
import asyncio

# import celpy
from simpleeval import simple_eval
from openai.types.chat.chat_completion import ChatCompletion
from temporalio import activity
from uuid import uuid4

from ...autogen.openapi_model import (
PromptWorkflowStep,
# EvaluateWorkflowStep,
# YieldWorkflowStep,
YieldWorkflowStep,
# ToolCallWorkflowStep,
# ErrorWorkflowStep,
# IfElseWorkflowStep,
IfElseWorkflowStep,
InputChatMLMessage,
)

Expand Down Expand Up @@ -79,20 +80,24 @@ async def prompt_step(context: StepContext) -> dict:
# return {"result": result}


# @activity.defn
# async def yield_step(context: StepContext) -> dict:
# if not isinstance(context.definition, YieldWorkflowStep):
# return {}
@activity.defn
async def yield_step(context: StepContext) -> dict:
if not isinstance(context.definition, YieldWorkflowStep):
return {}

# # TODO: implement
# TODO: implement

# return {"test": "result"}
return {"test": "result"}


# @activity.defn
# async def tool_call_step(context: StepContext) -> dict:
# if not isinstance(context.definition, ToolCallWorkflowStep):
# return {}
# assert isinstance(context.definition, ToolCallWorkflowStep)

# context.definition.tool_id
# context.definition.arguments
# # get tool by id
# # call tool

# # TODO: implement

Expand All @@ -107,12 +112,18 @@ async def prompt_step(context: StepContext) -> dict:
# return {"error": context.definition.error}


# @activity.defn
# async def if_else_step(context: StepContext) -> dict:
# if not isinstance(context.definition, IfElseWorkflowStep):
# return {}
@activity.defn
async def if_else_step(context: StepContext) -> dict:
assert isinstance(context.definition, IfElseWorkflowStep)

# return {"test": "result"}
context_data: dict = context.model_dump()
next_workflow = (
context.definition.then
if simple_eval(context.definition.if_, names=context_data)
else context.definition.else_
)

return {"workflow": next_workflow}


@activity.defn
Expand Down
112 changes: 57 additions & 55 deletions agents-api/agents_api/autogen/openapi_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# generated by datamodel-codegen:
# filename: openapi.yaml
# timestamp: 2024-07-10T09:10:51+00:00
# timestamp: 2024-07-17T06:45:55+00:00

from __future__ import annotations

Expand Down Expand Up @@ -858,13 +858,25 @@ class ChatMLImageContentPart(BaseModel):
"""


class ToolResponse(BaseModel):
id: UUID
"""
Optional Tool ID
"""
output: Dict[str, Any]


class CELObject(BaseModel):
model_config = ConfigDict(
extra="allow",
)
workflow: str
arguments: CELObject
arguments: Dict[str, Any]


class YieldWorkflowStep(CELObject):
pass
class YieldWorkflowStep(BaseModel):
workflow: str
arguments: Dict[str, Any]


class ToolCallWorkflowStep(BaseModel):
Expand All @@ -885,61 +897,59 @@ class IfElseWorkflowStep(BaseModel):
else_: Annotated[YieldWorkflowStep, Field(alias="else")]


class CreateExecution(BaseModel):
task_id: UUID
arguments: Dict[str, Any]
"""
JSON Schema of parameters
class TransitionType(str, Enum):
"""


class ToolResponse(BaseModel):
id: UUID
"""
Optional Tool ID
"""
output: Dict[str, Any]


class Type3(str, Enum):
"""
Transition type
Execution Status
"""

finish = "finish"
wait = "wait"
error = "error"
step = "step"
cancelled = "cancelled"


class UpdateExecutionTransitionRequest(BaseModel):
class ExecutionStatus(str, Enum):
"""
Update execution transition request schema
Execution Status
"""

type: Type3
"""
Transition type
"""
from_: Annotated[List[str | int], Field(alias="from", max_length=2, min_length=2)]
queued = "queued"
starting = "starting"
running = "running"
awaiting_input = "awaiting_input"
succeeded = "succeeded"
failed = "failed"
cancelled = "cancelled"


class CreateExecution(BaseModel):
task_id: UUID
arguments: Dict[str, Any]
"""
From state
JSON Schema of parameters
"""
to: Annotated[List[str | int] | None, Field(None, max_length=2, min_length=2)]


class StopExecution(BaseModel):
status: Literal["cancelled"] = "cancelled"
"""
To state
Stop Execution Status
"""
output: Dict[str, Any]


class ResumeExecutionTransitionRequest(BaseModel):
"""
Execution output
Update execution transition request schema
"""
task_token: str | None = None

task_token: str
"""
Task token
"""
metadata: Dict[str, Any] | None = None
output: Dict[str, Any]
"""
Custom metadata
Output of the execution
"""


Expand Down Expand Up @@ -1175,34 +1185,26 @@ class PatchToolRequest(BaseModel):
class Execution(BaseModel):
id: UUID
task_id: UUID
created_at: UUID
created_at: AwareDatetime
arguments: Dict[str, Any]
"""
JSON Schema of parameters
"""
status: Annotated[
str,
Field(pattern="^(queued|starting|running|awaiting_input|succeeded|failed)$"),
]
"""
Execution Status
"""
status: ExecutionStatus


class ExecutionTransition(BaseModel):
id: UUID
execution_id: UUID
created_at: AwareDatetime
updated_at: AwareDatetime
outputs: Dict[str, Any]
"""
Outputs from an Execution Transition
"""
from_: Annotated[List[str | int], Field(alias="from")]
to: List[str | int]
type: Annotated[str, Field(pattern="^(finish|wait|error|step)$")]
"""
Execution Status
"""
current: List[str | int]
next: List[str | int]
type: TransitionType


class PromptWorkflowStep(BaseModel):
Expand Down Expand Up @@ -1259,6 +1261,9 @@ class Task(BaseModel):
Describes a Task
"""

model_config = ConfigDict(
extra="allow",
)
name: str
"""
Name of the Task
Expand Down Expand Up @@ -1291,8 +1296,5 @@ class Task(BaseModel):
ID of the Task
"""
created_at: AwareDatetime
updated_at: AwareDatetime
agent_id: UUID


CELObject.model_rebuild()
YieldWorkflowStep.model_rebuild()
2 changes: 2 additions & 0 deletions agents-api/agents_api/clients/temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
temporal_private_key,
)
from ..common.protocol.tasks import ExecutionInput
from ..worker.codec import pydantic_data_converter


async def get_client():
Expand All @@ -22,6 +23,7 @@ async def get_client():
temporal_worker_url,
namespace=temporal_namespace,
tls=tls_config,
data_converter=pydantic_data_converter,
)


Expand Down
4 changes: 2 additions & 2 deletions agents-api/agents_api/models/execution/create_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@ def create_execution_query(
execution_id: UUID,
session_id: UUID | None = None,
status: Literal[
"pending",
"queued",
"starting",
"running",
"awaiting_input",
"succeeded",
"failed",
] = "pending",
"cancelled",
] = "queued",
arguments: Dict[str, Any] = {},
) -> tuple[str, dict]:
# TODO: Check for agent in developer ID; Assert whether dev can access agent and by relation the task
Expand Down
17 changes: 12 additions & 5 deletions agents-api/agents_api/routers/agents/create_agent.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Annotated
from uuid import uuid4

from fastapi import Depends
from pydantic import UUID4
Expand All @@ -17,13 +18,19 @@ async def create_agent(
request: CreateAgentRequest,
x_developer_id: Annotated[UUID4, Depends(get_developer_id)],
) -> ResourceCreatedResponse:
agent_id = create_agent_query(
new_agent_id = uuid4()

resp = create_agent_query(
developer_id=x_developer_id,
agent_id=new_agent_id,
name=request.name,
about=request.about,
instructions=request.instructions,
instructions=request.instructions or [],
model=request.model,
default_settings=request.default_settings,
metadata=request.metadata,
default_settings=request.default_settings or {},
metadata=request.metadata or {},
)
return ResourceCreatedResponse(id=agent_id, created_at=utcnow())

resp.iterrows()

return ResourceCreatedResponse(id=new_agent_id, created_at=resp["created_at"])
12 changes: 9 additions & 3 deletions agents-api/agents_api/routers/tasks/routers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from typing import Annotated
from uuid import uuid4
from jsonschema import validate
Expand Down Expand Up @@ -34,7 +35,6 @@
ExecutionTransition,
ResourceCreatedResponse,
ResourceUpdatedResponse,
UpdateExecutionTransitionRequest,
CreateExecution,
)
from agents_api.dependencies.developer_id import get_developer_id
Expand All @@ -43,6 +43,10 @@
from agents_api.clients.cozo import client as cozo_client


logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)


class TaskList(BaseModel):
items: list[Task]

Expand Down Expand Up @@ -206,7 +210,9 @@ async def create_task_execution(
execution_input=execution_input,
job_id=uuid4(),
)
except Exception:
except Exception as e:
logger.exception(e)

update_execution_status_query(
task_id=task_id,
execution_id=execution_id,
Expand Down Expand Up @@ -276,7 +282,7 @@ async def get_execution_transition(
async def update_execution_transition(
execution_id: UUID4,
transition_id: UUID4,
request: UpdateExecutionTransitionRequest,
request: ExecutionTransition,
) -> ResourceUpdatedResponse:
try:
resp = update_execution_transition_query(
Expand Down
4 changes: 2 additions & 2 deletions agents-api/agents_api/worker/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from ..activities.task_steps import (
prompt_step,
# evaluate_step,
# yield_step,
yield_step,
# tool_call_step,
# error_step,
# if_else_step,
Expand Down Expand Up @@ -77,7 +77,7 @@ async def main():
task_activities = [
prompt_step,
# evaluate_step,
# yield_step,
yield_step,
# tool_call_step,
# error_step,
# if_else_step,
Expand Down
Loading

0 comments on commit 277c168

Please sign in to comment.