Skip to content

Commit

Permalink
Get rid of mypy errors.
Browse files Browse the repository at this point in the history
We need to ignore numpy and to heavily use assertions/exceptions to down-scope
the types we actually use without violating the Liskov principle.
This wouldn't be necessary if we had proper metaclasses :/ maybe we should look
into ABC.

The List[B] -> Sequence[B] changes are because Sequences, unlike Lists are
covariant and allow specializing the type of their elements.
This is WAI python/mypy#2984
  • Loading branch information
Dietr1ch committed Sep 7, 2021
1 parent 93bc065 commit afd6905
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 34 deletions.
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import random
from typing import List

import numpy as np
import numpy as np # type: ignore
from termcolor import colored

from search.algorithms.bfs import BFS
Expand Down
15 changes: 9 additions & 6 deletions search/algorithms/dijkstra.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@


class Dijkstra(SearchAlgorithm):
"""Breadth-first Search.
"""Best-first Search.
Implements Open with a List and a set.
It uses the base Node class as we don't need to extend it.
Implements Open with an intrusive Heap.
Extends the base search Nodes to store internals.
"""

class DijkstraNode(Node, IntrusiveHeap.Node):
Expand Down Expand Up @@ -63,11 +63,11 @@ def __str__(self) -> str:
# IntrusiveHeap.Node
# ------------------
def __lt__(self, other) -> bool:
"""Returns < of "(f, h)" to perform informed/optimistic tie-breaking."""
"""Compares the cost of reaching the nodes."""
return self.g < other.g

class Open(SearchAlgorithm.Open):
"""An Open set implementation using a Queue."""
"""An Open set implementation using an intrusive Heap."""

def __init__(self):
self.heap: IntrusiveHeap = IntrusiveHeap()
Expand Down Expand Up @@ -120,8 +120,11 @@ def create_starting_node(self, state: Space.State) -> Node:
self.nodes_created += 1
return Dijkstra.DijkstraNode(state, action=None, parent=None, g=0)

def reach(self, state: Space.State, action: Space.Action, parent: DijkstraNode):
def reach(self, state: Space.State, action: Space.Action, parent: Node):
"""Reaches a state and updates Open."""
if not isinstance(parent, Dijkstra.DijkstraNode):
raise TypeError("Only DijkstraNode is supported")

# pylint: disable=invalid-name
g = parent.g + action.cost(parent.state)

Expand Down
35 changes: 26 additions & 9 deletions search/problems/grid/board2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
import copy
import random
from enum import Enum
from typing import Iterable, List, Set, Tuple
from typing import Iterable, List, Sequence, Set, Tuple

import numpy as np
import numpy as np # type: ignore
from search.space import Heuristic, Problem, RandomAccessSpace, Space
from termcolor import colored

Expand Down Expand Up @@ -51,8 +51,11 @@ def __str__(self) -> str:
"""The string representation of this state."""
return "Grid2d.State[{}]".format(self.agent_position)

def __eq__(self, other: Grid2D.State) -> bool:
def __eq__(self, other: object) -> bool:
"""Compares 2 states."""
if not isinstance(other, Grid2D.State):
return NotImplemented

return self.agent_position == other.agent_position

class Action(Space.Action, Enum):
Expand Down Expand Up @@ -107,9 +110,12 @@ def adjacent_coordinates(
# Space
# -----
def neighbors(
self, state: Grid2D.State
self, state: Space.State
) -> Iterable[Tuple[Grid2D.Action, Grid2D.State]]:
"""Generates the Actions and neighbor States."""
if not isinstance(state, Grid2D.State):
raise TypeError("Only Grid2D.State is supported")

# pylint: disable=invalid-name
for a, cell in self.adjacent_coordinates(cell=state.agent_position):
if not self.is_wall(cell):
Expand Down Expand Up @@ -168,14 +174,16 @@ class Grid2DProblem(Problem):
def __init__(
self,
space: Grid2D,
starting_states: Set[Grid2D.State],
starting_states: Sequence[Grid2D.State],
goals: Set[Tuple[int, int]],
):
super().__init__(space, starting_states)
self.goals = goals

def is_goal(self, state: Grid2D.State) -> bool:
def is_goal(self, state: Space.State) -> bool:
"""Checks if a state is a goal for this Problem."""
if not isinstance(state, Grid2D.State):
raise TypeError("Only Grid2D.State is supported")
return state.agent_position in self.goals

def all_heuristics(self) -> List[Heuristic]:
Expand All @@ -197,8 +205,11 @@ class Grid2DDiscreteMetric(Heuristic):
def __init__(self, problem):
super().__init__(problem)

def __call__(self, state: Grid2D.State):
def __call__(self, state: Space.State):
"""The estimated cost of reaching the goal."""
if not isinstance(state, Grid2D.State):
raise TypeError("Only Grid2D.State is supported")

if state.agent_position in self.problem.goals:
return 0
return 1
Expand All @@ -210,8 +221,11 @@ class Grid2DSingleDimensionDistance(Heuristic):
def __init__(self, problem):
super().__init__(problem)

def __call__(self, state: Grid2D.State):
def __call__(self, state: Space.State):
"""The estimated cost of reaching the goal."""
if not isinstance(state, Grid2D.State):
raise TypeError("Only Grid2D.State is supported")

if self.problem.goals:
pos = state.agent_position
return max(
Expand All @@ -227,8 +241,11 @@ class Grid2DManhattanDistance(Heuristic):
def __init__(self, problem):
super().__init__(problem)

def __call__(self, state: Grid2D.State):
def __call__(self, state: Space.State):
"""The estimated cost of reaching the goal."""
if not isinstance(state, Grid2D.State):
raise TypeError("Only Grid2D.State is supported")

if self.problem.goals:
pos = state.agent_position
return min([manhattan_distance_2d(pos, g) for g in self.problem.goals])
Expand Down
36 changes: 26 additions & 10 deletions search/problems/grid/bomb.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
import copy
import random
from enum import Enum
from typing import Iterable, List, Set, Tuple
from typing import Iterable, List, Sequence, Set, Tuple

import numpy as np
import numpy as np # type: ignore
from search.space import Heuristic, Problem, RandomAccessSpace, Space
from termcolor import colored

Expand Down Expand Up @@ -253,12 +253,16 @@ def to_str(self, problem: Problem, state: Space.State) -> str:

space = problem.space
grid_str = ""
grid_str += "Bombs: " + colored(state.bombs, "red", attrs=["bold"]) + "\n"
grid_str += "Bombs: " + colored(str(state.bombs), "red", attrs=["bold"]) + "\n"
grid_str += colored(
(" █" + ("█" * (space.W)) + "█\n"), "green", attrs=["bold"]
)

starting_positions = [s.agent_position for s in problem.starting_states]
starting_positions = []
for starting_state in problem.starting_states:
assert isinstance(starting_state, Bombs2D.State)
starting_positions.append(starting_state.agent_position)

for y in range(space.H):
grid_str += colored("%3d " % y, "white")
grid_str += colored("█", "green", attrs=["bold"])
Expand Down Expand Up @@ -297,14 +301,17 @@ class Bombs2DProblem(Problem):
def __init__(
self,
space: Bombs2D,
starting_states: Set[Bombs2D.State],
starting_states: Sequence[Bombs2D.State],
goals: Set[Tuple[int, int]],
):
super().__init__(space, starting_states)
self.goals = goals

def is_goal(self, state: Bombs2D.State) -> bool:
def is_goal(self, state: Space.State) -> bool:
"""Checks if a state is a goal for this Problem."""
if not isinstance(state, Bombs2D.State):
raise TypeError("Only Bombs2D.State is supported")

return state.agent_position in self.goals

def all_heuristics(self) -> List[Heuristic]:
Expand All @@ -326,8 +333,11 @@ class Bombs2DDiscreteMetric(Heuristic):
def __init__(self, problem):
super().__init__(problem)

def __call__(self, state: Bombs2D.State):
def __call__(self, state: Space.State):
"""The estimated cost of reaching the goal."""
if not isinstance(state, Bombs2D.State):
raise TypeError("Only Bombs2D.State is supported")

if state.agent_position in self.problem.goals:
return 0
return 1
Expand All @@ -339,8 +349,11 @@ class Bombs2DSingleDimensionDistance(Heuristic):
def __init__(self, problem):
super().__init__(problem)

def __call__(self, state: Bombs2D.State):
def __call__(self, state: Space.State):
"""The estimated cost of reaching the goal."""
if not isinstance(state, Bombs2D.State):
raise TypeError("Only Bombs2D.State is supported")

if self.problem.goals:
pos = state.agent_position
return max(
Expand All @@ -356,8 +369,11 @@ class Bombs2DManhattanDistance(Heuristic):
def __init__(self, problem):
super().__init__(problem)

def __call__(self, state: Bombs2D.State):
def __call__(self, state: Space.State):
"""The estimated cost of reaching the goal."""
if not isinstance(state, Bombs2D.State):
raise TypeError("Only Bombs2D.State is supported")

if self.problem.goals:
pos = state.agent_position
return min([manhattan_distance_2d(pos, g) for g in self.problem.goals])
Expand Down Expand Up @@ -397,7 +413,7 @@ def __init__(self, grid_lines: List[str], starting_bombs: int):
self.goal_positions.append(position)

self.space = Bombs2D(grid)
self.starting_states: List[Bombs2D] = []
self.starting_states: List[Bombs2D.State] = []
for start_position in start_positions:
self.starting_states.append(
Bombs2D.State(
Expand Down
11 changes: 5 additions & 6 deletions search/problems/nm_puzzle.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
import copy
import random
from enum import Enum
from typing import Iterable, List, Set, Tuple
from typing import Iterable, List, Sequence, Set, Tuple

import numpy as np
import numpy as np # type: ignore
from search.space import Heuristic, Problem, RandomAccessSpace, Space
from termcolor import colored

Expand Down Expand Up @@ -240,12 +240,11 @@ class NMPuzzleProblem(Problem):
def __init__(
self,
space: NMPuzzle,
starting_states: Set[NMPuzzle.State],
goal_states: Set[NMPuzzle.State],
starting_states: Sequence[NMPuzzle.State],
goal_states: Sequence[NMPuzzle.State],
):
# NOTE: mypy refuses to say that [B] is a [A] when B:A :/
super().__init__(space, starting_states)
self.goal_states = goal_states
self.goal_states = set(goal_states)

def is_goal(self, state: Space.State) -> bool:
"""Checks if a state is a goal for this Problem."""
Expand Down
4 changes: 2 additions & 2 deletions search/space.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import copy
from enum import Enum
from typing import Iterable, List, Set, Tuple
from typing import Iterable, List, Sequence, Tuple


class Space:
Expand Down Expand Up @@ -90,7 +90,7 @@ def random_state(self) -> Space.State:
class Problem:
"""A generic problem definition that uses a goal function."""

def __init__(self, space: Space, starting_states: Set[Space.State]):
def __init__(self, space: Space, starting_states: Sequence[Space.State]):
self.space = space
self.starting_states = starting_states

Expand Down

0 comments on commit afd6905

Please sign in to comment.