Skip to content

C++ implementation of Prim's Minimum Spanning Tree Algorithm #685

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Jul 1, 2025
151 changes: 150 additions & 1 deletion pydatastructs/graphs/_backend/cpp/Algorithms.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
#include <queue>
#include <string>
#include <unordered_set>
#include <variant>
#include "GraphEdge.hpp"
#include "AdjacencyList.hpp"
#include "AdjacencyMatrix.hpp"


static PyObject* breadth_first_search_adjacency_list(PyObject* self, PyObject* args, PyObject* kwargs) {
PyObject* graph_obj;
const char* source_name;
Expand Down Expand Up @@ -153,3 +154,151 @@ static PyObject* breadth_first_search_adjacency_matrix(PyObject* self, PyObject*

Py_RETURN_NONE;
}

static PyObject* minimum_spanning_tree_prim_adjacency_list(PyObject* self, PyObject* args, PyObject* kwargs) {

PyObject* graph_obj;
static const char* kwlist[] = {"graph", nullptr};

if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!", const_cast<char**>(kwlist),
&AdjacencyListGraphType, &graph_obj)) {
return nullptr;
}

AdjacencyListGraph* graph = reinterpret_cast<AdjacencyListGraph*>(graph_obj);

struct EdgeTuple {
std::string source;
std::string target;
std::variant<std::monostate, int64_t, double, std::string> value;
DataType value_type;

bool operator>(const EdgeTuple& other) const {
if (value_type != other.value_type)
return value_type > other.value_type;
if (std::holds_alternative<int64_t>(value))
return std::get<int64_t>(value) > std::get<int64_t>(other.value);
if (std::holds_alternative<double>(value))
return std::get<double>(value) > std::get<double>(other.value);
if (std::holds_alternative<std::string>(value))
return std::get<std::string>(value) > std::get<std::string>(other.value);
return false;
}
};

std::priority_queue<EdgeTuple, std::vector<EdgeTuple>, std::greater<>> pq;
std::unordered_set<std::string> visited;

PyObject* mst_graph = PyObject_CallObject(reinterpret_cast<PyObject*>(&AdjacencyListGraphType), nullptr);
AdjacencyListGraph* mst = reinterpret_cast<AdjacencyListGraph*>(mst_graph);

std::string start = graph->node_map.begin()->first;
visited.insert(start);

AdjacencyListGraphNode* start_node = graph->node_map[start];

Py_INCREF(start_node);
mst->nodes.push_back(start_node);
mst->node_map[start] = start_node;

for (const auto& [adj_name, _] : start_node->adjacent) {
std::string key = make_edge_key(start, adj_name);
GraphEdge* edge = graph->edges[key];
EdgeTuple et;
et.source = start;
et.target = adj_name;
et.value_type = edge->value_type;

switch (edge->value_type) {
case DataType::Int:
et.value = std::get<int64_t>(edge->value);
break;
case DataType::Double:
et.value = std::get<double>(edge->value);
break;
case DataType::String:
et.value = std::get<std::string>(edge->value);
break;
default:
et.value = std::monostate{};
}

pq.push(et);
}

while (!pq.empty()) {
EdgeTuple edge = pq.top();
pq.pop();

if (visited.count(edge.target)) continue;
visited.insert(edge.target);

for (const std::string& name : {edge.source, edge.target}) {
if (!mst->node_map.count(name)) {
AdjacencyListGraphNode* node = graph->node_map[name];
Py_INCREF(node);
mst->nodes.push_back(node);
mst->node_map[name] = node;
}
}

AdjacencyListGraphNode* u = mst->node_map[edge.source];
AdjacencyListGraphNode* v = mst->node_map[edge.target];

Py_INCREF(v);
Py_INCREF(u);
u->adjacent[edge.target] = reinterpret_cast<PyObject*>(v);
v->adjacent[edge.source] = reinterpret_cast<PyObject*>(u);

std::string key_uv = make_edge_key(edge.source, edge.target);
GraphEdge* new_edge = PyObject_New(GraphEdge, &GraphEdgeType);
PyObject_Init(reinterpret_cast<PyObject*>(new_edge), &GraphEdgeType);
new (&new_edge->value) std::variant<std::monostate, int64_t, double, std::string>(edge.value);
new_edge->value_type = edge.value_type;
Py_INCREF(u);
Py_INCREF(v);
new_edge->source = reinterpret_cast<PyObject*>(u);
new_edge->target = reinterpret_cast<PyObject*>(v);
mst->edges[key_uv] = new_edge;

std::string key_vu = make_edge_key(edge.target, edge.source);
GraphEdge* new_edge_rev = PyObject_New(GraphEdge, &GraphEdgeType);
PyObject_Init(reinterpret_cast<PyObject*>(new_edge_rev), &GraphEdgeType);
new (&new_edge_rev->value) std::variant<std::monostate, int64_t, double, std::string>(edge.value);
new_edge_rev->value_type = edge.value_type;
Py_INCREF(u);
Py_INCREF(v);
new_edge_rev->source = reinterpret_cast<PyObject *>(v);
new_edge_rev->target = reinterpret_cast<PyObject*>(u);
mst->edges[key_vu] = new_edge_rev;

AdjacencyListGraphNode* next_node = graph->node_map[edge.target];

for (const auto& [adj_name, _] : next_node->adjacent) {
if (visited.count(adj_name)) continue;
std::string key = make_edge_key(edge.target, adj_name);
GraphEdge* adj_edge = graph->edges[key];
EdgeTuple adj_et;
adj_et.source = edge.target;
adj_et.target = adj_name;
adj_et.value_type = adj_edge->value_type;

switch (adj_edge->value_type) {
case DataType::Int:
adj_et.value = std::get<int64_t>(adj_edge->value);
break;
case DataType::Double:
adj_et.value = std::get<double>(adj_edge->value);
break;
case DataType::String:
adj_et.value = std::get<std::string>(adj_edge->value);
break;
default:
adj_et.value = std::monostate{};
}

pq.push(adj_et);
}
}
return reinterpret_cast<PyObject*>(mst);
}
1 change: 1 addition & 0 deletions pydatastructs/graphs/_backend/cpp/algorithms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
static PyMethodDef AlgorithmsMethods[] = {
{"bfs_adjacency_list", (PyCFunction)breadth_first_search_adjacency_list, METH_VARARGS | METH_KEYWORDS, "Run BFS on adjacency list with callback"},
{"bfs_adjacency_matrix", (PyCFunction)breadth_first_search_adjacency_matrix, METH_VARARGS | METH_KEYWORDS, "Run BFS on adjacency matrix with callback"},
{"minimum_spanning_tree_prim_adjacency_list", (PyCFunction)minimum_spanning_tree_prim_adjacency_list, METH_VARARGS | METH_KEYWORDS, "Run Prim's algorithm on adjacency list"},
{NULL, NULL, 0, NULL}
};

Expand Down
24 changes: 14 additions & 10 deletions pydatastructs/graphs/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,16 +338,20 @@ def minimum_spanning_tree(graph, algorithm, **kwargs):
should be used only for such graphs. Using with other
types of graphs may lead to unwanted results.
"""
raise_if_backend_is_not_python(
minimum_spanning_tree, kwargs.get('backend', Backend.PYTHON))
import pydatastructs.graphs.algorithms as algorithms
func = "_minimum_spanning_tree_" + algorithm + "_" + graph._impl
if not hasattr(algorithms, func):
raise NotImplementedError(
"Currently %s algoithm for %s implementation of graphs "
"isn't implemented for finding minimum spanning trees."
%(algorithm, graph._impl))
return getattr(algorithms, func)(graph)
backend = kwargs.get('backend', Backend.PYTHON)
if backend == Backend.PYTHON:
import pydatastructs.graphs.algorithms as algorithms
func = "_minimum_spanning_tree_" + algorithm + "_" + graph._impl
if not hasattr(algorithms, func):
raise NotImplementedError(
"Currently %s algoithm for %s implementation of graphs "
"isn't implemented for finding minimum spanning trees."
%(algorithm, graph._impl))
return getattr(algorithms, func)(graph)
else:
from pydatastructs.graphs._backend.cpp._algorithms import minimum_spanning_tree_prim_adjacency_list
if graph._impl == "adjacency_list" and algorithm == 'prim':
return minimum_spanning_tree_prim_adjacency_list(graph)

def _minimum_spanning_tree_parallel_kruskal_adjacency_list(graph, num_threads):
mst = _generate_mst_object(graph)
Expand Down
8 changes: 4 additions & 4 deletions pydatastructs/graphs/tests/test_adjacency_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,16 +67,16 @@ def test_adjacency_list():
g2.add_vertex(v)
g2.add_edge('v_4', 'v', 0)
g2.add_edge('v_5', 'v', 0)
g2.add_edge('v_6', 'v', 0)
g2.add_edge('v_6', 'v', "h")
assert g2.is_adjacent('v_4', 'v') is True
assert g2.is_adjacent('v_5', 'v') is True
assert g2.is_adjacent('v_6', 'v') is True
e1 = g2.get_edge('v_4', 'v')
e2 = g2.get_edge('v_5', 'v')
e3 = g2.get_edge('v_6', 'v')
assert (str(e1)) == "('v_4', 'v')"
assert (str(e2)) == "('v_5', 'v')"
assert (str(e3)) == "('v_6', 'v')"
assert (str(e1)) == "('v_4', 'v', 0)"
assert (str(e2)) == "('v_5', 'v', 0)"
assert (str(e3)) == "('v_6', 'v', h)"
g2.remove_edge('v_4', 'v')
assert g2.is_adjacent('v_4', 'v') is False
g2.remove_vertex('v')
Expand Down
41 changes: 41 additions & 0 deletions pydatastructs/graphs/tests/test_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,46 @@ def _test_minimum_spanning_tree(func, ds, algorithm, *args):
for k, v in mst.edge_weights.items():
assert (k, v.value) in expected_mst

def _test_minimum_spanning_tree_cpp(ds, algorithm, *args):
if (ds == 'List' and algorithm == "prim"):
a1 = AdjacencyListGraphNode('a', 0, backend = Backend.CPP)
b1 = AdjacencyListGraphNode('b', 0, backend = Backend.CPP)
c1 = AdjacencyListGraphNode('c', 0, backend = Backend.CPP)
d1 = AdjacencyListGraphNode('d', 0, backend = Backend.CPP)
e1 = AdjacencyListGraphNode('e', 0, backend = Backend.CPP)
g = Graph(a1, b1, c1, d1, e1, backend = Backend.CPP)
g.add_edge(a1.name, c1.name, 10)
g.add_edge(c1.name, a1.name, 10)
g.add_edge(a1.name, d1.name, 7)
g.add_edge(d1.name, a1.name, 7)
g.add_edge(c1.name, d1.name, 9)
g.add_edge(d1.name, c1.name, 9)
g.add_edge(d1.name, b1.name, 32)
g.add_edge(b1.name, d1.name, 32)
g.add_edge(d1.name, e1.name, 23)
g.add_edge(e1.name, d1.name, 23)
mst = minimum_spanning_tree(g, "prim", backend = Backend.CPP)
expected_mst = ["('a', 'd', 7)", "('d', 'c', 9)", "('e', 'd', 23)", "('b', 'd', 32)",
"('d', 'a', 7)", "('c', 'd', 9)", "('d', 'e', 23)", "('d', 'b', 32)"]
assert str(mst.get_edge('a', 'd')) in expected_mst
assert str(mst.get_edge('e', 'd')) in expected_mst
assert str(mst.get_edge('d', 'c')) in expected_mst
assert str(mst.get_edge('b', 'd')) in expected_mst
assert mst.num_edges() == 8
a=AdjacencyListGraphNode('0', 0, backend = Backend.CPP)
b=AdjacencyListGraphNode('1', 0, backend = Backend.CPP)
c=AdjacencyListGraphNode('2', 0, backend = Backend.CPP)
d=AdjacencyListGraphNode('3', 0, backend = Backend.CPP)
g2 = Graph(a,b,c,d,backend = Backend.CPP)
g2.add_edge('0', '1', 74)
g2.add_edge('1', '0', 74)
g2.add_edge('0', '3', 55)
g2.add_edge('3', '0', 55)
g2.add_edge('1', '2', 74)
g2.add_edge('2', '1', 74)
mst2=minimum_spanning_tree(g2, "prim", backend = Backend.CPP)
assert mst2.num_edges() == 6

fmst = minimum_spanning_tree
fmstp = minimum_spanning_tree_parallel
_test_minimum_spanning_tree(fmst, "List", "kruskal")
Expand All @@ -193,6 +233,7 @@ def _test_minimum_spanning_tree(func, ds, algorithm, *args):
_test_minimum_spanning_tree(fmstp, "List", "kruskal", 3)
_test_minimum_spanning_tree(fmstp, "Matrix", "kruskal", 3)
_test_minimum_spanning_tree(fmstp, "List", "prim", 3)
_test_minimum_spanning_tree_cpp("List", "prim")

def test_strongly_connected_components():

Expand Down
Loading
Loading