Skip to content

Commit db84bd7

Browse files
committed
Fix double counting of heuristic
1 parent 205b5ac commit db84bd7

File tree

4 files changed

+22
-23
lines changed

4 files changed

+22
-23
lines changed

a_star/a_star.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -54,26 +54,25 @@ def backtrack(self, current):
5454
yield current
5555
current = self.visited[current].came_from
5656

57-
Node = collections.namedtuple("Node", ["cost_estimate", "vertex", "came_from"])
57+
Node = collections.namedtuple("Node", ["cost_estimate", "cost_so_far", "vertex", "came_from"])
5858

5959
def a_star_search(graph, start, end, heuristic):
6060
"""
6161
Calculates the shortest path from start to end.
6262
63-
:param graph: A graph object. The graph object can be anything that implements the following methods for a vertex of any comparable and hashable type V:
63+
:param graph: A graph object. The graph object can be anything that implements the following method for a vertex of any comparable and hashable type V:
6464
65-
graph.neighbors( from_vertex:V ) : Iterable( to_vertex:V, to_vertex:V, ...)
66-
graph.cost( from_vertex:V, to_vertex:V ) : float
65+
graph.neighbors( from_vertex:V ) : Iterable( (to_vertex:V,edge_cost:float), (to_vertex:V,edge_cost:float), ...)
6766
6867
:param start: The starting vertex, as type V.
6968
:param end: The ending vertex, as type V.
70-
:param heuristic: Heuristic lower-bound cost function taking arguments ( from_vertex:V, end:V ).
69+
:param heuristic: Heuristic lower-bound cost function taking arguments ( from_vertex:V, end:V ) and returning float.
7170
:returns: A DijkstraHeap object.
7271
7372
"""
7473

7574
frontier = DijkstraHeap()
76-
frontier.insert( Node(heuristic(start, end), start, None) )
75+
frontier.insert( Node(heuristic(start, end), 0, start, None) )
7776

7877
while True:
7978

@@ -84,14 +83,11 @@ def a_star_search(graph, start, end, heuristic):
8483
if current_node.vertex == end:
8584
return frontier
8685

87-
for neighbor in graph.neighbors( current_node.vertex ):
86+
for neighbor, edge_cost in graph.neighbors( current_node.vertex ):
8887

89-
cost_so_far = current_node.cost_estimate - heuristic(current_node.vertex, end)
90-
new_cost = ( cost_so_far
91-
+ graph.cost(current_node.vertex, neighbor)
92-
+ heuristic(neighbor, end) )
88+
new_cost = current_node.cost_so_far + edge_cost
9389

94-
new_node = Node(new_cost, neighbor, current_node.vertex)
90+
new_node = Node(new_cost + heuristic(neighbor, end), new_cost, neighbor, current_node.vertex)
9591

9692
frontier.insert(new_node)
9793

a_star/tools.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ def is_wall(self, point):
2323
"""
2424
return point not in self.walls
2525

26+
def cost(self, from_node, to_node):
27+
return 1
28+
2629
def neighbors(self, point):
2730
""" Yields the valid neighbours of a given point.
2831
@@ -34,7 +37,7 @@ def neighbors(self, point):
3437
candidates = [(x + 1, y), (x, y - 1), (x - 1, y), (x, y + 1)]
3538
candidates = filter(self.is_inside, candidates)
3639
candidates = filter(self.is_wall, candidates)
37-
yield from candidates
40+
yield from ((pt, self.cost(point, pt)) for pt in candidates)
3841

3942
def _draw_tile(self, point, style, width):
4043
""" Returns a symbol for the current point given the style dictionary and the drawing width.

examples/maze_solving_example.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,10 @@ def from_id_width(point, width):
4747

4848
graph.draw(width=5, point_to = {k:v.came_from for k,v in frontier.visited.items()}, start=(1, 4), goal=(7, 8))
4949

50+
print("[costs]")
51+
52+
graph.draw(width=5, number = {k:v.cost_so_far for k,v in frontier.visited.items()}, start=(1, 4), goal=(7, 8))
53+
5054
print("[total cost estimates]")
51-
costs = {k:v.cost_estimate for k,v in frontier.visited.items()}
52-
graph.draw(width=5, number = costs, start=(1, 4), goal=(7, 8))
5355

54-
print("[costs]")
55-
costs_so_far = { k: v - manhattan_distance(k, (7, 8)) for k,v in costs.items() }
56-
graph.draw(width=5, number = costs_so_far, start=(1, 4), goal=(7, 8))
56+
graph.draw(width=5, number = {k:v.cost_estimate for k,v in frontier.visited.items()}, start=(1, 4), goal=(7, 8))

tests/test_dijkstra_heap.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ class DijkstraHeapTests(unittest.TestCase):
99

1010
def test_sorted_unique(self):
1111

12-
test_data = [a_star.Node(x,x,None) for x in (5,7,3,2,5,3,5,7,2,1,3,1)]
12+
test_data = [a_star.Node(x,x,x,None) for x in (5,7,3,2,5,3,5,7,2,1,3,1)]
1313
sorted_unique = sorted(set(test_data))
1414

1515
# Prepare and insert elements in DijkstraHeap
@@ -30,7 +30,7 @@ def test_sorted_unique(self):
3030

3131
def test_delete_repeated(self):
3232

33-
test_data = [a_star.Node(x,x,None) for x in (1,1,1,1,1,1,1,1,1)]
33+
test_data = [a_star.Node(x,x,x,None) for x in (1,1,1,1,1,1,1,1,1)]
3434

3535
# Prepare and insert elements in DijkstraHeap
3636

@@ -45,12 +45,12 @@ def test_delete_repeated(self):
4545
elem = frontier.pop()
4646
if elem:
4747
result.append(elem)
48-
self.assertTrue( result == [a_star.Node(1,1,None)] )
48+
self.assertTrue( result == [a_star.Node(1,1,1,None)] )
4949

5050

5151
def test_came_from(self):
5252

53-
test_data = [a_star.Node(x,x,x+1) for x in range(10)]
53+
test_data = [a_star.Node(x,x,x,x+1) for x in range(10)]
5454

5555
# Prepare and insert elements in DijkstraHeap
5656

@@ -70,7 +70,7 @@ def test_came_from(self):
7070

7171
def test_came_from_unique(self):
7272

73-
test_data = [a_star.Node(x,x,x+1) for x in [0,1,2,3,2,4,1]]
73+
test_data = [a_star.Node(x,x,x,x+1) for x in [0,1,2,3,2,4,1]]
7474

7575
# Prepare and insert elements in DijkstraHeap
7676

0 commit comments

Comments
 (0)