|
| 1 | +import matplotlib.pyplot as plt |
| 2 | +import networkx |
| 3 | +import pytest |
| 4 | + |
| 5 | +from algorithm.graph.test.graph_data_utils import create_weighted_city_graph |
| 6 | + |
| 7 | +city_graph = create_weighted_city_graph() |
| 8 | + |
| 9 | + |
| 10 | +@pytest.mark.skip(reason='This test is for visualization only') |
| 11 | +def test_graph_mst_temp(): |
| 12 | + mst = networkx.minimum_spanning_tree(city_graph, algorithm='kruskal') |
| 13 | + pos = networkx.spring_layout(mst) |
| 14 | + networkx.draw_networkx_nodes(mst, pos) |
| 15 | + networkx.draw_networkx_edges(mst, pos, width=1) |
| 16 | + networkx.draw_networkx_labels(mst, pos, font_size=10) |
| 17 | + plt.show() |
| 18 | + |
| 19 | + |
| 20 | +city_edges = {('Seattle', 'San Francisco'), ('San Francisco', 'Los Angeles'), ('Los Angeles', 'Riverside'), |
| 21 | + ('Riverside', 'Phoenix'), ('Phoenix', 'Dallas'), ('Dallas', 'Houston'), ('Houston', 'Atlanta'), |
| 22 | + ('Atlanta', 'Miami'), ('Atlanta', 'Washington'), ('Washington', 'Philadelphia'), |
| 23 | + ('Philadelphia', 'New York'), ('New York', 'Boston'), ('Washington', 'Detroit'), ('Detroit', 'Chicago')} |
| 24 | + |
| 25 | + |
| 26 | +@pytest.mark.benchmark(group='graph_minimum_spanning_tree') |
| 27 | +@pytest.mark.parametrize( |
| 28 | + argnames='graph, algorithm, expected_total_weight, expected_edges', |
| 29 | + argvalues=[ |
| 30 | + (city_graph, 'kruskal', 5372, city_edges), |
| 31 | + (city_graph, 'prim', 5372, city_edges), |
| 32 | + (city_graph, 'boruvka', 5372, city_edges), |
| 33 | + ], |
| 34 | + ids=['kruskal', 'prim', 'boruvka']) |
| 35 | +def test_graph_mst(benchmark, graph, algorithm, expected_total_weight, expected_edges): |
| 36 | + mst = benchmark(networkx.minimum_spanning_tree, graph, algorithm=algorithm) |
| 37 | + |
| 38 | + mst_total_weight = sum(d['weight'] for u, v, d in mst.edges(data=True)) |
| 39 | + assert expected_total_weight == mst_total_weight |
| 40 | + |
| 41 | + mst_edges = set((u, v) for u, v, d in mst.edges(data=True)) |
| 42 | + normalized_set1 = {tuple(sorted(edge)) for edge in expected_edges} |
| 43 | + normalized_set2 = {tuple(sorted(edge)) for edge in mst_edges} |
| 44 | + assert normalized_set1 == normalized_set2 |
0 commit comments