Skip to content

Commit

Permalink
feat(agents-api): ALL TESTS PASS!! :D
Browse files Browse the repository at this point in the history
Signed-off-by: Diwank Tomer <diwank@julep.ai>
  • Loading branch information
Diwank Tomer committed Aug 17, 2024
1 parent 655d222 commit 0194c73
Show file tree
Hide file tree
Showing 10 changed files with 170 additions and 77 deletions.
13 changes: 12 additions & 1 deletion agents-api/agents_api/activities/demo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
from temporalio import activity

from ..env import testing


@activity.defn
async def demo_activity(a: int, b: int) -> int:
# Should throw an error if testing is not enabled
raise Exception("This should not be called in production")


async def mock_demo_activity(a: int, b: int) -> int:
return a + b


demo_activity = activity.defn(name="demo_activity")(
demo_activity if not testing else mock_demo_activity
)
12 changes: 11 additions & 1 deletion agents-api/agents_api/activities/embed_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@

from ..clients import embed as embedder
from ..clients.cozo import get_cozo_client
from ..env import testing
from ..models.docs.embed_snippets import embed_snippets as embed_snippets_query
from .types import EmbedDocsPayload


@activity.defn
@beartype
async def embed_docs(payload: EmbedDocsPayload, cozo_client=None) -> None:
indices, snippets = list(zip(*enumerate(payload.content)))
Expand All @@ -30,3 +30,13 @@ async def embed_docs(payload: EmbedDocsPayload, cozo_client=None) -> None:
embeddings=embeddings,
client=cozo_client or get_cozo_client(),
)


async def mock_embed_docs(payload: EmbedDocsPayload, cozo_client=None) -> None:
# Does nothing
return None


embed_docs = activity.defn(name="embed_docs")(
embed_docs if not testing else mock_embed_docs
)
11 changes: 10 additions & 1 deletion agents-api/agents_api/activities/task_steps/evaluate_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
StepContext,
StepOutcome,
)
from ...env import testing


@activity.defn
@beartype
async def evaluate_step(
context: StepContext[EvaluateStep],
Expand All @@ -20,3 +20,12 @@ async def evaluate_step(
output = simple_eval_dict(exprs, values=context.model_dump())

return StepOutcome(output=output)


# Note: This is here just for clarity. We could have just imported evaluate_step directly
# They do the same thing, so we dont need to mock the evaluate_step function
mock_evaluate_step = evaluate_step

evaluate_step = activity.defn(name="evaluate_step")(
evaluate_step if not testing else mock_evaluate_step
)
23 changes: 16 additions & 7 deletions agents-api/agents_api/activities/task_steps/transition_step.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,14 @@
from beartype import beartype
from temporalio import activity

from ...autogen.openapi_model import (
CreateTransitionRequest,
)
from ...common.protocol.tasks import (
StepContext,
)
from ...autogen.openapi_model import CreateTransitionRequest
from ...common.protocol.tasks import StepContext
from ...env import testing
from ...models.execution.create_execution_transition import (
create_execution_transition as create_execution_transition_query,
)


@activity.defn
@beartype
async def transition_step(
context: StepContext,
Expand All @@ -34,3 +30,16 @@ async def transition_step(
data=transition_info,
update_execution_status=True,
)


async def mock_transition_step(
context: StepContext,
transition_info: CreateTransitionRequest,
) -> None:
# Does nothing
return None


transition_step = activity.defn(name="transition_step")(
transition_step if not testing else mock_transition_step
)
20 changes: 12 additions & 8 deletions agents-api/agents_api/activities/task_steps/yield_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,12 @@

from agents_api.autogen.Executions import TransitionTarget

from ...autogen.openapi_model import (
YieldStep,
)
from ...common.protocol.tasks import (
StepContext,
StepOutcome,
)
from ...autogen.openapi_model import YieldStep
from ...common.protocol.tasks import StepContext, StepOutcome
from ...env import testing
from .utils import simple_eval_dict


@activity.defn
@beartype
async def yield_step(context: StepContext[YieldStep]) -> StepOutcome[dict[str, Any]]:
all_workflows = context.execution_input.task.workflows
Expand All @@ -36,3 +31,12 @@ async def yield_step(context: StepContext[YieldStep]) -> StepOutcome[dict[str, A
)

return StepOutcome(output=arguments, transition_to=("step", transition_target))


# Note: This is here just for clarity. We could have just imported yield_step directly
# They do the same thing, so we dont need to mock the yield_step function
mock_yield_step = yield_step

yield_step = activity.defn(name="yield_step")(
yield_step if not testing else mock_yield_step
)
10 changes: 9 additions & 1 deletion agents-api/agents_api/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,18 @@
temporal_worker_url=temporal_worker_url,
temporal_namespace=temporal_namespace,
embedding_model_id=embedding_model_id,
testing=testing,
)

if debug:
if debug or testing:
# Print the loaded environment variables for debugging purposes.
print("Environment variables:")
pprint(environment)
print()

# Yell if testing is enabled
print("@" * 80)
print(
f"@@@ Running in {'testing' if testing else 'debug'} mode. This should not be enabled in production. @@@"
)
print("@" * 80)
10 changes: 4 additions & 6 deletions agents-api/agents_api/routers/docs/create_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,6 @@ async def run_embed_docs_task(

client = client or (await temporal.get_client())

# TODO: Remove this conditional once we have a way to run workflows in
# a test environment.
if testing:
return None

embed_payload = EmbedDocsPayload(
developer_id=developer_id,
doc_id=doc_id,
Expand All @@ -49,7 +44,10 @@ async def run_embed_docs_task(
id=str(job_id),
)

background_tasks.add_task(handle.result)
# TODO: Remove this conditional once we have a way to run workflows in
# a test environment.
if not testing:
background_tasks.add_task(handle.result)

return handle

Expand Down
111 changes: 69 additions & 42 deletions agents-api/agents_api/routers/tasks/create_task_execution.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,89 @@
import logging
from typing import Annotated
from uuid import uuid4
from uuid import UUID, uuid4

from fastapi import Depends, HTTPException, status
from beartype import beartype
from fastapi import BackgroundTasks, Depends, HTTPException, status
from jsonschema import validate
from jsonschema.exceptions import ValidationError
from pycozo.client import QueryException
from pydantic import UUID4
from starlette.status import HTTP_201_CREATED
from temporalio.client import WorkflowHandle

from agents_api.autogen.openapi_model import (
from ...autogen.Executions import Execution
from ...autogen.openapi_model import (
CreateExecutionRequest,
ResourceCreatedResponse,
UpdateExecutionRequest,
)
from agents_api.clients.temporal import run_task_execution_workflow
from agents_api.dependencies.developer_id import get_developer_id
from agents_api.models.execution.create_execution import (
from ...clients.temporal import run_task_execution_workflow
from ...dependencies.developer_id import get_developer_id
from ...models.execution.create_execution import (
create_execution as create_execution_query,
)
from agents_api.models.execution.create_temporal_lookup import create_temporal_lookup
from agents_api.models.execution.prepare_execution_input import prepare_execution_input
from agents_api.models.execution.update_execution import (
from ...models.execution.create_temporal_lookup import create_temporal_lookup
from ...models.execution.prepare_execution_input import prepare_execution_input
from ...models.execution.update_execution import (
update_execution as update_execution_query,
)
from agents_api.models.task.get_task import get_task as get_task_query

from ...models.task.get_task import get_task as get_task_query
from .router import router

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)


@beartype
async def start_execution(
*,
developer_id: UUID,
task_id: UUID,
data: CreateExecutionRequest,
client=None,
) -> tuple[Execution, WorkflowHandle]:
execution_id = uuid4()

execution = create_execution_query(
developer_id=developer_id,
task_id=task_id,
execution_id=execution_id,
data=data,
client=client,
)

execution_input = prepare_execution_input(
developer_id=developer_id,
task_id=task_id,
execution_id=execution_id,
client=client,
)

try:
handle = await run_task_execution_workflow(
execution_input=execution_input,
job_id=uuid4(),
)

except Exception as e:
logger.exception(e)

update_execution_query(
developer_id=developer_id,
task_id=task_id,
execution_id=execution_id,
data=UpdateExecutionRequest(status="failed"),
client=client,
)

raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Execution creation failed",
) from e

return execution, handle


@router.post(
"/tasks/{task_id}/executions",
status_code=HTTP_201_CREATED,
Expand All @@ -41,6 +93,7 @@ async def create_task_execution(
task_id: UUID4,
data: CreateExecutionRequest,
x_developer_id: Annotated[UUID4, Depends(get_developer_id)],
background_tasks: BackgroundTasks,
) -> ResourceCreatedResponse:
try:
task = get_task_query(task_id=task_id, developer_id=x_developer_id)
Expand All @@ -60,44 +113,18 @@ async def create_task_execution(

raise

execution_id = uuid4()
execution = create_execution_query(
execution, handle = await start_execution(
developer_id=x_developer_id,
task_id=task_id,
execution_id=execution_id,
data=data,
)

execution_input = prepare_execution_input(
background_tasks.add_task(
create_temporal_lookup,
#
developer_id=x_developer_id,
task_id=task_id,
execution_id=execution_id,
)

try:
handle = await run_task_execution_workflow(
execution_input=execution_input,
job_id=uuid4(),
)
except Exception as e:
logger.exception(e)

update_execution_query(
developer_id=x_developer_id,
task_id=task_id,
execution_id=execution_id,
data=UpdateExecutionRequest(status="failed"),
)

raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Task creation failed",
)

create_temporal_lookup(
developer_id=x_developer_id,
task_id=task_id,
execution_id=execution_id,
execution_id=execution.id,
workflow_handle=handle,
)

Expand Down
10 changes: 0 additions & 10 deletions agents-api/tests/test_activities.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,26 +10,18 @@

from .fixtures import (
cozo_client,
patch_embed_acompletion,
test_developer_id,
test_doc,
)
from .utils import patch_testing_temporal

# from agents_api.activities.truncation import get_extra_entries
# from agents_api.autogen.openapi_model import Role
# from agents_api.common.protocol.entries import Entry


@test("activity: call direct embed_docs")
async def _(
cozo_client=cozo_client,
developer_id=test_developer_id,
doc=test_doc,
mocks=patch_embed_acompletion,
):
(embed, _) = mocks

title = "title"
content = ["content 1"]
include_title = True
Expand All @@ -46,8 +38,6 @@ async def _(
cozo_client,
)

embed.assert_called_once()


@test("activity: call demo workflow via temporal client")
async def _():
Expand Down
Loading

0 comments on commit 0194c73

Please sign in to comment.