Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 26 additions & 6 deletions temporalio/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2027,6 +2027,8 @@ class ActivityConfig(TypedDict, total=False):
cancellation_type: ActivityCancellationType
activity_id: Optional[str]
versioning_intent: Optional[VersioningIntent]
summary: Optional[str]
priority: temporalio.common.Priority


# Overload for async no-param activity
Expand All @@ -2043,6 +2045,7 @@ def start_activity(
cancellation_type: ActivityCancellationType = ActivityCancellationType.TRY_CANCEL,
activity_id: Optional[str] = None,
versioning_intent: Optional[VersioningIntent] = None,
summary: Optional[str] = None,
priority: temporalio.common.Priority = temporalio.common.Priority.default,
) -> ActivityHandle[ReturnType]: ...

Expand All @@ -2061,6 +2064,7 @@ def start_activity(
cancellation_type: ActivityCancellationType = ActivityCancellationType.TRY_CANCEL,
activity_id: Optional[str] = None,
versioning_intent: Optional[VersioningIntent] = None,
summary: Optional[str] = None,
priority: temporalio.common.Priority = temporalio.common.Priority.default,
) -> ActivityHandle[ReturnType]: ...

Expand All @@ -2080,6 +2084,7 @@ def start_activity(
cancellation_type: ActivityCancellationType = ActivityCancellationType.TRY_CANCEL,
activity_id: Optional[str] = None,
versioning_intent: Optional[VersioningIntent] = None,
summary: Optional[str] = None,
priority: temporalio.common.Priority = temporalio.common.Priority.default,
) -> ActivityHandle[ReturnType]: ...

Expand All @@ -2099,6 +2104,7 @@ def start_activity(
cancellation_type: ActivityCancellationType = ActivityCancellationType.TRY_CANCEL,
activity_id: Optional[str] = None,
versioning_intent: Optional[VersioningIntent] = None,
summary: Optional[str] = None,
priority: temporalio.common.Priority = temporalio.common.Priority.default,
) -> ActivityHandle[ReturnType]: ...

Expand All @@ -2118,6 +2124,7 @@ def start_activity(
cancellation_type: ActivityCancellationType = ActivityCancellationType.TRY_CANCEL,
activity_id: Optional[str] = None,
versioning_intent: Optional[VersioningIntent] = None,
summary: Optional[str] = None,
priority: temporalio.common.Priority = temporalio.common.Priority.default,
) -> ActivityHandle[ReturnType]: ...

Expand All @@ -2137,6 +2144,7 @@ def start_activity(
cancellation_type: ActivityCancellationType = ActivityCancellationType.TRY_CANCEL,
activity_id: Optional[str] = None,
versioning_intent: Optional[VersioningIntent] = None,
summary: Optional[str] = None,
priority: temporalio.common.Priority = temporalio.common.Priority.default,
) -> ActivityHandle[ReturnType]: ...

Expand All @@ -2158,6 +2166,7 @@ def start_activity(
cancellation_type: ActivityCancellationType = ActivityCancellationType.TRY_CANCEL,
activity_id: Optional[str] = None,
versioning_intent: Optional[VersioningIntent] = None,
summary: Optional[str] = None,
priority: temporalio.common.Priority = temporalio.common.Priority.default,
) -> ActivityHandle[Any]: ...

Expand Down Expand Up @@ -2234,6 +2243,7 @@ def start_activity(
activity_id=activity_id,
versioning_intent=versioning_intent,
summary=summary,
priority=priority,
)


Expand Down Expand Up @@ -4006,6 +4016,10 @@ class ChildWorkflowConfig(TypedDict, total=False):
temporalio.common.SearchAttributes, temporalio.common.TypedSearchAttributes
]
]
versioning_intent: Optional[VersioningIntent]
static_summary: Optional[str]
static_details: Optional[str]
priority: temporalio.common.Priority


# Overload for no-param workflow
Expand Down Expand Up @@ -4238,7 +4252,8 @@ async def execute_child_workflow(
]
] = None,
versioning_intent: Optional[VersioningIntent] = None,
summary: Optional[str] = None,
static_summary: Optional[str] = None,
static_details: Optional[str] = None,
priority: temporalio.common.Priority = temporalio.common.Priority.default,
) -> ReturnType: ...

Expand Down Expand Up @@ -4266,7 +4281,8 @@ async def execute_child_workflow(
]
] = None,
versioning_intent: Optional[VersioningIntent] = None,
summary: Optional[str] = None,
static_summary: Optional[str] = None,
static_details: Optional[str] = None,
priority: temporalio.common.Priority = temporalio.common.Priority.default,
) -> ReturnType: ...

Expand Down Expand Up @@ -4294,7 +4310,8 @@ async def execute_child_workflow(
]
] = None,
versioning_intent: Optional[VersioningIntent] = None,
summary: Optional[str] = None,
static_summary: Optional[str] = None,
static_details: Optional[str] = None,
priority: temporalio.common.Priority = temporalio.common.Priority.default,
) -> ReturnType: ...

Expand Down Expand Up @@ -4324,7 +4341,8 @@ async def execute_child_workflow(
]
] = None,
versioning_intent: Optional[VersioningIntent] = None,
summary: Optional[str] = None,
static_summary: Optional[str] = None,
static_details: Optional[str] = None,
priority: temporalio.common.Priority = temporalio.common.Priority.default,
) -> Any: ...

Expand Down Expand Up @@ -4352,7 +4370,8 @@ async def execute_child_workflow(
]
] = None,
versioning_intent: Optional[VersioningIntent] = None,
summary: Optional[str] = None,
static_summary: Optional[str] = None,
static_details: Optional[str] = None,
priority: temporalio.common.Priority = temporalio.common.Priority.default,
) -> Any:
"""Start a child workflow and wait for completion.
Expand All @@ -4379,7 +4398,8 @@ async def execute_child_workflow(
memo=memo,
search_attributes=search_attributes,
versioning_intent=versioning_intent,
static_summary=summary,
static_summary=static_summary,
static_details=static_details,
priority=priority,
)
return await handle
Expand Down
102 changes: 101 additions & 1 deletion tests/test_workflow.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import inspect
import itertools
from typing import Sequence
from typing import Any, Callable, Sequence, Set, Type, get_type_hints

from temporalio import workflow
from temporalio.common import RawValue, VersioningBehavior
Expand Down Expand Up @@ -469,3 +469,103 @@ def test_workflow_update_validator_not_update():
"Update validator method my_validator parameters do not match update method my_update parameters"
in str(err.value)
)


def _assert_config_function_parity(
function_obj: Callable[..., Any],
config_class: Type[Any],
excluded_params: Set[str],
) -> None:
function_name = function_obj.__name__
config_name = config_class.__name__

# Get the signature and type hints
function_sig = inspect.signature(function_obj)
config_hints = get_type_hints(config_class)

# Get parameter names from function (excluding excluded ones and applying mappings)
expected_config_params = set(
[name for name in function_sig.parameters.keys() if name not in excluded_params]
)

# Get parameter names from config
actual_config_params = set(
[name for name in config_hints.keys() if name not in excluded_params]
)

# Check for missing and extra parameters
missing_in_config = expected_config_params - actual_config_params
extra_in_config = actual_config_params - expected_config_params

# Build detailed error message if there are mismatches
if missing_in_config or extra_in_config:
error_parts = []
if missing_in_config:
error_parts.append(
f"{config_name} is missing parameters: {sorted(missing_in_config)}"
)
if extra_in_config:
error_parts.append(
f"{config_name} has extra parameters: {sorted(extra_in_config)}"
)

error_message = "; ".join(error_parts)
error_message += f"\nExpected: {sorted(expected_config_params)}\nActual: {sorted(actual_config_params)}"
assert False, error_message


async def test_activity_config_parity_with_execute_activity():
"""Test that ActivityConfig has all the same parameters as execute_activity."""
_assert_config_function_parity(
workflow.execute_activity,
workflow.ActivityConfig,
excluded_params={"activity", "arg", "args", "result_type"},
)

with pytest.raises(workflow._NotInWorkflowEventLoopError):
await workflow.execute_activity("activity", **workflow.ActivityConfig())


def test_activity_config_parity_with_start_activity():
"""Test that ActivityConfig has all the same parameters as start_activity."""
_assert_config_function_parity(
workflow.start_activity,
workflow.ActivityConfig,
excluded_params={"activity", "arg", "args", "result_type"},
)

with pytest.raises(workflow._NotInWorkflowEventLoopError):
workflow.start_activity("workflow", **workflow.ActivityConfig())


async def test_child_workflow_config_parity_with_execute_child_workflow():
"""Test that ChildWorkflowConfig has all the same parameters as execute_child_workflow."""
_assert_config_function_parity(
workflow.execute_child_workflow,
workflow.ChildWorkflowConfig,
excluded_params={"workflow", "arg", "args", "result_type"},
)

with pytest.raises(workflow._NotInWorkflowEventLoopError):
await workflow.execute_child_workflow(
"workflow", **workflow.ChildWorkflowConfig()
)


async def test_child_workflow_config_parity_with_start_child_workflow():
"""Test that ChildWorkflowConfig has all the same parameters as start_child_workflow."""
_assert_config_function_parity(
workflow.start_child_workflow,
workflow.ChildWorkflowConfig,
excluded_params={
"workflow",
"arg",
"args",
"result_type",
},
)

with pytest.raises(workflow._NotInWorkflowEventLoopError):
await workflow.start_child_workflow(
"workflow", **workflow.ChildWorkflowConfig()
)
Loading