Skip to content

Commit 427c18f

Browse files
committed
feat: add minimum spanning tree tests
1 parent f3a60a9 commit 427c18f

File tree

2 files changed

+45
-0
lines changed

2 files changed

+45
-0
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,7 @@ algorithm Prim(G, root):
347347
**Examples**
348348

349349
- Maze problem: [java](java-algorithm/src/main/java/com/example/algorithm/graph) | A maze problem is that find a path from the start to the goal. The maze is represented by a graph. The start and the goal are represented by vertices. The path is represented by a sequence of vertices.
350+
- Minimum spanning tree (Kruskal, Prim, Boruvka), CCSP#4.4.2: [python](python-algorithm/algorithm/graph/test)(test) | Find the minimum spanning tree of a graph.
350351

351352
[:arrow_up_small: back to toc](#table-of-contents)
352353

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import matplotlib.pyplot as plt
2+
import networkx
3+
import pytest
4+
5+
from algorithm.graph.test.graph_data_utils import create_weighted_city_graph
6+
7+
city_graph = create_weighted_city_graph()
8+
9+
10+
@pytest.mark.skip(reason='This test is for visualization only')
11+
def test_graph_mst_temp():
12+
mst = networkx.minimum_spanning_tree(city_graph, algorithm='kruskal')
13+
pos = networkx.spring_layout(mst)
14+
networkx.draw_networkx_nodes(mst, pos)
15+
networkx.draw_networkx_edges(mst, pos, width=1)
16+
networkx.draw_networkx_labels(mst, pos, font_size=10)
17+
plt.show()
18+
19+
20+
city_edges = {('Seattle', 'San Francisco'), ('San Francisco', 'Los Angeles'), ('Los Angeles', 'Riverside'),
21+
('Riverside', 'Phoenix'), ('Phoenix', 'Dallas'), ('Dallas', 'Houston'), ('Houston', 'Atlanta'),
22+
('Atlanta', 'Miami'), ('Atlanta', 'Washington'), ('Washington', 'Philadelphia'),
23+
('Philadelphia', 'New York'), ('New York', 'Boston'), ('Washington', 'Detroit'), ('Detroit', 'Chicago')}
24+
25+
26+
@pytest.mark.benchmark(group='graph_minimum_spanning_tree')
27+
@pytest.mark.parametrize(
28+
argnames='graph, algorithm, expected_total_weight, expected_edges',
29+
argvalues=[
30+
(city_graph, 'kruskal', 5372, city_edges),
31+
(city_graph, 'prim', 5372, city_edges),
32+
(city_graph, 'boruvka', 5372, city_edges),
33+
],
34+
ids=['kruskal', 'prim', 'boruvka'])
35+
def test_graph_mst(benchmark, graph, algorithm, expected_total_weight, expected_edges):
36+
mst = benchmark(networkx.minimum_spanning_tree, graph, algorithm=algorithm)
37+
38+
mst_total_weight = sum(d['weight'] for u, v, d in mst.edges(data=True))
39+
assert expected_total_weight == mst_total_weight
40+
41+
mst_edges = set((u, v) for u, v, d in mst.edges(data=True))
42+
normalized_set1 = {tuple(sorted(edge)) for edge in expected_edges}
43+
normalized_set2 = {tuple(sorted(edge)) for edge in mst_edges}
44+
assert normalized_set1 == normalized_set2

0 commit comments

Comments
 (0)