Skip to content

Commit 205b5ac

Browse files
committed
Generalize to non-int-tuple vertices
1 parent 65e14f7 commit 205b5ac

File tree

6 files changed

+66
-71
lines changed

6 files changed

+66
-71
lines changed

a_star/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
__author__ = 'Pablo Galindo Salgado'
22

3-
from .a_star import a_star_search,heuristic,Node,DijkstraHeap
3+
from .a_star import a_star_search, DijkstraHeap, Node

a_star/a_star.py

Lines changed: 37 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -2,105 +2,96 @@
22
import heapq
33

44

5-
def heuristic(a, b):
6-
"""
7-
The heuristic function of the A* algorithm. In this case the Manhattan distance.
8-
9-
:param a: Tuple of two ints ( Point A)
10-
:param b: Tuple of two ints ( Point B)
11-
:returns: integer ( Distance between A nd B )
12-
"""
13-
(x1, y1) = a
14-
(x2, y2) = b
15-
return abs(x1 - x2) + abs(y1 - y2)
16-
17-
185
class DijkstraHeap(list):
196
"""
207
An augmented heap for the A* algorithm. This class encapsulated the residual logic of
21-
the A* algorithm like for example how to manage elements already visited that remain
22-
in the heap, elements already visited that are not in the heap and from where we came to
23-
a visited element.
8+
the A* algorithm like for example how to manage nodes already visited that remain
9+
in the heap, nodes already visited that are not in the heap and from where we came to
10+
a visited node.
2411
2512
This class will have three main elements:
2613
2714
- A heap that will act as a cost queue (self).
28-
- A visited dict that will act as a visited set and as a mapping of the form point:came_from
29-
- A costs dict that will act as a mapping of the form point:cost_so_far
15+
- A visited dict that will act as a visited set and as a mapping of the form vertex:node
3016
"""
31-
def __init__(self, first_node = None):
17+
def __init__(self):
3218
self.visited = dict()
33-
self.costs = dict()
3419

35-
if first_node is not None:
36-
self.insert(first_node)
37-
38-
def insert(self, element):
20+
def insert(self, node):
3921
"""
40-
Insert an element into the Dijkstra Heap.
22+
Insert a node into the Dijkstra Heap.
4123
42-
:param element: A Node object.
24+
:param node: A Node object.
4325
:return: None
4426
"""
4527

46-
if element.point not in self.visited:
47-
heapq.heappush(self,element)
28+
if node.vertex not in self.visited:
29+
heapq.heappush(self, node)
4830

4931
def pop(self):
5032
"""
51-
Pop an element from the Dijkstra Heap, adding it to the visited and cost dicts.
33+
Pop a node from the Dijkstra Heap, adding it to the visited dict.
5234
5335
:return: A Node object
5436
"""
5537

56-
while self and self[0].point in self.visited:
38+
while self and self[0].vertex in self.visited:
5739
heapq.heappop(self)
5840

5941
if self:
6042
next_elem = heapq.heappop(self)
61-
self.visited[next_elem.point] = next_elem.came_from
62-
self.costs[next_elem.point] = next_elem.cost_estimate
43+
self.visited[next_elem.vertex] = next_elem
6344
return next_elem
6445

46+
def backtrack(self, current):
47+
"""
48+
Retrieve the backward path, starting from current.
6549
50+
:param current: The starting vertex of the backward path
51+
:return: A generator of vertices; listify and reverse to get forward path
52+
"""
53+
while current is not None:
54+
yield current
55+
current = self.visited[current].came_from
6656

57+
Node = collections.namedtuple("Node", ["cost_estimate", "vertex", "came_from"])
6758

68-
Node = collections.namedtuple("Node","cost_estimate point came_from")
69-
70-
def a_star_search(graph, start, end):
59+
def a_star_search(graph, start, end, heuristic):
7160
"""
7261
Calculates the shortest path from start to end.
7362
74-
:param graph: A graph object. The graph object can be anything that implements the following methods:
63+
:param graph: A graph object. The graph object can be anything that implements the following methods for a vertex of any comparable and hashable type V:
7564
76-
graph.neighbors( (x:int, y:int) ) : Iterable( (x:int,y:int), (x:int,y:int), ...)
77-
graph.cost( (x:int,y:int) ) : int
65+
graph.neighbors( from_vertex:V ) : Iterable( to_vertex:V, to_vertex:V, ...)
66+
graph.cost( from_vertex:V, to_vertex:V ) : float
7867
79-
:param start: Tuple of two ints representing the starting point.
80-
:param end: Tuple of two ints representing the ending point.
68+
:param start: The starting vertex, as type V.
69+
:param end: The ending vertex, as type V.
70+
:param heuristic: Heuristic lower-bound cost function taking arguments ( from_vertex:V, end:V ).
8171
:returns: A DijkstraHeap object.
8272
8373
"""
8474

85-
frontier = DijkstraHeap( Node(heuristic(start, end), start, None) )
75+
frontier = DijkstraHeap()
76+
frontier.insert( Node(heuristic(start, end), start, None) )
8677

8778
while True:
8879

8980
current_node = frontier.pop()
9081

9182
if not current_node:
9283
raise ValueError("No path from start to end")
93-
if current_node.point == end:
84+
if current_node.vertex == end:
9485
return frontier
9586

96-
for neighbor in graph.neighbors( current_node.point ):
87+
for neighbor in graph.neighbors( current_node.vertex ):
9788

98-
cost_so_far = current_node.cost_estimate - heuristic(current_node.point, end)
89+
cost_so_far = current_node.cost_estimate - heuristic(current_node.vertex, end)
9990
new_cost = ( cost_so_far
100-
+ graph.cost(current_node.point, neighbor)
91+
+ graph.cost(current_node.vertex, neighbor)
10192
+ heuristic(neighbor, end) )
10293

103-
new_node = Node(new_cost, neighbor, current_node.point)
94+
new_node = Node(new_cost, neighbor, current_node.vertex)
10495

10596
frontier.insert(new_node)
10697

a_star/tools.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,3 +100,15 @@ def cost(self, from_node, to_node):
100100
return self.weights.get(to_node, 1)
101101

102102

103+
def manhattan_distance(a, b):
104+
"""
105+
The heuristic function of the A* algorithm. In this case the Manhattan distance,
106+
to be used in the case where vertices are 2D coordinate tuples (a grid).
107+
108+
:param a: Tuple of two ints ( Point A)
109+
:param b: Tuple of two ints ( Point B)
110+
:returns: integer ( Distance between A nd B )
111+
"""
112+
(x1, y1) = a
113+
(x2, y2) = b
114+
return abs(x1 - x2) + abs(y1 - y2)

examples/maze_solving_example.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def from_id_width(point, width):
1313
if __name__ == '__main__':
1414

1515
import a_star
16-
from a_star.tools import GridWithWeights
16+
from a_star.tools import GridWithWeights, manhattan_distance
1717
import random
1818

1919
# Construct a cool wall collection from this aparently arbitraty points.
@@ -41,17 +41,16 @@ def from_id_width(point, width):
4141
(7, 3), (7, 4), (7, 5)]}
4242

4343
# Call the A* algorithm and get the frontier
44-
frontier = a_star.a_star_search(graph = graph, start=(1, 4), end=(7, 8))
44+
frontier = a_star.a_star_search(graph = graph, start=(1, 4), end=(7, 8), heuristic=manhattan_distance)
4545

4646
# Print the results
4747

48-
graph.draw(width=5, point_to = frontier.visited, start=(1, 4), goal=(7, 8))
49-
50-
print("[costs]")
51-
52-
costs_so_far = { k: v - a_star.heuristic(k, (7, 8)) for k,v in frontier.costs.items() }
53-
graph.draw(width=5, number = costs_so_far, start=(1, 4), goal=(7, 8))
48+
graph.draw(width=5, point_to = {k:v.came_from for k,v in frontier.visited.items()}, start=(1, 4), goal=(7, 8))
5449

5550
print("[total cost estimates]")
51+
costs = {k:v.cost_estimate for k,v in frontier.visited.items()}
52+
graph.draw(width=5, number = costs, start=(1, 4), goal=(7, 8))
5653

57-
graph.draw(width=5, number = frontier.costs, start=(1, 4), goal=(7, 8)) # cost estimates
54+
print("[costs]")
55+
costs_so_far = { k: v - manhattan_distance(k, (7, 8)) for k,v in costs.items() }
56+
graph.draw(width=5, number = costs_so_far, start=(1, 4), goal=(7, 8))

tests/test_dijkstra_heap.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def test_came_from(self):
6666
came_from_dic = { x:x+1 for x in range(10) }
6767

6868

69-
self.assertTrue( came_from_dic == frontier.visited )
69+
self.assertTrue( came_from_dic == {k:v.came_from for k,v in frontier.visited.items()} )
7070

7171
def test_came_from_unique(self):
7272

@@ -86,7 +86,7 @@ def test_came_from_unique(self):
8686
came_from_dic = { x:x+1 for x in range(5) }
8787

8888

89-
self.assertTrue( came_from_dic == frontier.visited )
89+
self.assertTrue( came_from_dic == {k:v.came_from for k,v in frontier.visited.items()} )
9090

9191
if __name__ == "__main__":
9292
unittest.main()

tests/test_maze.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,9 @@
22

33
import unittest
44
import a_star
5-
from a_star.tools import GridWithWeights
5+
from a_star.tools import GridWithWeights, manhattan_distance
66

77

8-
def backtrack( came_from_dict , start, end):
9-
10-
current = end
11-
yield current
12-
while current != start:
13-
current = came_from_dict[current]
14-
yield current
15-
168
class MazeTests(unittest.TestCase):
179

1810
def test_maze_1(self):
@@ -22,12 +14,13 @@ def test_maze_1(self):
2214
weights = {(1,0):20,(3,0) : 2}
2315
maze.weights = weights
2416
my_solution = [(3,0),(3,1),(3,2),(3,3),(2,3),(1,3),(0,3),(0,2)]
17+
2518
end = (3,0)
2619
start = (0,2)
2720

2821
# Call the A* algorithm and get the frontier
29-
frontier = a_star.a_star_search(graph = maze, start=start, end=end)
30-
solution = list(backtrack(frontier.visited,start,end))
22+
frontier = a_star.a_star_search(graph = maze, start=start, end=end, heuristic=manhattan_distance)
23+
solution = list(frontier.backtrack(end))
3124
self.assertTrue( solution == my_solution )
3225

3326

@@ -45,9 +38,9 @@ def test_weights_instead_of_walls(self):
4538
(2, 0), (1, 0), (0, 0), (0, 1),
4639
(0, 2), (0, 3)]
4740
# Call the A* algorithm and get the frontier
48-
frontier = a_star.a_star_search(graph = maze, start=start, end=end)
41+
frontier = a_star.a_star_search(graph = maze, start=start, end=end, heuristic=manhattan_distance)
4942

50-
solution = list(backtrack(frontier.visited,start,end))
43+
solution = list(frontier.backtrack(end))
5144

5245
self.assertTrue( solution == my_solution )
5346

0 commit comments

Comments
 (0)