diff --git a/solve/solver.py b/solve/solver.py index 3afb04c..70e845c 100644 --- a/solve/solver.py +++ b/solve/solver.py @@ -1,5 +1,3 @@ -from itertools import chain - from .states import * @@ -12,13 +10,34 @@ def solve_state[S, M]( """ Makes moves stating from the given initial state until one of the states is completed. """ + if initial_state.is_done: return initial_state + + # region First cycle + if is_doing_triumph and last_position_touched: + states = [ + next_state + for next_state in initial_state.next_states(is_doing_triumph) + if last_position_touched != next_state.first_position + ] + else: + states = list(initial_state.next_states(is_doing_triumph)) + + for state in states: + if state.is_done: + return state + + # endregion + max_cycles = initial_state.max_cycles - states = [initial_state] - for _ in range(max_cycles): - states = list(chain.from_iterable(s.next_states(is_doing_triumph) for s in states)) - for s in states: - if s.is_done and (not is_doing_triumph or last_position_touched != s.first_position): - return s + for _ in range(max_cycles - 1): + states = [ + next_state + for state in states + for next_state in state.next_states(is_doing_triumph) + ] + for state in states: + if state.is_done: + return state else: raise ValueError( f'cannot solve encounter with initial {initial_state} '