Skip to content

Commit 9b55852

Browse files
authored
refac: Engine now owns the queue; Attempt to simplify engines (~15 lines less per engine) (#498)
1 parent c1fea89 commit 9b55852

File tree

7 files changed

+117
-108
lines changed

7 files changed

+117
-108
lines changed

statemachine/contrib/diagram.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def _actions_getter(self):
7070
if isinstance(self.machine, StateMachine):
7171

7272
def getter(grouper) -> str:
73-
return self.machine._callbacks_registry.str(grouper.key)
73+
return self.machine._callbacks.str(grouper.key)
7474
else:
7575

7676
def getter(grouper) -> str:

statemachine/engines/async_.py

Lines changed: 27 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,22 @@
1-
from threading import Lock
21
from typing import TYPE_CHECKING
3-
from weakref import proxy
42

53
from ..event_data import EventData
64
from ..event_data import TriggerData
75
from ..exceptions import InvalidDefinition
86
from ..exceptions import TransitionNotAllowed
97
from ..i18n import _
10-
from ..state import State
11-
from ..transition import Transition
8+
from .base import BaseEngine
129

1310
if TYPE_CHECKING:
1411
from ..statemachine import StateMachine
12+
from ..transition import Transition
1513

1614

17-
class AsyncEngine:
15+
class AsyncEngine(BaseEngine):
1816
def __init__(self, sm: "StateMachine", rtc: bool = True):
19-
sm._engine = self
20-
self.sm = proxy(sm)
21-
self._sentinel = object()
2217
if not rtc:
2318
raise InvalidDefinition(_("Only RTC is supported on async engine"))
24-
self._processing = Lock()
19+
super().__init__(sm=sm, rtc=rtc)
2520

2621
async def activate_initial_state(self):
2722
"""
@@ -65,80 +60,71 @@ async def processing_loop(self):
6560
first_result = self._sentinel
6661
try:
6762
# Execute the triggers in the queue in FIFO order until the queue is empty
68-
while self.sm._external_queue:
69-
trigger_data = self.sm._external_queue.popleft()
63+
while self._external_queue:
64+
trigger_data = self._external_queue.popleft()
7065
try:
7166
result = await self._trigger(trigger_data)
7267
if first_result is self._sentinel:
7368
first_result = result
7469
except Exception:
7570
# Whe clear the queue as we don't have an expected behavior
7671
# and cannot keep processing
77-
self.sm._external_queue.clear()
72+
self._external_queue.clear()
7873
raise
7974
finally:
8075
self._processing.release()
8176
return first_result if first_result is not self._sentinel else None
8277

8378
async def _trigger(self, trigger_data: TriggerData):
84-
event_data = None
79+
executed = False
8580
if trigger_data.event == "__initial__":
86-
transition = Transition(State(), self.sm._get_initial_state(), event="__initial__")
87-
transition._specs.clear()
88-
event_data = EventData(trigger_data=trigger_data, transition=transition)
89-
await self._activate(event_data)
81+
transition = self._initial_transition(trigger_data)
82+
await self._activate(trigger_data, transition)
9083
return self._sentinel
9184

9285
state = self.sm.current_state
9386
for transition in state.transitions:
9487
if not transition.match(trigger_data.event):
9588
continue
9689

97-
event_data = EventData(trigger_data=trigger_data, transition=transition)
98-
args, kwargs = event_data.args, event_data.extended_kwargs
99-
await self.sm._callbacks_registry.async_call(
100-
transition.validators.key, *args, **kwargs
101-
)
102-
if not await self.sm._callbacks_registry.async_all(
103-
transition.cond.key, *args, **kwargs
104-
):
90+
executed, result = await self._activate(trigger_data, transition)
91+
if not executed:
10592
continue
106-
107-
result = await self._activate(event_data)
108-
event_data.result = result
109-
event_data.executed = True
11093
break
11194
else:
11295
if not self.sm.allow_event_without_transition:
11396
raise TransitionNotAllowed(trigger_data.event, state)
11497

115-
return event_data.result if event_data else None
98+
return result if executed else None
11699

117-
async def _activate(self, event_data: EventData):
100+
async def _activate(self, trigger_data: TriggerData, transition: "Transition"):
101+
event_data = EventData(trigger_data=trigger_data, transition=transition)
118102
args, kwargs = event_data.args, event_data.extended_kwargs
119-
transition = event_data.transition
120-
source = event_data.state
103+
104+
await self.sm._callbacks.async_call(transition.validators.key, *args, **kwargs)
105+
if not await self.sm._callbacks.async_all(transition.cond.key, *args, **kwargs):
106+
return False, None
107+
108+
source = transition.source
121109
target = transition.target
122110

123-
result = await self.sm._callbacks_registry.async_call(
124-
transition.before.key, *args, **kwargs
125-
)
111+
result = await self.sm._callbacks.async_call(transition.before.key, *args, **kwargs)
126112
if source is not None and not transition.internal:
127-
await self.sm._callbacks_registry.async_call(source.exit.key, *args, **kwargs)
113+
await self.sm._callbacks.async_call(source.exit.key, *args, **kwargs)
128114

129-
result += await self.sm._callbacks_registry.async_call(transition.on.key, *args, **kwargs)
115+
result += await self.sm._callbacks.async_call(transition.on.key, *args, **kwargs)
130116

131117
self.sm.current_state = target
132118
event_data.state = target
133119
kwargs["state"] = target
134120

135121
if not transition.internal:
136-
await self.sm._callbacks_registry.async_call(target.enter.key, *args, **kwargs)
137-
await self.sm._callbacks_registry.async_call(transition.after.key, *args, **kwargs)
122+
await self.sm._callbacks.async_call(target.enter.key, *args, **kwargs)
123+
await self.sm._callbacks.async_call(transition.after.key, *args, **kwargs)
138124

139125
if len(result) == 0:
140126
result = None
141127
elif len(result) == 1:
142128
result = result[0]
143129

144-
return result
130+
return True, result

statemachine/engines/base.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from collections import deque
2+
from threading import Lock
3+
from typing import TYPE_CHECKING
4+
from weakref import proxy
5+
6+
from statemachine.event import BoundEvent
7+
8+
from ..event_data import TriggerData
9+
from ..state import State
10+
from ..transition import Transition
11+
12+
if TYPE_CHECKING:
13+
from ..statemachine import StateMachine
14+
15+
16+
class BaseEngine:
17+
def __init__(self, sm: "StateMachine", rtc: bool = True):
18+
self.sm: StateMachine = proxy(sm)
19+
self._external_queue: deque = deque()
20+
self._sentinel = object()
21+
self._rtc = rtc
22+
self._processing = Lock()
23+
24+
def put(self, trigger_data: TriggerData):
25+
"""Put the trigger on the queue without blocking the caller."""
26+
self._external_queue.append(trigger_data)
27+
28+
def start(self):
29+
if self.sm.current_state_value is not None:
30+
return
31+
32+
trigger_data = TriggerData(
33+
machine=self.sm,
34+
event=BoundEvent("__initial__", _sm=self.sm),
35+
)
36+
self.put(trigger_data)
37+
38+
def _initial_transition(self, trigger_data):
39+
transition = Transition(State(), self.sm._get_initial_state(), event="__initial__")
40+
transition._specs.clear()
41+
return transition

statemachine/engines/sync.py

Lines changed: 29 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,17 @@
1-
from threading import Lock
21
from typing import TYPE_CHECKING
3-
from weakref import proxy
42

53
from ..event_data import EventData
64
from ..event_data import TriggerData
75
from ..exceptions import TransitionNotAllowed
8-
from ..state import State
9-
from ..transition import Transition
6+
from .base import BaseEngine
107

118
if TYPE_CHECKING:
12-
from ..statemachine import StateMachine
9+
from ..transition import Transition
1310

1411

15-
class SyncEngine:
16-
def __init__(self, sm: "StateMachine", rtc: bool = True):
17-
sm._engine = self
18-
self.sm = proxy(sm)
19-
self._sentinel = object()
20-
self._rtc = rtc
21-
self._processing = Lock()
12+
class SyncEngine(BaseEngine):
13+
def start(self):
14+
super().start()
2215
self.activate_initial_state()
2316

2417
def activate_initial_state(self):
@@ -54,7 +47,7 @@ def processing_loop(self):
5447
"""
5548
if not self._rtc:
5649
# The machine is in "synchronous" mode
57-
trigger_data = self.sm._external_queue.popleft()
50+
trigger_data = self._external_queue.popleft()
5851
return self._trigger(trigger_data)
5952

6053
# We make sure that only the first event enters the processing critical section,
@@ -68,74 +61,72 @@ def processing_loop(self):
6861
first_result = self._sentinel
6962
try:
7063
# Execute the triggers in the queue in FIFO order until the queue is empty
71-
while self.sm._external_queue:
72-
trigger_data = self.sm._external_queue.popleft()
64+
while self._external_queue:
65+
trigger_data = self._external_queue.popleft()
7366
try:
7467
result = self._trigger(trigger_data)
7568
if first_result is self._sentinel:
7669
first_result = result
7770
except Exception:
7871
# Whe clear the queue as we don't have an expected behavior
7972
# and cannot keep processing
80-
self.sm._external_queue.clear()
73+
self._external_queue.clear()
8174
raise
8275
finally:
8376
self._processing.release()
8477
return first_result if first_result is not self._sentinel else None
8578

8679
def _trigger(self, trigger_data: TriggerData):
87-
event_data = None
80+
executed = False
8881
if trigger_data.event == "__initial__":
89-
transition = Transition(State(), self.sm._get_initial_state(), event="__initial__")
90-
transition._specs.clear()
91-
event_data = EventData(trigger_data=trigger_data, transition=transition)
92-
self._activate(event_data)
82+
transition = self._initial_transition(trigger_data)
83+
self._activate(trigger_data, transition)
9384
return self._sentinel
9485

9586
state = self.sm.current_state
9687
for transition in state.transitions:
9788
if not transition.match(trigger_data.event):
9889
continue
9990

100-
event_data = EventData(trigger_data=trigger_data, transition=transition)
101-
args, kwargs = event_data.args, event_data.extended_kwargs
102-
self.sm._callbacks_registry.call(transition.validators.key, *args, **kwargs)
103-
if not self.sm._callbacks_registry.all(transition.cond.key, *args, **kwargs):
91+
executed, result = self._activate(trigger_data, transition)
92+
if not executed:
10493
continue
10594

106-
result = self._activate(event_data)
107-
event_data.result = result
108-
event_data.executed = True
10995
break
11096
else:
11197
if not self.sm.allow_event_without_transition:
11298
raise TransitionNotAllowed(trigger_data.event, state)
11399

114-
return event_data.result if event_data else None
100+
return result if executed else None
115101

116-
def _activate(self, event_data: EventData):
102+
def _activate(self, trigger_data: TriggerData, transition: "Transition"):
103+
event_data = EventData(trigger_data=trigger_data, transition=transition)
117104
args, kwargs = event_data.args, event_data.extended_kwargs
118-
transition = event_data.transition
119-
source = event_data.state
105+
106+
self.sm._callbacks.call(transition.validators.key, *args, **kwargs)
107+
if not self.sm._callbacks.all(transition.cond.key, *args, **kwargs):
108+
return False, None
109+
110+
source = transition.source
120111
target = transition.target
121112

122-
result = self.sm._callbacks_registry.call(transition.before.key, *args, **kwargs)
113+
result = self.sm._callbacks.call(transition.before.key, *args, **kwargs)
123114
if source is not None and not transition.internal:
124-
self.sm._callbacks_registry.call(source.exit.key, *args, **kwargs)
115+
self.sm._callbacks.call(source.exit.key, *args, **kwargs)
125116

126-
result += self.sm._callbacks_registry.call(transition.on.key, *args, **kwargs)
117+
result += self.sm._callbacks.call(transition.on.key, *args, **kwargs)
127118

128119
self.sm.current_state = target
129120
event_data.state = target
130121
kwargs["state"] = target
131122

132123
if not transition.internal:
133-
self.sm._callbacks_registry.call(target.enter.key, *args, **kwargs)
134-
self.sm._callbacks_registry.call(transition.after.key, *args, **kwargs)
124+
self.sm._callbacks.call(target.enter.key, *args, **kwargs)
125+
self.sm._callbacks.call(transition.after.key, *args, **kwargs)
135126

136127
if len(result) == 0:
137128
result = None
138129
elif len(result) == 1:
139130
result = result[0]
140131

141-
return result
132+
return True, result

0 commit comments

Comments
 (0)