Skip to content

Commit 8c2e18d

Browse files
authored
feat: algorithm MST(Minimum Spanning Tree) (#72)
* feat: v5 algorithm mst * test: mst unit test * fix: fix lint * fix: fix lint * fix: graph get related edge api * chore: translate chinese annotation * chore: remove the unnecessary var
1 parent b28d667 commit 8c2e18d

File tree

10 files changed

+374
-18
lines changed

10 files changed

+374
-18
lines changed

__tests__/unit/mst.spec.ts

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
import { minimumSpanningTree } from "../../packages/graph/src";
2+
import { Graph } from "@antv/graphlib";
3+
4+
const data = {
5+
nodes: [
6+
{
7+
id: 'A',
8+
data: {},
9+
},
10+
{
11+
id: 'B',
12+
data: {},
13+
},
14+
{
15+
id: 'C',
16+
data: {},
17+
},
18+
{
19+
id: 'D',
20+
data: {},
21+
},
22+
{
23+
id: 'E',
24+
data: {},
25+
},
26+
{
27+
id: 'F',
28+
data: {},
29+
},
30+
{
31+
id: 'G',
32+
data: {},
33+
},
34+
],
35+
edges: [
36+
{
37+
id: 'edge1',
38+
source: 'A',
39+
target: 'B',
40+
data: {
41+
weight: 1,
42+
}
43+
},
44+
{
45+
id: 'edge2',
46+
source: 'B',
47+
target: 'C',
48+
data: {
49+
weight: 1,
50+
}
51+
},
52+
{
53+
id: 'edge3',
54+
source: 'A',
55+
target: 'C',
56+
data: {
57+
weight: 2,
58+
}
59+
},
60+
{
61+
id: 'edge4',
62+
source: 'D',
63+
target: 'A',
64+
data: {
65+
weight: 3,
66+
}
67+
},
68+
{
69+
id: 'edge5',
70+
source: 'D',
71+
target: 'E',
72+
data: {
73+
weight: 4,
74+
}
75+
},
76+
{
77+
id: 'edge6',
78+
source: 'E',
79+
target: 'F',
80+
data: {
81+
weight: 2,
82+
}
83+
},
84+
{
85+
id: 'edge7',
86+
source: 'F',
87+
target: 'D',
88+
data: {
89+
weight: 3,
90+
}
91+
},
92+
],
93+
};
94+
const graph = new Graph(data);
95+
describe('minimumSpanningTree', () => {
96+
it('test kruskal algorithm', () => {
97+
let result = minimumSpanningTree(graph, 'weight');
98+
let totalWeight = 0;
99+
for (let edge of result) {
100+
totalWeight += edge.data.weight;
101+
}
102+
expect(totalWeight).toEqual(10);
103+
});
104+
105+
it('test prim algorithm', () => {
106+
let result = minimumSpanningTree(graph, 'weight', 'prim');
107+
let totalWeight = 0;
108+
for (let edge of result) {
109+
totalWeight += edge.data.weight;
110+
}
111+
expect(totalWeight).toEqual(10);
112+
});
113+
});

package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
"build:ci": "pnpm -r run build:ci",
2222
"prepare": "husky install",
2323
"test": "jest",
24-
"test_one": "jest ./__tests__/unit/nodes-cosine-similarity.spec.ts",
24+
"test_one": "jest ./__tests__/unit/mst.spec.ts",
2525
"coverage": "jest --coverage",
2626
"build:site": "vite build",
2727
"deploy": "gh-pages -d site/dist",

packages/graph/src/cosine-similarity.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,4 @@ export const cosineSimilarity = (
2424
// Calculate the cosine similarity between the item vector and the target element vector
2525
const cosineSimilarity = norm2Product ? dot / norm2Product : 0;
2626
return cosineSimilarity;
27-
}
27+
};

packages/graph/src/index.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ export * from './dfs';
99
export * from './cosine-similarity';
1010
export * from './nodes-cosine-similarity';
1111
export * from './gaddi';
12+
export * from './mst';

packages/graph/src/mst.ts

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import UnionFind from './structs/union-find';
2+
import MinBinaryHeap from './structs/binary-heap';
3+
import { Graph, IEdge, IMSTAlgorithm, IMSTAlgorithmOpt } from './types';
4+
import { clone } from '@antv/util';
5+
6+
/**
7+
Calculates the Minimum Spanning Tree (MST) of a graph using the Prim's algorithm.The MST is a subset of edges that forms a tree connecting all nodes with the minimum possible total edge weight.
8+
@param graph - The graph for which the MST needs to be calculated.
9+
@param weightProps - Optional. The property name in the edge data object that represents the weight of the edge.If provided, the algorithm will consider the weight of edges based on this property.If not provided, the algorithm will assume all edges have a weight of 0.
10+
@returns An array of selected edges that form the Minimum Spanning Tree (MST) of the graph.
11+
*/
12+
const primMST: IMSTAlgorithm = (graph, weightProps?) => {
13+
const selectedEdges: IEdge[] = [];
14+
const nodes = graph.getAllNodes();
15+
const edges = graph.getAllEdges();
16+
if (nodes.length === 0) {
17+
return selectedEdges;
18+
}
19+
// From the first node
20+
const currNode = nodes[0];
21+
const visited = new Set();
22+
visited.add(currNode);
23+
// Using binary heap to maintain the weight of edges from other nodes that have joined the node
24+
const compareWeight = (a: IEdge, b: IEdge) => {
25+
if (weightProps) {
26+
a.data;
27+
return (a.data[weightProps] as number) - (b.data[weightProps] as number);
28+
}
29+
return 0;
30+
};
31+
const edgeQueue = new MinBinaryHeap<IEdge>(compareWeight);
32+
33+
graph.getRelatedEdges(currNode.id, 'both').forEach((edge) => {
34+
edgeQueue.insert(edge);
35+
});
36+
while (!edgeQueue.isEmpty()) {
37+
// Select the node with the least edge weight between the added node and the added node
38+
const currEdge: IEdge = edgeQueue.delMin();
39+
const source = currEdge.source;
40+
const target = currEdge.target;
41+
if (visited.has(source) && visited.has(target)) continue;
42+
selectedEdges.push(currEdge);
43+
if (!visited.has(source)) {
44+
visited.add(source);
45+
graph.getRelatedEdges(source, 'both').forEach((edge) => {
46+
edgeQueue.insert(edge);
47+
});
48+
}
49+
if (!visited.has(target)) {
50+
visited.add(target);
51+
graph.getRelatedEdges(target, 'both').forEach((edge) => {
52+
edgeQueue.insert(edge);
53+
});
54+
}
55+
}
56+
return selectedEdges;
57+
};
58+
59+
/**
60+
Calculates the Minimum Spanning Tree (MST) of a graph using the Kruskal's algorithm.The MST is a subset of edges that forms a tree connecting all nodes with the minimum possible total edge weight.
61+
@param graph - The graph for which the MST needs to be calculated.
62+
@param weightProps - Optional. The property name in the edge data object that represents the weight of the edge.If provided, the algorithm will consider the weight of edges based on this property.If not provided, the algorithm will assume all edges have a weight of 0.
63+
@returns An array of selected edges that form the Minimum Spanning Tree (MST) of the graph.
64+
*/
65+
const kruskalMST: IMSTAlgorithm = (graph, weightProps?) => {
66+
const selectedEdges: IEdge[] = [];
67+
const nodes = graph.getAllNodes();
68+
const edges = graph.getAllEdges();
69+
if (nodes.length === 0) {
70+
return selectedEdges;
71+
}
72+
// If you specify weight, all edges are sorted by weight from smallest to largest
73+
const weightEdges = clone(edges);
74+
if (weightProps) {
75+
weightEdges.sort((a: IEdge, b: IEdge) => {
76+
return (a.data[weightProps] as number) - (b.data[weightProps] as number);
77+
});
78+
}
79+
const disjointSet = new UnionFind(nodes.map((n) => n.id));
80+
// Starting with the edge with the least weight, if the two nodes connected by this edge are not in the same connected component in graph G, the edge is added.
81+
while (weightEdges.length > 0) {
82+
const curEdge = weightEdges.shift();
83+
const source = curEdge.source;
84+
const target = curEdge.target;
85+
if (!disjointSet.connected(source, target)) {
86+
selectedEdges.push(curEdge);
87+
disjointSet.union(source, target);
88+
}
89+
}
90+
return selectedEdges;
91+
};
92+
93+
/**
94+
Calculates the Minimum Spanning Tree (MST) of a graph using either Prim's or Kruskal's algorithm.The MST is a subset of edges that forms a tree connecting all nodes with the minimum possible total edge weight.
95+
@param graph - The graph for which the MST needs to be calculated.
96+
@param weightProps - Optional. The property name in the edge data object that represents the weight of the edge.If provided, the algorithm will consider the weight of edges based on this property.If not provided, the algorithm will assume all edges have a weight of 0.
97+
@param algo - Optional. The algorithm to use for calculating the MST. Can be either 'prim' for Prim's algorithm, 'kruskal' for Kruskal's algorithm, or undefined to use the default algorithm (Kruskal's algorithm).
98+
@returns An array of selected edges that form the Minimum Spanning Tree (MST) of the graph.
99+
*/
100+
export const minimumSpanningTree = (graph: Graph, weightProps?: string, algo?: 'prim' | 'kruskal' | undefined): IEdge[] => {
101+
const algos: IMSTAlgorithmOpt = {
102+
'prim': primMST,
103+
'kruskal': kruskalMST,
104+
};
105+
return (algo && algos[algo](graph, weightProps)) || kruskalMST(graph, weightProps);
106+
};

packages/graph/src/nodes-cosine-similarity.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ export const nodesCosineSimilarity = (
2121
allCosineSimilarity: number[],
2222
similarNodes: NodeSimilarity[],
2323
} => {
24-
const similarNodes = clone(nodes.filter(node => node.id !== seedNode.id));
25-
const seedNodeIndex = nodes.findIndex(node => node.id === seedNode.id);
24+
const similarNodes = clone(nodes.filter((node) => node.id !== seedNode.id));
25+
const seedNodeIndex = nodes.findIndex((node) => node.id === seedNode.id);
2626
// Collection of all node properties
2727
const properties = getAllProperties(nodes);
2828
// One-hot feature vectors for all node properties
@@ -40,4 +40,4 @@ export const nodesCosineSimilarity = (
4040
// Sort the returned nodes according to cosine similarity
4141
similarNodes.sort((a: NodeSimilarity, b: NodeSimilarity) => b.data.cosineSimilarity - a.data.cosineSimilarity);
4242
return { allCosineSimilarity, similarNodes };
43-
}
43+
};
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
2+
export default class MinBinaryHeap<T> {
3+
list: T[];
4+
5+
compareFn: (a?: T, b?: T) => number;
6+
7+
constructor(compareFn: (a: T, b: T) => number) {
8+
this.compareFn = compareFn || (() => 0);
9+
this.list = [];
10+
}
11+
12+
getLeft(index: number) {
13+
return 2 * index + 1;
14+
}
15+
16+
getRight(index: number) {
17+
return 2 * index + 2;
18+
}
19+
20+
getParent(index: number) {
21+
if (index === 0) {
22+
return null;
23+
}
24+
return Math.floor((index - 1) / 2);
25+
}
26+
27+
isEmpty() {
28+
return this.list.length <= 0;
29+
}
30+
31+
top() {
32+
return this.isEmpty() ? undefined : this.list[0];
33+
}
34+
35+
delMin() {
36+
const top = this.top();
37+
const bottom = this.list.pop();
38+
if (this.list.length > 0) {
39+
this.list[0] = bottom;
40+
this.moveDown(0);
41+
}
42+
return top;
43+
}
44+
45+
insert(value: T) {
46+
if (value !== null) {
47+
this.list.push(value);
48+
const index = this.list.length - 1;
49+
this.moveUp(index);
50+
return true;
51+
}
52+
return false;
53+
}
54+
55+
moveUp(index: number) {
56+
let i = index;
57+
let parent = this.getParent(i);
58+
while (i && i > 0 && this.compareFn(this.list[parent], this.list[i]) > 0) {
59+
// swap
60+
const tmp = this.list[parent];
61+
this.list[parent] = this.list[i];
62+
this.list[i] = tmp;
63+
i = parent;
64+
parent = this.getParent(i);
65+
}
66+
}
67+
68+
moveDown(index: number) {
69+
let element = index;
70+
const left = this.getLeft(index);
71+
const right = this.getRight(index);
72+
const size = this.list.length;
73+
if (left !== null && left < size && this.compareFn(this.list[element], this.list[left]) > 0) {
74+
element = left;
75+
} else if (
76+
right !== null &&
77+
right < size &&
78+
this.compareFn(this.list[element], this.list[right]) > 0
79+
) {
80+
element = right;
81+
}
82+
if (index !== element) {
83+
[this.list[index], this.list[element]] = [this.list[element], this.list[index]];
84+
this.moveDown(element);
85+
}
86+
}
87+
}
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
/**
2+
* Disjoint set to support quick union
3+
*/
4+
export default class UnionFind {
5+
count: number;
6+
7+
parent: { [key: number | string]: number | string };
8+
9+
constructor(items: (number | string)[]) {
10+
this.count = items.length;
11+
this.parent = {};
12+
for (const i of items) {
13+
this.parent[i] = i;
14+
}
15+
}
16+
17+
// find the root of the item
18+
find(item: (number | string)) {
19+
let resItem = item;
20+
while (this.parent[resItem] !== resItem) {
21+
resItem = this.parent[resItem];
22+
}
23+
return resItem;
24+
}
25+
26+
union(a: (number | string), b: (number | string)) {
27+
const rootA = this.find(a);
28+
const rootB = this.find(b);
29+
if (rootA === rootB) return;
30+
// make the element with smaller root the parent
31+
if (rootA < rootB) {
32+
if (this.parent[b] !== b) this.union(this.parent[b], a);
33+
this.parent[b] = this.parent[a];
34+
} else {
35+
if (this.parent[a] !== a) this.union(this.parent[a], b);
36+
this.parent[a] = this.parent[b];
37+
}
38+
}
39+
40+
// Determine that A and B are connected
41+
connected(a: (number | string), b: (number | string)) {
42+
return this.find(a) === this.find(b);
43+
}
44+
}

0 commit comments

Comments
 (0)