Skip to content

feat: Add finalize action after every transition (#386) #529

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions docs/actions.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ when something changes, and are not bound to a specific state or event:

- `after_transition()`

- `finalize()`

The following example offers an overview of the "generic" callbacks available:

```py
Expand Down Expand Up @@ -54,6 +56,10 @@ The following example offers an overview of the "generic" callbacks available:
...
... def after_transition(self, event, state):
... print(f"After '{event}', on the '{state.id}' state.")
...
... def finalize(self, event, source, target, state):
... print(f"Finalizing transition {event} from {source.id} to {target.id}")
... print(f"Current state: {state.id}")


>>> sm = ExampleStateMachine() # On initialization, the machine run a special event `__initial__`
Expand All @@ -65,6 +71,8 @@ Exiting 'initial' state from 'loop' event.
On 'loop', on the 'initial' state.
Entering 'initial' state from 'loop' event.
After 'loop', on the 'initial' state.
Finalizing transition loop from initial to initial
Current state: initial
['before_transition_return', 'on_transition_return']

>>> sm.go()
Expand All @@ -73,6 +81,8 @@ Exiting 'initial' state from 'go' event.
On 'go', on the 'initial' state.
Entering 'final' state from 'go' event.
After 'go', on the 'final' state.
Finalizing transition go from initial to final
Current state: final
['before_transition_return', 'on_transition_return']

```
Expand Down Expand Up @@ -346,6 +356,10 @@ Actions registered on the same group don't have order guaranties and are execute
- `after_<event>()`, `after_transition()`
- `destination`
- Callbacks declared in the transition or event.
* - Finalize
- `finalize()`
- `destination`
- Guaranteed to run after every transition attempt, whether successful or failed.

```

Expand Down Expand Up @@ -381,6 +395,9 @@ defined explicitly. The following provides an example:
... def on_loop(self):
... return "On loop"
...
... def finalize(self):
... # Finalize return values are not included in results
... return "Finalize"

>>> sm = ExampleStateMachine()

Expand Down
1 change: 1 addition & 0 deletions statemachine/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class CallbackGroup(IntEnum):
ON = auto()
AFTER = auto()
COND = auto()
FINALIZE = auto()

def build_key(self, specs: "CallbackSpecList") -> str:
return f"{self.name}@{id(specs)}"
Expand Down
60 changes: 40 additions & 20 deletions statemachine/engines/async_.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import TYPE_CHECKING

from ..callbacks import CallbackGroup
from ..event_data import EventData
from ..event_data import TriggerData
from ..exceptions import InvalidDefinition
Expand Down Expand Up @@ -101,30 +102,49 @@ async def _activate(self, trigger_data: TriggerData, transition: "Transition"):
event_data = EventData(trigger_data=trigger_data, transition=transition)
args, kwargs = event_data.args, event_data.extended_kwargs

await self.sm._callbacks.async_call(transition.validators.key, *args, **kwargs)
if not await self.sm._callbacks.async_all(transition.cond.key, *args, **kwargs):
return False, None
try:
await self.sm._callbacks.async_call(transition.validators.key, *args, **kwargs)
if not await self.sm._callbacks.async_all(transition.cond.key, *args, **kwargs):
return False, None

source = transition.source
target = transition.target

source = transition.source
target = transition.target
result = await self.sm._callbacks.async_call(transition.before.key, *args, **kwargs)
if source is not None and not transition.internal:
await self.sm._callbacks.async_call(source.exit.key, *args, **kwargs)

result = await self.sm._callbacks.async_call(transition.before.key, *args, **kwargs)
if source is not None and not transition.internal:
await self.sm._callbacks.async_call(source.exit.key, *args, **kwargs)
result += await self.sm._callbacks.async_call(transition.on.key, *args, **kwargs)

result += await self.sm._callbacks.async_call(transition.on.key, *args, **kwargs)
self.sm.current_state = target
event_data.state = target
kwargs["state"] = target

self.sm.current_state = target
event_data.state = target
kwargs["state"] = target
if not transition.internal:
await self.sm._callbacks.async_call(target.enter.key, *args, **kwargs)
await self.sm._callbacks.async_call(transition.after.key, *args, **kwargs)

if not transition.internal:
await self.sm._callbacks.async_call(target.enter.key, *args, **kwargs)
await self.sm._callbacks.async_call(transition.after.key, *args, **kwargs)
if len(result) == 0:
result = None
elif len(result) == 1:
result = result[0]

if len(result) == 0:
result = None
elif len(result) == 1:
result = result[0]
return True, result
finally:
# Run finalize actions regardless of success/failure
await self._run_finalize_actions(event_data)

return True, result
async def _run_finalize_actions(self, event_data: EventData):
"""Run finalize actions after a transition attempt."""
try:
args, kwargs = event_data.args, event_data.extended_kwargs
await self.sm._callbacks.async_call(
CallbackGroup.FINALIZE.build_key(event_data.transition._specs),
*args,
**kwargs,
)
except Exception as e:
# Log but don't re-raise finalize errors
import logging

logging.error(f"Error in finalize action: {e}")
60 changes: 40 additions & 20 deletions statemachine/engines/sync.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import TYPE_CHECKING

from ..callbacks import CallbackGroup
from ..event_data import EventData
from ..event_data import TriggerData
from ..exceptions import TransitionNotAllowed
Expand Down Expand Up @@ -103,30 +104,49 @@ def _activate(self, trigger_data: TriggerData, transition: "Transition"):
event_data = EventData(trigger_data=trigger_data, transition=transition)
args, kwargs = event_data.args, event_data.extended_kwargs

self.sm._callbacks.call(transition.validators.key, *args, **kwargs)
if not self.sm._callbacks.all(transition.cond.key, *args, **kwargs):
return False, None
try:
self.sm._callbacks.call(transition.validators.key, *args, **kwargs)
if not self.sm._callbacks.all(transition.cond.key, *args, **kwargs):
return False, None

source = transition.source
target = transition.target

source = transition.source
target = transition.target
result = self.sm._callbacks.call(transition.before.key, *args, **kwargs)
if source is not None and not transition.internal:
self.sm._callbacks.call(source.exit.key, *args, **kwargs)

result = self.sm._callbacks.call(transition.before.key, *args, **kwargs)
if source is not None and not transition.internal:
self.sm._callbacks.call(source.exit.key, *args, **kwargs)
result += self.sm._callbacks.call(transition.on.key, *args, **kwargs)

result += self.sm._callbacks.call(transition.on.key, *args, **kwargs)
self.sm.current_state = target
event_data.state = target
kwargs["state"] = target

self.sm.current_state = target
event_data.state = target
kwargs["state"] = target
if not transition.internal:
self.sm._callbacks.call(target.enter.key, *args, **kwargs)
self.sm._callbacks.call(transition.after.key, *args, **kwargs)

if not transition.internal:
self.sm._callbacks.call(target.enter.key, *args, **kwargs)
self.sm._callbacks.call(transition.after.key, *args, **kwargs)
if len(result) == 0:
result = None
elif len(result) == 1:
result = result[0]

if len(result) == 0:
result = None
elif len(result) == 1:
result = result[0]
return True, result
finally:
# Run finalize actions regardless of success/failure
self._run_finalize_actions(event_data)

return True, result
def _run_finalize_actions(self, event_data: EventData):
"""Run finalize actions after a transition attempt."""
try:
args, kwargs = event_data.args, event_data.extended_kwargs
self.sm._callbacks.call(
CallbackGroup.FINALIZE.build_key(event_data.transition._specs),
*args,
**kwargs,
)
except Exception as e:
# Log but don't re-raise finalize errors
import logging

logging.error(f"Error in finalize action: {e}")
13 changes: 13 additions & 0 deletions statemachine/transition.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ class Transition:
before the transition is executed.
after (Optional[Union[str, Callable, List[Callable]]]): The callbacks to be invoked
after the transition is executed.
finalize (Optional[Union[str, Callable, List[Callable]]]): The callbacks to be invoked
after the transition is executed.
"""

def __init__(
Expand All @@ -48,6 +50,7 @@ def __init__(
on=None,
before=None,
after=None,
finalize=None,
):
self.source = source
self.target = target
Expand All @@ -68,6 +71,9 @@ def __init__(
self.after = self._specs.grouper(CallbackGroup.AFTER).add(
after, priority=CallbackPriority.INLINE
)
self.finalize = self._specs.grouper(CallbackGroup.FINALIZE).add(
finalize, priority=CallbackPriority.INLINE
)
self.cond = (
self._specs.grouper(CallbackGroup.COND)
.add(cond, priority=CallbackPriority.INLINE, expected_value=True)
Expand All @@ -87,6 +93,7 @@ def _setup(self):
before = self.before.add
on = self.on.add
after = self.after.add
finalize = self.finalize.add

before("before_transition", priority=CallbackPriority.GENERIC, is_convention=True)
on("on_transition", priority=CallbackPriority.GENERIC, is_convention=True)
Expand Down Expand Up @@ -118,6 +125,12 @@ def _setup(self):
is_convention=True,
)

finalize(
"finalize",
priority=CallbackPriority.AFTER,
is_convention=True,
)

def match(self, event: str):
return self._events.match(event)

Expand Down
Loading