Skip to content

Commit 606beaa

Browse files
committed
add state machine
state_machine test state_machine_update
1 parent 6477929 commit 606beaa

File tree

3 files changed

+434
-0
lines changed

3 files changed

+434
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
from state_machine import StateMachine
2+
3+
4+
class Robot:
5+
def __init__(self):
6+
self.battery = 100
7+
self.task_progress = 0
8+
9+
# Initialize state machine
10+
self.machine = StateMachine(self, "robot_sm")
11+
12+
# Add state transition rules
13+
self.machine.add_transition(
14+
src_state="patrolling",
15+
event="detect_task",
16+
dst_state="executing_task",
17+
guard=None,
18+
action=None,
19+
)
20+
21+
self.machine.add_transition(
22+
src_state="executing_task",
23+
event="task_complete",
24+
dst_state="patrolling",
25+
guard=None,
26+
action="reset_task",
27+
)
28+
29+
self.machine.add_transition(
30+
src_state="executing_task",
31+
event="low_battery",
32+
dst_state="returning_to_base",
33+
guard="is_battery_low",
34+
)
35+
36+
self.machine.add_transition(
37+
src_state="returning_to_base",
38+
event="reach_base",
39+
dst_state="charging",
40+
guard=None,
41+
action=None,
42+
)
43+
44+
self.machine.add_transition(
45+
src_state="charging",
46+
event="charge_complete",
47+
dst_state="patrolling",
48+
guard=None,
49+
action="battery_full",
50+
)
51+
52+
# Set initial state
53+
self.machine.set_current_state("patrolling")
54+
55+
def is_battery_low(self):
56+
"""Battery level check condition"""
57+
return self.battery < 30
58+
59+
def reset_task(self):
60+
"""Reset task progress"""
61+
self.task_progress = 0
62+
print("[Action] Task progress has been reset")
63+
64+
# Modify state entry callback naming convention (add state_ prefix)
65+
def on_enter_executing_task(self):
66+
print("\n------ Start Executing Task ------")
67+
print(f"Current battery: {self.battery}%")
68+
while self.machine.get_current_state().name == "executing_task":
69+
self.task_progress += 10
70+
self.battery -= 25
71+
print(
72+
f"Task progress: {self.task_progress}%, Remaining battery: {self.battery}%"
73+
)
74+
75+
if self.task_progress >= 100:
76+
self.machine.process("task_complete")
77+
break
78+
elif self.is_battery_low():
79+
self.machine.process("low_battery")
80+
break
81+
82+
def on_enter_returning_to_base(self):
83+
print("\nLow battery, returning to charging station...")
84+
self.machine.process("reach_base")
85+
86+
def on_enter_charging(self):
87+
print("\n------ Charging ------")
88+
self.battery = 100
89+
print("Charging complete!")
90+
self.machine.process("charge_complete")
91+
92+
93+
# Keep the test section structure the same, only modify the trigger method
94+
if __name__ == "__main__":
95+
robot = Robot()
96+
97+
print(f"Initial state: {robot.machine.get_current_state().name}")
98+
print("------------")
99+
100+
# Trigger task detection event
101+
robot.machine.process("detect_task")
102+
103+
print("\n------------")
104+
print(f"Final state: {robot.machine.get_current_state().name}")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
from collections.abc import Callable
2+
3+
4+
class State:
5+
def __init__(self, name, on_enter=None, on_exit=None):
6+
self.name = name
7+
self.on_enter = on_enter
8+
self.on_exit = on_exit
9+
10+
def enter(self):
11+
print(f"entering <{self.name}>")
12+
if self.on_enter:
13+
self.on_enter()
14+
15+
def exit(self):
16+
print(f"exiting <{self.name}>")
17+
if self.on_exit:
18+
self.on_exit()
19+
20+
21+
class StateMachine(State):
22+
def __init__(self, model: object, name: str, on_enter=None, on_exit=None):
23+
State.__init__(self, name, on_enter, on_exit)
24+
self.states = {}
25+
self.events = {}
26+
self.transition_table = {}
27+
self._model = model
28+
self._state: StateMachine = None
29+
30+
def add_transition(
31+
self,
32+
src_state: str | State,
33+
event: str,
34+
dst_state: str | State,
35+
guard: str | Callable = None,
36+
action: str | Callable = None,
37+
) -> None:
38+
"""Add a transition to the state machine.
39+
40+
Args:
41+
src_state: Source state name or State object
42+
event: Event name or Event object
43+
dst_state: Destination state name or State object
44+
guard: Guard function name or callable
45+
action: Action function name or callable
46+
"""
47+
# Convert string parameters to objects if necessary
48+
self.register_state(src_state)
49+
self.register_event(event)
50+
self.register_state(dst_state)
51+
52+
def get_state_obj(state):
53+
return state if isinstance(state, State) else self.get_state(state)
54+
55+
def get_callable(func):
56+
return func if callable(func) else getattr(self._model, func, None)
57+
58+
src_state_obj = get_state_obj(src_state)
59+
dst_state_obj = get_state_obj(dst_state)
60+
61+
guard_func = get_callable(guard) if guard else None
62+
action_func = get_callable(action) if action else None
63+
self.transition_table[(src_state_obj.name, event)] = (
64+
dst_state_obj,
65+
guard_func,
66+
action_func,
67+
)
68+
69+
def state_transition(self, src_state: State, event: str):
70+
if (src_state.name, event) not in self.transition_table:
71+
raise ValueError(
72+
f"|{self.name}| invalid transition: <{src_state.name}> : [{event}]"
73+
)
74+
75+
dst_state, guard, action = self.transition_table[(src_state.name, event)]
76+
77+
def call_guard(guard):
78+
if callable(guard):
79+
return guard()
80+
else:
81+
return True
82+
83+
def call_action(action):
84+
if callable(action):
85+
action()
86+
87+
if call_guard(guard):
88+
call_action(action)
89+
if src_state.name != dst_state.name:
90+
print(
91+
f"|{self.name}| transitioning from <{src_state.name}> to <{dst_state.name}>"
92+
)
93+
src_state.exit()
94+
self._state = dst_state
95+
dst_state.enter()
96+
else:
97+
print(
98+
f"|{self.name}| skipping transition from <{src_state.name}> to <{dst_state.name}> because guard failed"
99+
)
100+
101+
def register_state(self, state: str | State, on_enter=None, on_exit=None):
102+
"""Register a state in the state machine.
103+
104+
Args:
105+
state (str | State): The state to register. Can be either a string (state name)
106+
or a State object.
107+
on_enter (Callable, optional): Callback function to be executed when entering the state.
108+
If state is a string and on_enter is None, it will look for
109+
a method named 'on_enter_<state>' in the model.
110+
on_exit (Callable, optional): Callback function to be executed when exiting the state.
111+
If state is a string and on_exit is None, it will look for
112+
a method named 'on_exit_<state>' in the model.
113+
114+
Raises:
115+
ValueError: If a state with the same name is already registered with a different type.
116+
"""
117+
if isinstance(state, str):
118+
if on_enter is None:
119+
on_enter = getattr(self._model, "on_enter_" + state, None)
120+
if on_exit is None:
121+
on_exit = getattr(self._model, "on_exit_" + state, None)
122+
self.states[state] = State(state, on_enter, on_exit)
123+
return
124+
125+
name = state.name
126+
if name in self.states and type(self.states[name]) is not type(state):
127+
raise ValueError(
128+
f'State "{name}" {type(state).__name__} already registered as {type(self.states[name]).__name__}'
129+
)
130+
131+
self.states[name] = state
132+
133+
def register_event(self, event: str):
134+
self.events[event] = event
135+
136+
def get_state(self, name):
137+
return self.states[name]
138+
139+
def get_event(self, name):
140+
return self.events[name]
141+
142+
def has_event(self, event: str):
143+
return event in self.events
144+
145+
def set_current_state(self, state: State | str):
146+
if isinstance(state, str):
147+
self._state = self.get_state(state)
148+
else:
149+
self._state = state
150+
151+
def get_current_state(self):
152+
return self._state
153+
154+
def process(self, event: str) -> None:
155+
"""Process an event in the state machine.
156+
157+
Args:
158+
event: Event name or Event object
159+
"""
160+
if self._state is None:
161+
raise ValueError("State machine is not initialized")
162+
163+
if self.has_event(event):
164+
self.state_transition(self._state, event)
165+
else:
166+
raise ValueError(f"Invalid event: {event}")

0 commit comments

Comments
 (0)