Skip to content

Update Client: Adds Workflow-Reset API #687

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions temporalio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1632,6 +1632,53 @@ async def query(
)
)

async def reset(
self,
*,
workflow_task_finish_event_id: int,
reason: Optional[str] = None,
reset_reapply_exclude_types: Optional[Iterable[temporalio.api.enums.v1.ResetReapplyExcludeType]] = None,
rpc_metadata: Optional[Mapping[str, str]] = None,
rpc_timeout: Optional[timedelta] = None,
) -> str:
"""Reset the workflow.

This will issue a workflow reset for :py:attr:`run_id` if present.

Args:
workflow_task_finish_event_id: The event-id of the completed
workflow task to reset to.
reason: Reason for the reset.
reset_reapply_exclude_types: When the workflow is reset, these
will not be reapplied. Eg. if we don't want Signals to be
reapplied in the new run generated due to the reset, then
we could pass in
`[ResetReapplyExcludeType.RESET_REAPPLY_EXCLUDE_TYPE_SIGNAL]`.
rpc_metadata: Headers used on the RPC call. Keys here override
client-level RPC metadata keys.
rpc_timeout: Optional RPC deadline to set for the RPC call.

Returns:
The run-id associated with the new run.

Raises:
RPCError: Workflow could not be reset.
"""
if rpc_metadata is None:
rpc_metadata = {}
resp = await self._client._impl.reset_workflow(
ResetWorkflowInput(
id=self._id,
run_id=self._run_id,
reason=reason,
workflow_task_finish_event_id=workflow_task_finish_event_id,
reset_reapply_exclude_types=reset_reapply_exclude_types,
rpc_metadata=rpc_metadata,
rpc_timeout=rpc_timeout,
)
)
return resp.run_id

# Overload for no-param signal
@overload
async def signal(
Expand Down Expand Up @@ -4599,6 +4646,18 @@ class QueryWorkflowInput:
rpc_metadata: Mapping[str, str]
rpc_timeout: Optional[timedelta]

@dataclass
class ResetWorkflowInput:
"""Input for :py:meth:`OutboundInterceptor.reset_workflow`."""

id: str
run_id: Optional[str]
reason: Optional[str]
workflow_task_finish_event_id: int
reset_reapply_exclude_types: Optional[Iterable[temporalio.api.enums.v1.ResetReapplyExcludeType]]
rpc_metadata: Mapping[str, str]
rpc_timeout: Optional[timedelta]


@dataclass
class SignalWorkflowInput:
Expand Down Expand Up @@ -4889,6 +4948,13 @@ async def query_workflow(self, input: QueryWorkflowInput) -> Any:
"""Called for every :py:meth:`WorkflowHandle.query` call."""
return await self.next.query_workflow(input)

async def reset_workflow(
self,
input: ResetWorkflowInput,
) -> temporalio.api.workflowservice.v1.ResetWorkflowExecutionResponse:
"""Called for every :py:meth:`WorkflowHandle.reset` call."""
return await self.next.reset_workflow(input)

async def signal_workflow(self, input: SignalWorkflowInput) -> None:
"""Called for every :py:meth:`WorkflowHandle.signal` call."""
await self.next.signal_workflow(input)
Expand Down Expand Up @@ -5209,6 +5275,25 @@ async def query_workflow(self, input: QueryWorkflowInput) -> Any:
warnings.warn(f"Expected single query result, got {len(results)}")
return results[0]

async def reset_workflow(
self,
input: ResetWorkflowInput,
) -> temporalio.api.workflowservice.v1.ResetWorkflowExecutionResponse:
req = temporalio.api.workflowservice.v1.ResetWorkflowExecutionRequest(
namespace=self._client.namespace,
workflow_execution=temporalio.api.common.v1.WorkflowExecution(
workflow_id=input.id,
run_id=input.run_id or "",
),
reason=input.reason or "",
workflow_task_finish_event_id=input.workflow_task_finish_event_id,
request_id=str(uuid.uuid4()),
reset_reapply_exclude_types=input.reset_reapply_exclude_types,
)
return await self._client.workflow_service.reset_workflow_execution(
req, retry=True, metadata=input.rpc_metadata, timeout=input.rpc_timeout
)

async def signal_workflow(self, input: SignalWorkflowInput) -> None:
req = temporalio.api.workflowservice.v1.SignalWorkflowExecutionRequest(
namespace=self._client.namespace,
Expand Down