Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions src/kruskal_mst.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
class DisjointSet:
def __init__(self, vertices):
"""
Initialize a Disjoint Set data structure for Kruskal's algorithm.

Args:
vertices (int): Number of vertices in the graph
"""
self.parent = list(range(vertices))
self.rank = [0] * vertices

def find(self, item):
"""
Find the root of an item with path compression.

Args:
item (int): Vertex to find the root for

Returns:
int: Root of the vertex
"""
if self.parent[item] != item:
self.parent[item] = self.find(self.parent[item])
return self.parent[item]

def union(self, x, y):
"""
Union of two sets by rank.

Args:
x (int): First vertex
y (int): Second vertex

Returns:
bool: True if union was successful, False if already in same set
"""
xroot = self.find(x)
yroot = self.find(y)

if xroot == yroot:
return False

# Union by rank
if self.rank[xroot] < self.rank[yroot]:
self.parent[xroot] = yroot
elif self.rank[xroot] > self.rank[yroot]:
self.parent[yroot] = xroot
else:
self.parent[yroot] = xroot
self.rank[xroot] += 1

return True

def kruskal_mst(num_vertices, edges):
"""
Implement Kruskal's algorithm to find Minimum Spanning Tree.

Args:
num_vertices (int): Number of vertices in the graph
edges (list): List of edges, where each edge is (weight, u, v)

Returns:
list: Edges in the Minimum Spanning Tree
"""
# Input validation
if num_vertices <= 0:
raise ValueError("Number of vertices must be positive")

if not edges:
return []

# Sort edges by weight in ascending order
sorted_edges = sorted(edges)

# Initialize Disjoint Set
disjoint_set = DisjointSet(num_vertices)

# MST will store the resultant minimum spanning tree
mst = []

for weight, u, v in sorted_edges:
# If including this edge doesn't cause a cycle, add it to MST
if disjoint_set.union(u, v):
mst.append((weight, u, v))

# Stop when MST has num_vertices - 1 edges
if len(mst) == num_vertices - 1:
break

return mst
82 changes: 82 additions & 0 deletions tests/test_kruskal_mst.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import pytest
from src.kruskal_mst import kruskal_mst, DisjointSet

def test_disjoint_set():
"""Test Disjoint Set data structure functionality."""
ds = DisjointSet(5)

# Initial state: each vertex is in its own set
for i in range(5):
assert ds.find(i) == i

# Union of sets
ds.union(0, 1)
ds.union(2, 3)

# Check roots after union
assert ds.find(0) == ds.find(1)
assert ds.find(2) == ds.find(3)
assert ds.find(0) != ds.find(2)

def test_kruskal_empty_graph():
"""Test Kruskal's algorithm with an empty graph."""
assert kruskal_mst(0, []) == []
assert kruskal_mst(1, []) == []

def test_kruskal_invalid_input():
"""Test Kruskal's algorithm with invalid inputs."""
with pytest.raises(ValueError):
kruskal_mst(-1, [])

def test_kruskal_simple_graph():
"""Test Kruskal's algorithm with a simple graph."""
# Simple graph with 4 vertices
edges = [
(1, 0, 1), # weight 1, connecting vertices 0 and 1
(4, 1, 2), # weight 4, connecting vertices 1 and 2
(2, 0, 2), # weight 2, connecting vertices 0 and 2
(3, 1, 3), # weight 3, connecting vertices 1 and 3
(5, 2, 3), # weight 5, connecting vertices 2 and 3
]

mst = kruskal_mst(4, edges)

# Expect 3 edges in MST for 4 vertices
assert len(mst) == 3

# Check total weight of MST
total_weight = sum(edge[0] for edge in mst)
assert total_weight == 6 # 1 + 2 + 3

def test_kruskal_disconnected_graph():
"""Test Kruskal's algorithm with a disconnected graph."""
edges = [
(1, 0, 1), # weight 1, connecting vertices 0 and 1
(5, 2, 3), # weight 5, connecting vertices 2 and 3
(10, 4, 5) # weight 10, connecting vertices 4 and 5
]

mst = kruskal_mst(6, edges)

# Expect 2 edges in partial MST
assert len(mst) == 2

# Ensure the minimum weight edges are selected
assert sorted(mst) == [(1, 0, 1), (5, 2, 3)]

def test_kruskal_complete_graph():
"""Test Kruskal's algorithm with a complete graph."""
edges = [
(1, 0, 1), (2, 0, 2), (3, 0, 3),
(2, 1, 2), (4, 1, 3),
(5, 2, 3)
]

mst = kruskal_mst(4, edges)

# Always expect 3 edges in MST for a 4-vertex graph
assert len(mst) == 3

# Check that the total weight is minimized
total_weight = sum(edge[0] for edge in mst)
assert total_weight == 6 # Minimum total weight