|
19 | 19 | 'strongly_connected_components', |
20 | 20 | 'depth_first_search', |
21 | 21 | 'shortest_paths', |
22 | | - 'topological_sort' |
| 22 | + 'topological_sort', |
| 23 | + 'topological_sort_parallel' |
23 | 24 | ] |
24 | 25 |
|
25 | 26 | Stack = Queue = deque |
@@ -772,27 +773,107 @@ def topological_sort(graph: Graph, algorithm: str) -> list: |
772 | 773 | return getattr(algorithms, func)(graph) |
773 | 774 |
|
774 | 775 | def _kahn_adjacency_list(graph: Graph) -> list: |
775 | | - S = set(graph.vertices) |
776 | | - in_degree = dict() |
| 776 | + S = Queue() |
| 777 | + in_degree = {u: 0 for u in graph.vertices} |
777 | 778 | for u in graph.vertices: |
778 | 779 | for v in graph.neighbors(u): |
779 | | - if v.name not in in_degree: |
780 | | - in_degree[v.name] = 0 |
781 | 780 | in_degree[v.name] += 1 |
782 | | - if v.name in S: |
783 | | - S.remove(v.name) |
| 781 | + for u in graph.vertices: |
| 782 | + if in_degree[u] == 0: |
| 783 | + S.append(u) |
| 784 | + in_degree.pop(u) |
784 | 785 |
|
785 | 786 | L = [] |
786 | 787 | while S: |
787 | | - n = S.pop() |
| 788 | + n = S.popleft() |
788 | 789 | L.append(n) |
789 | 790 | for m in graph.neighbors(n): |
790 | 791 | graph.remove_edge(n, m.name) |
791 | 792 | in_degree[m.name] -= 1 |
792 | 793 | if in_degree[m.name] == 0: |
793 | | - S.add(m.name) |
| 794 | + S.append(m.name) |
794 | 795 | in_degree.pop(m.name) |
795 | 796 |
|
796 | 797 | if in_degree: |
797 | 798 | raise ValueError("Graph is not acyclic.") |
798 | 799 | return L |
| 800 | + |
| 801 | +def topological_sort_parallel(graph: Graph, algorithm: str, num_threads: int) -> list: |
| 802 | + """ |
| 803 | + Performs topological sort on the given graph using given algorithm using |
| 804 | + given number of threads. |
| 805 | +
|
| 806 | + Parameters |
| 807 | + ========== |
| 808 | +
|
| 809 | + graph: Graph |
| 810 | + The graph under consideration. |
| 811 | + algorithm: str |
| 812 | + The algorithm to be used. |
| 813 | + Currently, following are supported, |
| 814 | + 'kahn' -> Kahn's algorithm as given in [1]. |
| 815 | + num_threads: int |
| 816 | + The maximum number of threads to be used. |
| 817 | +
|
| 818 | + Returns |
| 819 | + ======= |
| 820 | +
|
| 821 | + list |
| 822 | + The list of topologically sorted vertices. |
| 823 | +
|
| 824 | + Examples |
| 825 | + ======== |
| 826 | +
|
| 827 | + >>> from pydatastructs import Graph, AdjacencyListGraphNode, topological_sort_parallel |
| 828 | + >>> v_1 = AdjacencyListGraphNode('v_1') |
| 829 | + >>> v_2 = AdjacencyListGraphNode('v_2') |
| 830 | + >>> graph = Graph(v_1, v_2) |
| 831 | + >>> graph.add_edge('v_1', 'v_2') |
| 832 | + >>> topological_sort_parallel(graph, 'kahn', 1) |
| 833 | + ['v_1', 'v_2'] |
| 834 | +
|
| 835 | + References |
| 836 | + ========== |
| 837 | +
|
| 838 | + .. [1] https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm |
| 839 | + """ |
| 840 | + import pydatastructs.graphs.algorithms as algorithms |
| 841 | + func = "_" + algorithm + "_" + graph._impl + '_parallel' |
| 842 | + if not hasattr(algorithms, func): |
| 843 | + raise NotImplementedError( |
| 844 | + "Currently %s algorithm isn't implemented for " |
| 845 | + "performing topological sort on %s graphs."%(algorithm, graph._impl)) |
| 846 | + return getattr(algorithms, func)(graph, num_threads) |
| 847 | + |
| 848 | +def _kahn_adjacency_list_parallel(graph: Graph, num_threads: int) -> list: |
| 849 | + num_vertices = len(graph.vertices) |
| 850 | + |
| 851 | + def _collect_source_nodes(graph: Graph) -> list: |
| 852 | + S = [] |
| 853 | + in_degree = {u: 0 for u in graph.vertices} |
| 854 | + for u in graph.vertices: |
| 855 | + for v in graph.neighbors(u): |
| 856 | + in_degree[v.name] += 1 |
| 857 | + for u in in_degree: |
| 858 | + if in_degree[u] == 0: |
| 859 | + S.append(u) |
| 860 | + return list(S) |
| 861 | + |
| 862 | + def _job(graph: Graph, u: str): |
| 863 | + for v in graph.neighbors(u): |
| 864 | + graph.remove_edge(u, v.name) |
| 865 | + |
| 866 | + L = [] |
| 867 | + source_nodes = _collect_source_nodes(graph) |
| 868 | + while source_nodes: |
| 869 | + with ThreadPoolExecutor(max_workers=num_threads) as Executor: |
| 870 | + for node in source_nodes: |
| 871 | + L.append(node) |
| 872 | + Executor.submit(_job, graph, node) |
| 873 | + for node in source_nodes: |
| 874 | + graph.remove_vertex(node) |
| 875 | + source_nodes = _collect_source_nodes(graph) |
| 876 | + |
| 877 | + if len(L) != num_vertices: |
| 878 | + raise ValueError("Graph is not acyclic.") |
| 879 | + return L |
0 commit comments