Skip to content

Commit

Permalink
feat+refactor(engine): Cleanup and improve webhooks (#489)
Browse files Browse the repository at this point in the history
  • Loading branch information
daryllimyt authored Nov 5, 2024
1 parent 28c7e19 commit dd443ae
Show file tree
Hide file tree
Showing 17 changed files with 138 additions and 67 deletions.
17 changes: 17 additions & 0 deletions tests/unit/test_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,3 +816,20 @@ async def test_extract_expressions_errors(expr, expected, test_role, env_sandbox

for actual, ex in zip(errors, expected, strict=True):
assert_validation_result(actual, **ex)


@pytest.mark.parametrize(
"context,expr,expected",
[
({"TRIGGER": {"data": {"foo": "bar"}}}, "TRIGGER", {"data": {"foo": "bar"}}),
({"TRIGGER": "data"}, "TRIGGER", "data"),
({"TRIGGER": None}, "TRIGGER", None),
({"TRIGGER": [1, 2, 3]}, "TRIGGER", [1, 2, 3]),
],
)
def test_parse_trigger_json(context, expr, expected):
parser = ExprParser()
parse_tree = parser.parse(expr)
ev = ExprEvaluator(context=context)
actual = ev.transform(parse_tree)
assert actual == expected
2 changes: 1 addition & 1 deletion tracecat/api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from sqlmodel.ext.asyncio.session import AsyncSession

from tracecat import config
from tracecat.api.routers.public.webhooks import router as webhook_router
from tracecat.api.routers.users import router as users_router
from tracecat.auth.constants import AuthType
from tracecat.auth.models import UserCreate, UserRead, UserUpdate
Expand All @@ -38,6 +37,7 @@
from tracecat.secrets.router import router as secrets_router
from tracecat.types.auth import AccessLevel, Role
from tracecat.types.exceptions import TracecatException
from tracecat.webhooks.router import router as webhook_router
from tracecat.workflow.actions.router import router as workflow_actions_router
from tracecat.workflow.executions.router import router as workflow_executions_router
from tracecat.workflow.management.router import router as workflow_management_router
Expand Down
15 changes: 4 additions & 11 deletions tracecat/dsl/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,8 @@
from tracecat.contexts import RunContext
from tracecat.db.schemas import Action
from tracecat.dsl.enums import EdgeType, FailStrategy, LoopStrategy
from tracecat.dsl.models import ActionStatement, DSLConfig, Trigger
from tracecat.dsl.view import (
RFEdge,
RFGraph,
RFNode,
TriggerNode,
UDFNode,
UDFNodeData,
)
from tracecat.dsl.models import ActionStatement, DSLConfig, Trigger, TriggerInputs
from tracecat.dsl.view import RFEdge, RFGraph, RFNode, TriggerNode, UDFNode, UDFNodeData
from tracecat.expressions import patterns
from tracecat.expressions.expectations import ExpectedField
from tracecat.expressions.shared import ExprContext
Expand Down Expand Up @@ -201,7 +194,7 @@ class DSLRunArgs(BaseModel):
role: Role
dsl: DSLInput | None = None
wf_id: WorkflowID
trigger_inputs: dict[str, Any] | None = None
trigger_inputs: TriggerInputs | None = None
parent_run_context: RunContext | None = None
runtime_config: DSLConfig = Field(
default_factory=DSLConfig,
Expand All @@ -222,7 +215,7 @@ class DSLRunArgs(BaseModel):

class ExecuteChildWorkflowArgs(TypedDict):
workflow_id: WorkflowID
trigger_inputs: dict[str, Any]
trigger_inputs: TriggerInputs
environment: str | None
version: int | None
loop_strategy: LoopStrategy | None
Expand Down
7 changes: 5 additions & 2 deletions tracecat/dsl/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from dataclasses import dataclass
from typing import Annotated, Any, Generic, Literal, TypedDict, TypeVar

from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, JsonValue

from tracecat.contexts import RunContext
from tracecat.dsl.constants import DEFAULT_ACTION_TIMEOUT
Expand All @@ -17,6 +17,9 @@
SLUG_PATTERN = r"^[a-z0-9_]+$"
ACTION_TYPE_PATTERN = r"^[a-z0-9_.]+$"

TriggerInputs = JsonValue
"""Trigger inputs JSON type."""


class DSLNodeResult(TypedDict, total=False):
"""Result of executing a DSL node."""
Expand Down Expand Up @@ -167,7 +170,7 @@ class DSLContext(TypedDict, total=False):
ACTIONS: dict[str, Any]
"""DSL Actions context"""

TRIGGER: dict[str, Any]
TRIGGER: TriggerInputs
"""DSL Trigger dynamic inputs context"""

ENV: DSLEnvironment
Expand Down
30 changes: 16 additions & 14 deletions tracecat/dsl/validation.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from __future__ import annotations

import re
from typing import Any

from pydantic import BaseModel, ConfigDict, ValidationError
from temporalio import activity

from tracecat.dsl.common import DSLInput
from tracecat.dsl.models import TriggerInputs
from tracecat.expressions.expectations import ExpectedField, create_expectation_model
from tracecat.logger import logger
from tracecat.types.validation import ValidationResult
Expand All @@ -16,7 +16,7 @@

def validate_trigger_inputs(
dsl: DSLInput,
payload: dict[str, Any] | None = None,
payload: TriggerInputs | None = None,
*,
raise_exceptions: bool = False,
model_name: str = "TriggerInputsValidator",
Expand All @@ -35,24 +35,26 @@ def validate_trigger_inputs(
field_name: ExpectedField.model_validate(field_schema)
for field_name, field_schema in dsl.entrypoint.expects.items()
}
validator = create_expectation_model(expects_schema, model_name=model_name)
try:
validator(**payload)
except ValidationError as e:
if raise_exceptions:
raise
return ValidationResult(
status="error",
msg=f"Validation error in trigger inputs ({e.title}). Please refer to the schema for more details.",
detail={"errors": e.errors()},
)
if isinstance(payload, dict):
# NOTE: We only validate dict payloads for now
validator = create_expectation_model(expects_schema, model_name=model_name)
try:
validator(**payload)
except ValidationError as e:
if raise_exceptions:
raise
return ValidationResult(
status="error",
msg=f"Validation error in trigger inputs ({e.title}). Please refer to the schema for more details.",
detail={"errors": e.errors()},
)
return ValidationResult(status="success", msg="Trigger inputs are valid.")


class ValidateTriggerInputsActivityInputs(BaseModel):
model_config: ConfigDict = ConfigDict(arbitrary_types_allowed=True)
dsl: DSLInput
trigger_inputs: dict[str, Any]
trigger_inputs: TriggerInputs


@activity.defn
Expand Down
5 changes: 3 additions & 2 deletions tracecat/dsl/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
TracecatValidationError,
)
from tracecat.types.validation import ValidationResult
from tracecat.dsl.models import TriggerInputs
from tracecat.workflow.management.definitions import (
get_workflow_definition_activity,
)
Expand Down Expand Up @@ -336,7 +337,7 @@ async def execute_task(self, task: ActionStatement[ArgsT]) -> Any:
task_result.update(error=msg, error_typename=err_type)
raise ApplicationError(msg, non_retryable=True, type=err_type) from e
finally:
logger.warning("Setting action result", task_result=task_result)
logger.debug("Setting action result", task_result=task_result)
self.context[ExprContext.ACTIONS][task.ref] = task_result # type: ignore

ERROR_TYPE_TO_MESSAGE = {
Expand Down Expand Up @@ -511,7 +512,7 @@ async def _get_workflow_definition(
)

async def _validate_trigger_inputs(
self, trigger_inputs: dict[str, Any]
self, trigger_inputs: TriggerInputs
) -> ValidationResult:
"""Validate trigger inputs.
Expand Down
7 changes: 3 additions & 4 deletions tracecat/expressions/parser/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,10 @@ def local_vars_assignment(self, jsonpath: str):
return jsonpath

@v_args(inline=True)
def trigger(self, jsonpath: str):
def trigger(self, jsonpath: str | None):
logger.trace("Visiting trigger:", args=jsonpath)
return functions.eval_jsonpath(
ExprContext.TRIGGER + jsonpath, self._context, strict=self._strict
)
expr = ExprContext.TRIGGER + (jsonpath or "")
return functions.eval_jsonpath(expr, self._context, strict=self._strict)

@v_args(inline=True)
def template_action_inputs(self, jsonpath: str):
Expand Down
2 changes: 1 addition & 1 deletion tracecat/expressions/parser/grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
inputs: "INPUTS" jsonpath_expression
env: "ENV" jsonpath_expression
local_vars: "var" jsonpath_expression
trigger: "TRIGGER" jsonpath_expression
trigger: "TRIGGER" [jsonpath_expression]
function: "FN." FN_NAME_WITH_TRANSFORM "(" [arg_list] ")"
local_vars_assignment: "var" ATTRIBUTE_PATH
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from typing import Annotated
from typing import Annotated, cast

from fastapi import Depends, HTTPException, Request, status
import orjson
from fastapi import Depends, Header, HTTPException, Request, status
from sqlalchemy.exc import NoResultFound
from sqlmodel import select

from tracecat.contexts import ctx_role
from tracecat.db.engine import get_async_session_context_manager
from tracecat.db.schemas import Webhook, WorkflowDefinition
from tracecat.dsl.models import TriggerInputs
from tracecat.logger import logger
from tracecat.types.auth import Role

Expand Down Expand Up @@ -60,7 +62,7 @@ async def validate_incoming_webhook(
result = await session.exec(
select(WorkflowDefinition)
.where(WorkflowDefinition.workflow_id == path)
.order_by(WorkflowDefinition.version.desc())
.order_by(WorkflowDefinition.version.desc()) # type: ignore
)
try:
defn = result.first()
Expand Down Expand Up @@ -96,6 +98,63 @@ async def validate_incoming_webhook(
return validated_defn


async def parse_webhook_payload(
request: Request,
content_type: Annotated[str | None, Header(alias="content-type")] = None,
) -> TriggerInputs | None:
"""
Dependency to parse webhook payload based on Content-Type header.
Args:
request: FastAPI request object
content_type: Content-Type header value
Returns:
Parsed payload as TriggerInputs or None if no payload
"""
logger.debug("Parsing webhook payload", content_type=content_type)

body = await request.body()
if not body:
return None

match content_type:
case "application/x-ndjson" | "application/jsonlines" | "application/jsonl":
# Newline delimited json
try:
lines = body.splitlines()
return cast(TriggerInputs, [orjson.loads(line) for line in lines])
except orjson.JSONDecodeError as e:
logger.error("Failed to parse ndjson payload", error=e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid ndjson payload",
) from e
case "application/x-www-form-urlencoded":
try:
form_data = await request.form()
return cast(TriggerInputs, dict(form_data))
except Exception as e:
logger.error("Failed to parse form data payload", error=e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid form data payload",
) from e
case _:
# Interpret everything else as json
try:
return cast(TriggerInputs, orjson.loads(body))
except orjson.JSONDecodeError as e:
logger.error("Failed to parse json payload", error=e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid json payload",
) from e


PayloadDep = Annotated[TriggerInputs | None, Depends(parse_webhook_payload)]


WorkflowDefinitionFromWebhook = Annotated[
WorkflowDefinition, Depends(validate_incoming_webhook)
]
14 changes: 6 additions & 8 deletions tracecat/types/api.py → tracecat/webhooks/models.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,10 @@
from __future__ import annotations

from typing import Any, Literal

from pydantic import BaseModel

from tracecat.db.schemas import Resource


class UpsertWebhookParams(BaseModel):
status: Literal["online", "offline"] | None = None
entrypoint_ref: str | None = None
method: Literal["GET", "POST"] | None = None


class WebhookResponse(Resource):
id: str
secret: str
Expand All @@ -22,3 +14,9 @@ class WebhookResponse(Resource):
method: Literal["GET", "POST"]
workflow_id: str
url: str


class UpsertWebhookParams(BaseModel):
status: Literal["online", "offline"] | None = None
entrypoint_ref: str | None = None
method: Literal["GET", "POST"] | None = None
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from typing import Any

from fastapi import APIRouter

from tracecat.api.routers.public.dependencies import WorkflowDefinitionFromWebhook
from tracecat.contexts import ctx_role
from tracecat.dsl.common import DSLInput
from tracecat.dsl.models import DSLContext
from tracecat.logger import logger
from tracecat.webhooks.dependencies import PayloadDep, WorkflowDefinitionFromWebhook
from tracecat.workflow.executions.models import CreateWorkflowExecutionResponse
from tracecat.workflow.executions.service import WorkflowExecutionsService

Expand All @@ -14,9 +13,7 @@

@router.post("/{path}/{secret}", tags=["public"])
async def incoming_webhook(
defn: WorkflowDefinitionFromWebhook,
path: str,
payload: dict[str, Any] | None = None,
defn: WorkflowDefinitionFromWebhook, path: str, payload: PayloadDep
) -> CreateWorkflowExecutionResponse:
"""
Webhook endpoint to trigger a workflow.
Expand All @@ -37,10 +34,8 @@ async def incoming_webhook(

@router.post("/{path}/{secret}/wait", tags=["public"])
async def incoming_webhook_wait(
defn: WorkflowDefinitionFromWebhook,
path: str,
payload: dict[str, Any] | None = None,
) -> dict[str, Any]:
defn: WorkflowDefinitionFromWebhook, path: str, payload: PayloadDep
) -> DSLContext:
"""
Webhook endpoint to trigger a workflow.
Expand Down
9 changes: 7 additions & 2 deletions tracecat/workflow/executions/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@
from tracecat import identifiers
from tracecat.dsl.common import DSLRunArgs
from tracecat.dsl.enums import JoinStrategy
from tracecat.dsl.models import ActionRetryPolicy, DSLContext, RunActionInput
from tracecat.dsl.models import (
ActionRetryPolicy,
DSLContext,
RunActionInput,
TriggerInputs,
)
from tracecat.types.auth import Role
from tracecat.workflow.management.models import GetWorkflowDefinitionActivityInputs

Expand Down Expand Up @@ -255,7 +260,7 @@ class EventHistoryResponse(BaseModel, Generic[EventInput]):

class CreateWorkflowExecutionParams(BaseModel):
workflow_id: identifiers.WorkflowID
inputs: dict[str, Any] | None = None
inputs: TriggerInputs | None = None


class CreateWorkflowExecutionResponse(TypedDict):
Expand Down
Loading

0 comments on commit dd443ae

Please sign in to comment.