Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(agents-api): Fix map reduce step and activity #466

Merged
merged 1 commit into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions agents-api/agents_api/activities/task_steps/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# ruff: noqa: F401, F403, F405

from .base_evaluate import base_evaluate
from .evaluate_step import evaluate_step
from .for_each_step import for_each_step
from .if_else_step import if_else_step
Expand Down
44 changes: 44 additions & 0 deletions agents-api/agents_api/activities/task_steps/base_evaluate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from typing import Any

from beartype import beartype
from temporalio import activity

from ...env import testing
from ..utils import get_evaluator


@beartype
async def base_evaluate(
exprs: str | list[str] | dict[str, str],
values: dict[str, Any] = {},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using a mutable default argument (like a dictionary) is not recommended as it can lead to unexpected behavior. Consider using None as the default value and initializing the dictionary inside the function.

Suggested change
values: dict[str, Any] = {},
values: dict[str, Any] = None,

) -> Any | list[Any] | dict[str, Any]:
input_len = 1 if isinstance(exprs, str) else len(exprs)
assert input_len > 0, "exprs must be a non-empty string, list or dict"

evaluator = get_evaluator(names=values)

try:
match exprs:
case str():
return evaluator.eval(exprs)

case list():
return [evaluator.eval(expr) for expr in exprs]

case dict():
return {k: evaluator.eval(v) for k, v in exprs.items()}

except BaseException as e:
if activity.in_activity():
activity.logger.error(f"Error in base_evaluate: {e}")

raise


# Note: This is here just for clarity. We could have just imported base_evaluate directly
# They do the same thing, so we dont need to mock the base_evaluate function
mock_base_evaluate = base_evaluate

base_evaluate = activity.defn(name="base_evaluate")(
base_evaluate if not testing else mock_base_evaluate
)
29 changes: 18 additions & 11 deletions agents-api/agents_api/activities/task_steps/evaluate_step.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,35 @@
from typing import Any

from beartype import beartype
from temporalio import activity

from ...activities.task_steps.utils import simple_eval_dict
from ...autogen.openapi_model import EvaluateStep
from ...activities.utils import simple_eval_dict
from ...common.protocol.tasks import StepContext, StepOutcome
from ...env import testing


@beartype
async def evaluate_step(context: StepContext) -> StepOutcome:
# NOTE: This activity is only for returning immediately, so we just evaluate the expression
# Hence, it's a local activity and SHOULD NOT fail
async def evaluate_step(
context: StepContext,
additional_values: dict[str, Any] = {},
override_expr: dict[str, str] | None = None,
) -> StepOutcome:
try:
assert isinstance(context.current_step, EvaluateStep)

exprs = context.current_step.evaluate
output = simple_eval_dict(exprs, values=context.model_dump())

expr = (
override_expr
if override_expr is not None
else context.current_step.evaluate
)

values = context.model_dump() | additional_values
output = simple_eval_dict(expr, values)
result = StepOutcome(output=output)

return result

except BaseException as e:
activity.logger.error(f"Error in evaluate_step: {e}")
return StepOutcome(error=str(e))
return StepOutcome(error=str(e) or repr(e))


# Note: This is here just for clarity. We could have just imported evaluate_step directly
Expand Down
10 changes: 5 additions & 5 deletions agents-api/agents_api/activities/task_steps/for_each_step.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import logging

from beartype import beartype
from simpleeval import simple_eval
from temporalio import activity

from ...autogen.openapi_model import ForeachStep
Expand All @@ -10,18 +9,19 @@
StepOutcome,
)
from ...env import testing
from .base_evaluate import base_evaluate


@beartype
async def for_each_step(context: StepContext) -> StepOutcome:
try:
assert isinstance(context.current_step, ForeachStep)

return StepOutcome(
output=simple_eval(
context.current_step.foreach.in_, names=context.model_dump()
)
output = await base_evaluate(
context.current_step.foreach.in_, context.model_dump()
)
return StepOutcome(output=output)

except BaseException as e:
logging.error(f"Error in for_each_step: {e}")
return StepOutcome(error=str(e))
Expand Down
4 changes: 2 additions & 2 deletions agents-api/agents_api/activities/task_steps/if_else_step.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from beartype import beartype
from simpleeval import simple_eval
from temporalio import activity

from ...autogen.openapi_model import IfElseWorkflowStep
Expand All @@ -8,6 +7,7 @@
StepOutcome,
)
from ...env import testing
from .base_evaluate import base_evaluate


@beartype
Expand All @@ -18,7 +18,7 @@ async def if_else_step(context: StepContext) -> StepOutcome:
assert isinstance(context.current_step, IfElseWorkflowStep)

expr: str = context.current_step.if_
output = simple_eval(expr, names=context.model_dump())
output = await base_evaluate(expr, context.model_dump())
output: bool = bool(output)

result = StepOutcome(output=output)
Expand Down
4 changes: 2 additions & 2 deletions agents-api/agents_api/activities/task_steps/log_step.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from beartype import beartype
from simpleeval import simple_eval
from temporalio import activity

from ...autogen.openapi_model import LogStep
Expand All @@ -8,6 +7,7 @@
StepOutcome,
)
from ...env import testing
from .base_evaluate import base_evaluate


@beartype
Expand All @@ -18,7 +18,7 @@ async def log_step(context: StepContext) -> StepOutcome:
assert isinstance(context.current_step, LogStep)

expr: str = context.current_step.log
output = simple_eval(expr, names=context.model_dump())
output = await base_evaluate(expr, context.model_dump())

result = StepOutcome(output=output)
return result
Expand Down
11 changes: 5 additions & 6 deletions agents-api/agents_api/activities/task_steps/map_reduce_step.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import logging

from beartype import beartype
from simpleeval import simple_eval
from temporalio import activity

from ...autogen.openapi_model import MapReduceStep
Expand All @@ -10,18 +9,18 @@
StepOutcome,
)
from ...env import testing
from .base_evaluate import base_evaluate


@beartype
async def map_reduce_step(context: StepContext) -> StepOutcome:
try:
assert isinstance(context.current_step, MapReduceStep)

return StepOutcome(
output=simple_eval(
context.current_step.map.over, names=context.model_dump()
)
)
output = await base_evaluate(context.current_step.over, context.model_dump())

return StepOutcome(output=output)

except BaseException as e:
logging.error(f"Error in map_reduce_step: {e}")
return StepOutcome(error=str(e))
Expand Down
4 changes: 2 additions & 2 deletions agents-api/agents_api/activities/task_steps/return_step.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from temporalio import activity

from ...activities.task_steps.utils import simple_eval_dict
from ...autogen.openapi_model import ReturnStep
from ...common.protocol.tasks import (
StepContext,
StepOutcome,
)
from ...env import testing
from .base_evaluate import base_evaluate


async def return_step(context: StepContext) -> StepOutcome:
Expand All @@ -16,7 +16,7 @@ async def return_step(context: StepContext) -> StepOutcome:
assert isinstance(context.current_step, ReturnStep)

exprs: dict[str, str] = context.current_step.return_
output = simple_eval_dict(exprs, values=context.model_dump())
output = await base_evaluate(exprs, context.model_dump())

result = StepOutcome(output=output)
return result
Expand Down
7 changes: 4 additions & 3 deletions agents-api/agents_api/activities/task_steps/switch_step.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from beartype import beartype
from simpleeval import simple_eval
from temporalio import activity

from ...autogen.openapi_model import SwitchStep
Expand All @@ -8,6 +7,7 @@
StepOutcome,
)
from ...env import testing
from ..utils import get_evaluator


@beartype
Expand All @@ -17,11 +17,12 @@ async def switch_step(context: StepContext) -> StepOutcome:

# Assume that none of the cases evaluate to truthy
output: int = -1

cases: list[str] = [c.case for c in context.current_step.switch]

evaluator = get_evaluator(names=context.model_dump())

for i, case in enumerate(cases):
result = simple_eval(case, names=context.model_dump())
result = evaluator.eval(case)

if result:
output = i
Expand Down
11 changes: 0 additions & 11 deletions agents-api/agents_api/activities/task_steps/utils.py

This file was deleted.

17 changes: 11 additions & 6 deletions agents-api/agents_api/activities/task_steps/wait_for_input_step.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,24 @@
from temporalio import activity

from ...activities.task_steps.utils import simple_eval_dict
from ...autogen.openapi_model import WaitForInputStep
from ...common.protocol.tasks import StepContext, StepOutcome
from ...env import testing
from .base_evaluate import base_evaluate


async def wait_for_input_step(context: StepContext) -> StepOutcome:
assert isinstance(context.current_step, WaitForInputStep)
try:
assert isinstance(context.current_step, WaitForInputStep)

exprs = context.current_step.wait_for_input
output = simple_eval_dict(exprs, values=context.model_dump())
exprs = context.current_step.wait_for_input
output = await base_evaluate(exprs, context.model_dump())

result = StepOutcome(output=output)
return result
result = StepOutcome(output=output)
return result

except BaseException as e:
activity.logger.error(f"Error in wait_for_input_step: {e}")
return StepOutcome(error=str(e))


# Note: This is here just for clarity. We could have just imported wait_for_input_step directly
Expand Down
6 changes: 3 additions & 3 deletions agents-api/agents_api/activities/task_steps/yield_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from ...common.protocol.tasks import StepContext, StepOutcome
from ...env import testing
from .utils import simple_eval_dict
from .base_evaluate import base_evaluate


@beartype
Expand All @@ -19,14 +19,14 @@ async def yield_step(context: StepContext) -> StepOutcome:

all_workflows = context.execution_input.task.workflows
workflow = context.current_step.workflow
exprs = context.current_step.arguments

assert workflow in [
wf.name for wf in all_workflows
], f"Workflow {workflow} not found in task"

# Evaluate the expressions in the arguments
exprs = context.current_step.arguments
arguments = simple_eval_dict(exprs, values=context.model_dump())
arguments = await base_evaluate(exprs, context.model_dump())

# Transition to the first step of that workflow
transition_target = TransitionTarget(
Expand Down
17 changes: 17 additions & 0 deletions agents-api/agents_api/activities/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from typing import Any

from beartype import beartype
from simpleeval import EvalWithCompoundTypes, SimpleEval


@beartype
def get_evaluator(names: dict[str, Any]) -> SimpleEval:
evaluator = EvalWithCompoundTypes(names=names)
return evaluator


@beartype
def simple_eval_dict(exprs: dict[str, str], values: dict[str, Any]) -> dict[str, Any]:
evaluator = get_evaluator(names=values)

return {k: evaluator.eval(v) for k, v in exprs.items()}
10 changes: 0 additions & 10 deletions agents-api/agents_api/autogen/Common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,6 @@
from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, RootModel


class JinjaTemplate(RootModel[str]):
model_config = ConfigDict(
populate_by_name=True,
)
root: str
"""
A valid jinja template.
"""


class Limit(RootModel[int]):
model_config = ConfigDict(
populate_by_name=True,
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/autogen/Executions.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class Transition(BaseModel):
Field(json_schema_extra={"readOnly": True}),
]
execution_id: Annotated[UUID, Field(json_schema_extra={"readOnly": True})]
output: Annotated[dict[str, Any], Field(json_schema_extra={"readOnly": True})]
output: Annotated[Any, Field(json_schema_extra={"readOnly": True})]
current: Annotated[TransitionTarget, Field(json_schema_extra={"readOnly": True})]
next: Annotated[
TransitionTarget | None, Field(json_schema_extra={"readOnly": True})
Expand Down
Loading
Loading