diff --git a/tests/unit/test_expressions.py b/tests/unit/test_expressions.py index 527d4fd11..e44ca867b 100644 --- a/tests/unit/test_expressions.py +++ b/tests/unit/test_expressions.py @@ -816,3 +816,20 @@ async def test_extract_expressions_errors(expr, expected, test_role, env_sandbox for actual, ex in zip(errors, expected, strict=True): assert_validation_result(actual, **ex) + + +@pytest.mark.parametrize( + "context,expr,expected", + [ + ({"TRIGGER": {"data": {"foo": "bar"}}}, "TRIGGER", {"data": {"foo": "bar"}}), + ({"TRIGGER": "data"}, "TRIGGER", "data"), + ({"TRIGGER": None}, "TRIGGER", None), + ({"TRIGGER": [1, 2, 3]}, "TRIGGER", [1, 2, 3]), + ], +) +def test_parse_trigger_json(context, expr, expected): + parser = ExprParser() + parse_tree = parser.parse(expr) + ev = ExprEvaluator(context=context) + actual = ev.transform(parse_tree) + assert actual == expected diff --git a/tracecat/api/app.py b/tracecat/api/app.py index 23b43169b..8fcd891b1 100644 --- a/tracecat/api/app.py +++ b/tracecat/api/app.py @@ -11,7 +11,6 @@ from sqlmodel.ext.asyncio.session import AsyncSession from tracecat import config -from tracecat.api.routers.public.webhooks import router as webhook_router from tracecat.api.routers.users import router as users_router from tracecat.auth.constants import AuthType from tracecat.auth.models import UserCreate, UserRead, UserUpdate @@ -38,6 +37,7 @@ from tracecat.secrets.router import router as secrets_router from tracecat.types.auth import AccessLevel, Role from tracecat.types.exceptions import TracecatException +from tracecat.webhooks.router import router as webhook_router from tracecat.workflow.actions.router import router as workflow_actions_router from tracecat.workflow.executions.router import router as workflow_executions_router from tracecat.workflow.management.router import router as workflow_management_router diff --git a/tracecat/dsl/common.py b/tracecat/dsl/common.py index 4672af912..8c9887d34 100644 --- a/tracecat/dsl/common.py +++ b/tracecat/dsl/common.py @@ -14,15 +14,8 @@ from tracecat.contexts import RunContext from tracecat.db.schemas import Action from tracecat.dsl.enums import EdgeType, FailStrategy, LoopStrategy -from tracecat.dsl.models import ActionStatement, DSLConfig, Trigger -from tracecat.dsl.view import ( - RFEdge, - RFGraph, - RFNode, - TriggerNode, - UDFNode, - UDFNodeData, -) +from tracecat.dsl.models import ActionStatement, DSLConfig, Trigger, TriggerInputs +from tracecat.dsl.view import RFEdge, RFGraph, RFNode, TriggerNode, UDFNode, UDFNodeData from tracecat.expressions import patterns from tracecat.expressions.expectations import ExpectedField from tracecat.expressions.shared import ExprContext @@ -201,7 +194,7 @@ class DSLRunArgs(BaseModel): role: Role dsl: DSLInput | None = None wf_id: WorkflowID - trigger_inputs: dict[str, Any] | None = None + trigger_inputs: TriggerInputs | None = None parent_run_context: RunContext | None = None runtime_config: DSLConfig = Field( default_factory=DSLConfig, @@ -222,7 +215,7 @@ class DSLRunArgs(BaseModel): class ExecuteChildWorkflowArgs(TypedDict): workflow_id: WorkflowID - trigger_inputs: dict[str, Any] + trigger_inputs: TriggerInputs environment: str | None version: int | None loop_strategy: LoopStrategy | None diff --git a/tracecat/dsl/models.py b/tracecat/dsl/models.py index 317496866..e08a5bfd4 100644 --- a/tracecat/dsl/models.py +++ b/tracecat/dsl/models.py @@ -4,7 +4,7 @@ from dataclasses import dataclass from typing import Annotated, Any, Generic, Literal, TypedDict, TypeVar -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, JsonValue from tracecat.contexts import RunContext from tracecat.dsl.constants import DEFAULT_ACTION_TIMEOUT @@ -17,6 +17,9 @@ SLUG_PATTERN = r"^[a-z0-9_]+$" ACTION_TYPE_PATTERN = r"^[a-z0-9_.]+$" +TriggerInputs = JsonValue +"""Trigger inputs JSON type.""" + class DSLNodeResult(TypedDict, total=False): """Result of executing a DSL node.""" @@ -167,7 +170,7 @@ class DSLContext(TypedDict, total=False): ACTIONS: dict[str, Any] """DSL Actions context""" - TRIGGER: dict[str, Any] + TRIGGER: TriggerInputs """DSL Trigger dynamic inputs context""" ENV: DSLEnvironment diff --git a/tracecat/dsl/validation.py b/tracecat/dsl/validation.py index aab95ee90..eb9d37a56 100644 --- a/tracecat/dsl/validation.py +++ b/tracecat/dsl/validation.py @@ -1,12 +1,12 @@ from __future__ import annotations import re -from typing import Any from pydantic import BaseModel, ConfigDict, ValidationError from temporalio import activity from tracecat.dsl.common import DSLInput +from tracecat.dsl.models import TriggerInputs from tracecat.expressions.expectations import ExpectedField, create_expectation_model from tracecat.logger import logger from tracecat.types.validation import ValidationResult @@ -16,7 +16,7 @@ def validate_trigger_inputs( dsl: DSLInput, - payload: dict[str, Any] | None = None, + payload: TriggerInputs | None = None, *, raise_exceptions: bool = False, model_name: str = "TriggerInputsValidator", @@ -35,24 +35,26 @@ def validate_trigger_inputs( field_name: ExpectedField.model_validate(field_schema) for field_name, field_schema in dsl.entrypoint.expects.items() } - validator = create_expectation_model(expects_schema, model_name=model_name) - try: - validator(**payload) - except ValidationError as e: - if raise_exceptions: - raise - return ValidationResult( - status="error", - msg=f"Validation error in trigger inputs ({e.title}). Please refer to the schema for more details.", - detail={"errors": e.errors()}, - ) + if isinstance(payload, dict): + # NOTE: We only validate dict payloads for now + validator = create_expectation_model(expects_schema, model_name=model_name) + try: + validator(**payload) + except ValidationError as e: + if raise_exceptions: + raise + return ValidationResult( + status="error", + msg=f"Validation error in trigger inputs ({e.title}). Please refer to the schema for more details.", + detail={"errors": e.errors()}, + ) return ValidationResult(status="success", msg="Trigger inputs are valid.") class ValidateTriggerInputsActivityInputs(BaseModel): model_config: ConfigDict = ConfigDict(arbitrary_types_allowed=True) dsl: DSLInput - trigger_inputs: dict[str, Any] + trigger_inputs: TriggerInputs @activity.defn diff --git a/tracecat/dsl/workflow.py b/tracecat/dsl/workflow.py index b0cad444e..febe518ab 100644 --- a/tracecat/dsl/workflow.py +++ b/tracecat/dsl/workflow.py @@ -60,6 +60,7 @@ TracecatValidationError, ) from tracecat.types.validation import ValidationResult + from tracecat.dsl.models import TriggerInputs from tracecat.workflow.management.definitions import ( get_workflow_definition_activity, ) @@ -336,7 +337,7 @@ async def execute_task(self, task: ActionStatement[ArgsT]) -> Any: task_result.update(error=msg, error_typename=err_type) raise ApplicationError(msg, non_retryable=True, type=err_type) from e finally: - logger.warning("Setting action result", task_result=task_result) + logger.debug("Setting action result", task_result=task_result) self.context[ExprContext.ACTIONS][task.ref] = task_result # type: ignore ERROR_TYPE_TO_MESSAGE = { @@ -511,7 +512,7 @@ async def _get_workflow_definition( ) async def _validate_trigger_inputs( - self, trigger_inputs: dict[str, Any] + self, trigger_inputs: TriggerInputs ) -> ValidationResult: """Validate trigger inputs. diff --git a/tracecat/expressions/parser/evaluator.py b/tracecat/expressions/parser/evaluator.py index 913c1bae5..5e290da42 100644 --- a/tracecat/expressions/parser/evaluator.py +++ b/tracecat/expressions/parser/evaluator.py @@ -141,11 +141,10 @@ def local_vars_assignment(self, jsonpath: str): return jsonpath @v_args(inline=True) - def trigger(self, jsonpath: str): + def trigger(self, jsonpath: str | None): logger.trace("Visiting trigger:", args=jsonpath) - return functions.eval_jsonpath( - ExprContext.TRIGGER + jsonpath, self._context, strict=self._strict - ) + expr = ExprContext.TRIGGER + (jsonpath or "") + return functions.eval_jsonpath(expr, self._context, strict=self._strict) @v_args(inline=True) def template_action_inputs(self, jsonpath: str): diff --git a/tracecat/expressions/parser/grammar.py b/tracecat/expressions/parser/grammar.py index 2a1645acd..854f56a44 100644 --- a/tracecat/expressions/parser/grammar.py +++ b/tracecat/expressions/parser/grammar.py @@ -36,7 +36,7 @@ inputs: "INPUTS" jsonpath_expression env: "ENV" jsonpath_expression local_vars: "var" jsonpath_expression -trigger: "TRIGGER" jsonpath_expression +trigger: "TRIGGER" [jsonpath_expression] function: "FN." FN_NAME_WITH_TRANSFORM "(" [arg_list] ")" local_vars_assignment: "var" ATTRIBUTE_PATH diff --git a/tracecat/api/routers/public/__init__.py b/tracecat/webhooks/__init__.py similarity index 100% rename from tracecat/api/routers/public/__init__.py rename to tracecat/webhooks/__init__.py diff --git a/tracecat/api/routers/public/dependencies.py b/tracecat/webhooks/dependencies.py similarity index 59% rename from tracecat/api/routers/public/dependencies.py rename to tracecat/webhooks/dependencies.py index 4378a2cd4..e7193e952 100644 --- a/tracecat/api/routers/public/dependencies.py +++ b/tracecat/webhooks/dependencies.py @@ -1,12 +1,14 @@ -from typing import Annotated +from typing import Annotated, cast -from fastapi import Depends, HTTPException, Request, status +import orjson +from fastapi import Depends, Header, HTTPException, Request, status from sqlalchemy.exc import NoResultFound from sqlmodel import select from tracecat.contexts import ctx_role from tracecat.db.engine import get_async_session_context_manager from tracecat.db.schemas import Webhook, WorkflowDefinition +from tracecat.dsl.models import TriggerInputs from tracecat.logger import logger from tracecat.types.auth import Role @@ -60,7 +62,7 @@ async def validate_incoming_webhook( result = await session.exec( select(WorkflowDefinition) .where(WorkflowDefinition.workflow_id == path) - .order_by(WorkflowDefinition.version.desc()) + .order_by(WorkflowDefinition.version.desc()) # type: ignore ) try: defn = result.first() @@ -96,6 +98,63 @@ async def validate_incoming_webhook( return validated_defn +async def parse_webhook_payload( + request: Request, + content_type: Annotated[str | None, Header(alias="content-type")] = None, +) -> TriggerInputs | None: + """ + Dependency to parse webhook payload based on Content-Type header. + + Args: + request: FastAPI request object + content_type: Content-Type header value + + Returns: + Parsed payload as TriggerInputs or None if no payload + """ + logger.debug("Parsing webhook payload", content_type=content_type) + + body = await request.body() + if not body: + return None + + match content_type: + case "application/x-ndjson" | "application/jsonlines" | "application/jsonl": + # Newline delimited json + try: + lines = body.splitlines() + return cast(TriggerInputs, [orjson.loads(line) for line in lines]) + except orjson.JSONDecodeError as e: + logger.error("Failed to parse ndjson payload", error=e) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid ndjson payload", + ) from e + case "application/x-www-form-urlencoded": + try: + form_data = await request.form() + return cast(TriggerInputs, dict(form_data)) + except Exception as e: + logger.error("Failed to parse form data payload", error=e) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid form data payload", + ) from e + case _: + # Interpret everything else as json + try: + return cast(TriggerInputs, orjson.loads(body)) + except orjson.JSONDecodeError as e: + logger.error("Failed to parse json payload", error=e) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid json payload", + ) from e + + +PayloadDep = Annotated[TriggerInputs | None, Depends(parse_webhook_payload)] + + WorkflowDefinitionFromWebhook = Annotated[ WorkflowDefinition, Depends(validate_incoming_webhook) ] diff --git a/tracecat/types/api.py b/tracecat/webhooks/models.py similarity index 93% rename from tracecat/types/api.py rename to tracecat/webhooks/models.py index 593814b83..ee943099c 100644 --- a/tracecat/types/api.py +++ b/tracecat/webhooks/models.py @@ -1,5 +1,3 @@ -from __future__ import annotations - from typing import Any, Literal from pydantic import BaseModel @@ -7,12 +5,6 @@ from tracecat.db.schemas import Resource -class UpsertWebhookParams(BaseModel): - status: Literal["online", "offline"] | None = None - entrypoint_ref: str | None = None - method: Literal["GET", "POST"] | None = None - - class WebhookResponse(Resource): id: str secret: str @@ -22,3 +14,9 @@ class WebhookResponse(Resource): method: Literal["GET", "POST"] workflow_id: str url: str + + +class UpsertWebhookParams(BaseModel): + status: Literal["online", "offline"] | None = None + entrypoint_ref: str | None = None + method: Literal["GET", "POST"] | None = None diff --git a/tracecat/api/routers/public/webhooks.py b/tracecat/webhooks/router.py similarity index 83% rename from tracecat/api/routers/public/webhooks.py rename to tracecat/webhooks/router.py index 8de9eb8ff..e340c932a 100644 --- a/tracecat/api/routers/public/webhooks.py +++ b/tracecat/webhooks/router.py @@ -1,11 +1,10 @@ -from typing import Any - from fastapi import APIRouter -from tracecat.api.routers.public.dependencies import WorkflowDefinitionFromWebhook from tracecat.contexts import ctx_role from tracecat.dsl.common import DSLInput +from tracecat.dsl.models import DSLContext from tracecat.logger import logger +from tracecat.webhooks.dependencies import PayloadDep, WorkflowDefinitionFromWebhook from tracecat.workflow.executions.models import CreateWorkflowExecutionResponse from tracecat.workflow.executions.service import WorkflowExecutionsService @@ -14,9 +13,7 @@ @router.post("/{path}/{secret}", tags=["public"]) async def incoming_webhook( - defn: WorkflowDefinitionFromWebhook, - path: str, - payload: dict[str, Any] | None = None, + defn: WorkflowDefinitionFromWebhook, path: str, payload: PayloadDep ) -> CreateWorkflowExecutionResponse: """ Webhook endpoint to trigger a workflow. @@ -37,10 +34,8 @@ async def incoming_webhook( @router.post("/{path}/{secret}/wait", tags=["public"]) async def incoming_webhook_wait( - defn: WorkflowDefinitionFromWebhook, - path: str, - payload: dict[str, Any] | None = None, -) -> dict[str, Any]: + defn: WorkflowDefinitionFromWebhook, path: str, payload: PayloadDep +) -> DSLContext: """ Webhook endpoint to trigger a workflow. diff --git a/tracecat/workflow/executions/models.py b/tracecat/workflow/executions/models.py index c7bf01e77..93f8fcfe1 100644 --- a/tracecat/workflow/executions/models.py +++ b/tracecat/workflow/executions/models.py @@ -15,7 +15,12 @@ from tracecat import identifiers from tracecat.dsl.common import DSLRunArgs from tracecat.dsl.enums import JoinStrategy -from tracecat.dsl.models import ActionRetryPolicy, DSLContext, RunActionInput +from tracecat.dsl.models import ( + ActionRetryPolicy, + DSLContext, + RunActionInput, + TriggerInputs, +) from tracecat.types.auth import Role from tracecat.workflow.management.models import GetWorkflowDefinitionActivityInputs @@ -255,7 +260,7 @@ class EventHistoryResponse(BaseModel, Generic[EventInput]): class CreateWorkflowExecutionParams(BaseModel): workflow_id: identifiers.WorkflowID - inputs: dict[str, Any] | None = None + inputs: TriggerInputs | None = None class CreateWorkflowExecutionResponse(TypedDict): diff --git a/tracecat/workflow/executions/service.py b/tracecat/workflow/executions/service.py index 27fdbe199..3ba98195a 100644 --- a/tracecat/workflow/executions/service.py +++ b/tracecat/workflow/executions/service.py @@ -25,6 +25,7 @@ from tracecat.contexts import ctx_role from tracecat.dsl.client import get_temporal_client from tracecat.dsl.common import DSLInput, DSLRunArgs +from tracecat.dsl.models import TriggerInputs from tracecat.dsl.validation import validate_trigger_inputs from tracecat.dsl.workflow import DSLWorkflow, retry_policies from tracecat.logger import logger @@ -319,7 +320,7 @@ def create_workflow_execution_nowait( dsl: DSLInput, *, wf_id: identifiers.WorkflowID, - payload: dict[str, Any] | None = None, + payload: TriggerInputs | None = None, ) -> CreateWorkflowExecutionResponse: """Create a new workflow execution. @@ -338,7 +339,7 @@ def create_workflow_execution( dsl: DSLInput, *, wf_id: identifiers.WorkflowID, - payload: dict[str, Any] | None = None, + payload: TriggerInputs | None = None, ) -> Awaitable[DispatchWorkflowResult]: """Create a new workflow execution. @@ -364,7 +365,7 @@ async def _dispatch_workflow( dsl: DSLInput, wf_id: identifiers.WorkflowID, wf_exec_id: identifiers.WorkflowExecutionID, - trigger_inputs: dict[str, Any] | None = None, + trigger_inputs: TriggerInputs | None = None, **kwargs: Any, ) -> DispatchWorkflowResult: logger.info( diff --git a/tracecat/workflow/management/management.py b/tracecat/workflow/management/management.py index 63ffeec3b..a446733ef 100644 --- a/tracecat/workflow/management/management.py +++ b/tracecat/workflow/management/management.py @@ -182,7 +182,6 @@ async def build_dsl_from_workflow(self, workflow: Workflow) -> DSLInput: raise TracecatValidationError( "Workflow has no actions. Please add an action to the workflow before committing." ) - logger.warning("Building graph from workflow", workflow=workflow) graph = RFGraph.from_workflow(workflow) if not graph.logical_entrypoint: raise TracecatValidationError( diff --git a/tracecat/workflow/management/models.py b/tracecat/workflow/management/models.py index f9c35c8c9..0af4a1593 100644 --- a/tracecat/workflow/management/models.py +++ b/tracecat/workflow/management/models.py @@ -12,8 +12,8 @@ from tracecat.expressions.expectations import ExpectedField from tracecat.identifiers import OwnerID, WorkflowID, WorkspaceID from tracecat.registry.actions.models import RegistryActionValidateResponse -from tracecat.types.api import WebhookResponse from tracecat.types.auth import Role +from tracecat.webhooks.models import WebhookResponse from tracecat.workflow.actions.models import ActionRead @@ -116,7 +116,7 @@ def from_database(defn: WorkflowDefinition) -> ExternalWorkflowDefinition: created_at=defn.created_at, updated_at=defn.updated_at, version=defn.version, - definition=defn.content, + definition=DSLInput(**defn.content), ) diff --git a/tracecat/workflow/management/router.py b/tracecat/workflow/management/router.py index 63c7f6990..8cfb13b30 100644 --- a/tracecat/workflow/management/router.py +++ b/tracecat/workflow/management/router.py @@ -27,9 +27,9 @@ from tracecat.identifiers import WorkflowID from tracecat.logger import logger from tracecat.registry.actions.models import RegistryActionValidateResponse -from tracecat.types.api import UpsertWebhookParams, WebhookResponse from tracecat.types.auth import Role from tracecat.types.exceptions import TracecatValidationError +from tracecat.webhooks.models import UpsertWebhookParams, WebhookResponse from tracecat.workflow.actions.models import ActionRead from tracecat.workflow.management.definitions import WorkflowDefinitionsService from tracecat.workflow.management.management import WorkflowsManagementService @@ -445,10 +445,9 @@ async def create_webhook( webhook = Webhook( owner_id=role.workspace_id, - entrypoint_ref=params.entrypoint_ref, method=params.method or "POST", workflow_id=workflow_id, - ) + ) # type: ignore session.add(webhook) await session.commit() await session.refresh(webhook)