Skip to content

Commit

Permalink
feat: algorithm MST(Minimum Spanning Tree) (#72)
Browse files Browse the repository at this point in the history
* feat: v5 algorithm mst

* test: mst unit test

* fix: fix lint

* fix: fix lint

* fix: graph get related edge api

* chore: translate chinese annotation

* chore: remove the unnecessary var
  • Loading branch information
zqqcee authored Sep 18, 2023
1 parent b28d667 commit 8c2e18d
Show file tree
Hide file tree
Showing 10 changed files with 374 additions and 18 deletions.
113 changes: 113 additions & 0 deletions __tests__/unit/mst.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import { minimumSpanningTree } from "../../packages/graph/src";
import { Graph } from "@antv/graphlib";

const data = {
nodes: [
{
id: 'A',
data: {},
},
{
id: 'B',
data: {},
},
{
id: 'C',
data: {},
},
{
id: 'D',
data: {},
},
{
id: 'E',
data: {},
},
{
id: 'F',
data: {},
},
{
id: 'G',
data: {},
},
],
edges: [
{
id: 'edge1',
source: 'A',
target: 'B',
data: {
weight: 1,
}
},
{
id: 'edge2',
source: 'B',
target: 'C',
data: {
weight: 1,
}
},
{
id: 'edge3',
source: 'A',
target: 'C',
data: {
weight: 2,
}
},
{
id: 'edge4',
source: 'D',
target: 'A',
data: {
weight: 3,
}
},
{
id: 'edge5',
source: 'D',
target: 'E',
data: {
weight: 4,
}
},
{
id: 'edge6',
source: 'E',
target: 'F',
data: {
weight: 2,
}
},
{
id: 'edge7',
source: 'F',
target: 'D',
data: {
weight: 3,
}
},
],
};
const graph = new Graph(data);
describe('minimumSpanningTree', () => {
it('test kruskal algorithm', () => {
let result = minimumSpanningTree(graph, 'weight');
let totalWeight = 0;
for (let edge of result) {
totalWeight += edge.data.weight;
}
expect(totalWeight).toEqual(10);
});

it('test prim algorithm', () => {
let result = minimumSpanningTree(graph, 'weight', 'prim');
let totalWeight = 0;
for (let edge of result) {
totalWeight += edge.data.weight;
}
expect(totalWeight).toEqual(10);
});
});
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
"build:ci": "pnpm -r run build:ci",
"prepare": "husky install",
"test": "jest",
"test_one": "jest ./__tests__/unit/nodes-cosine-similarity.spec.ts",
"test_one": "jest ./__tests__/unit/mst.spec.ts",
"coverage": "jest --coverage",
"build:site": "vite build",
"deploy": "gh-pages -d site/dist",
Expand Down
2 changes: 1 addition & 1 deletion packages/graph/src/cosine-similarity.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,4 @@ export const cosineSimilarity = (
// Calculate the cosine similarity between the item vector and the target element vector
const cosineSimilarity = norm2Product ? dot / norm2Product : 0;
return cosineSimilarity;
}
};
1 change: 1 addition & 0 deletions packages/graph/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ export * from './dfs';
export * from './cosine-similarity';
export * from './nodes-cosine-similarity';
export * from './gaddi';
export * from './mst';
106 changes: 106 additions & 0 deletions packages/graph/src/mst.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import UnionFind from './structs/union-find';
import MinBinaryHeap from './structs/binary-heap';
import { Graph, IEdge, IMSTAlgorithm, IMSTAlgorithmOpt } from './types';
import { clone } from '@antv/util';

/**
Calculates the Minimum Spanning Tree (MST) of a graph using the Prim's algorithm.The MST is a subset of edges that forms a tree connecting all nodes with the minimum possible total edge weight.
@param graph - The graph for which the MST needs to be calculated.
@param weightProps - Optional. The property name in the edge data object that represents the weight of the edge.If provided, the algorithm will consider the weight of edges based on this property.If not provided, the algorithm will assume all edges have a weight of 0.
@returns An array of selected edges that form the Minimum Spanning Tree (MST) of the graph.
*/
const primMST: IMSTAlgorithm = (graph, weightProps?) => {
const selectedEdges: IEdge[] = [];
const nodes = graph.getAllNodes();
const edges = graph.getAllEdges();
if (nodes.length === 0) {
return selectedEdges;
}
// From the first node
const currNode = nodes[0];
const visited = new Set();
visited.add(currNode);
// Using binary heap to maintain the weight of edges from other nodes that have joined the node
const compareWeight = (a: IEdge, b: IEdge) => {
if (weightProps) {
a.data;
return (a.data[weightProps] as number) - (b.data[weightProps] as number);
}
return 0;
};
const edgeQueue = new MinBinaryHeap<IEdge>(compareWeight);

graph.getRelatedEdges(currNode.id, 'both').forEach((edge) => {
edgeQueue.insert(edge);
});
while (!edgeQueue.isEmpty()) {
// Select the node with the least edge weight between the added node and the added node
const currEdge: IEdge = edgeQueue.delMin();
const source = currEdge.source;
const target = currEdge.target;
if (visited.has(source) && visited.has(target)) continue;
selectedEdges.push(currEdge);
if (!visited.has(source)) {
visited.add(source);
graph.getRelatedEdges(source, 'both').forEach((edge) => {
edgeQueue.insert(edge);
});
}
if (!visited.has(target)) {
visited.add(target);
graph.getRelatedEdges(target, 'both').forEach((edge) => {
edgeQueue.insert(edge);
});
}
}
return selectedEdges;
};

/**
Calculates the Minimum Spanning Tree (MST) of a graph using the Kruskal's algorithm.The MST is a subset of edges that forms a tree connecting all nodes with the minimum possible total edge weight.
@param graph - The graph for which the MST needs to be calculated.
@param weightProps - Optional. The property name in the edge data object that represents the weight of the edge.If provided, the algorithm will consider the weight of edges based on this property.If not provided, the algorithm will assume all edges have a weight of 0.
@returns An array of selected edges that form the Minimum Spanning Tree (MST) of the graph.
*/
const kruskalMST: IMSTAlgorithm = (graph, weightProps?) => {
const selectedEdges: IEdge[] = [];
const nodes = graph.getAllNodes();
const edges = graph.getAllEdges();
if (nodes.length === 0) {
return selectedEdges;
}
// If you specify weight, all edges are sorted by weight from smallest to largest
const weightEdges = clone(edges);
if (weightProps) {
weightEdges.sort((a: IEdge, b: IEdge) => {
return (a.data[weightProps] as number) - (b.data[weightProps] as number);
});
}
const disjointSet = new UnionFind(nodes.map((n) => n.id));
// Starting with the edge with the least weight, if the two nodes connected by this edge are not in the same connected component in graph G, the edge is added.
while (weightEdges.length > 0) {
const curEdge = weightEdges.shift();
const source = curEdge.source;
const target = curEdge.target;
if (!disjointSet.connected(source, target)) {
selectedEdges.push(curEdge);
disjointSet.union(source, target);
}
}
return selectedEdges;
};

/**
Calculates the Minimum Spanning Tree (MST) of a graph using either Prim's or Kruskal's algorithm.The MST is a subset of edges that forms a tree connecting all nodes with the minimum possible total edge weight.
@param graph - The graph for which the MST needs to be calculated.
@param weightProps - Optional. The property name in the edge data object that represents the weight of the edge.If provided, the algorithm will consider the weight of edges based on this property.If not provided, the algorithm will assume all edges have a weight of 0.
@param algo - Optional. The algorithm to use for calculating the MST. Can be either 'prim' for Prim's algorithm, 'kruskal' for Kruskal's algorithm, or undefined to use the default algorithm (Kruskal's algorithm).
@returns An array of selected edges that form the Minimum Spanning Tree (MST) of the graph.
*/
export const minimumSpanningTree = (graph: Graph, weightProps?: string, algo?: 'prim' | 'kruskal' | undefined): IEdge[] => {
const algos: IMSTAlgorithmOpt = {
'prim': primMST,
'kruskal': kruskalMST,
};
return (algo && algos[algo](graph, weightProps)) || kruskalMST(graph, weightProps);
};
6 changes: 3 additions & 3 deletions packages/graph/src/nodes-cosine-similarity.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ export const nodesCosineSimilarity = (
allCosineSimilarity: number[],
similarNodes: NodeSimilarity[],
} => {
const similarNodes = clone(nodes.filter(node => node.id !== seedNode.id));
const seedNodeIndex = nodes.findIndex(node => node.id === seedNode.id);
const similarNodes = clone(nodes.filter((node) => node.id !== seedNode.id));
const seedNodeIndex = nodes.findIndex((node) => node.id === seedNode.id);
// Collection of all node properties
const properties = getAllProperties(nodes);
// One-hot feature vectors for all node properties
Expand All @@ -40,4 +40,4 @@ export const nodesCosineSimilarity = (
// Sort the returned nodes according to cosine similarity
similarNodes.sort((a: NodeSimilarity, b: NodeSimilarity) => b.data.cosineSimilarity - a.data.cosineSimilarity);
return { allCosineSimilarity, similarNodes };
}
};
87 changes: 87 additions & 0 deletions packages/graph/src/structs/binary-heap.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@

export default class MinBinaryHeap<T> {
list: T[];

compareFn: (a?: T, b?: T) => number;

constructor(compareFn: (a: T, b: T) => number) {
this.compareFn = compareFn || (() => 0);
this.list = [];
}

getLeft(index: number) {
return 2 * index + 1;
}

getRight(index: number) {
return 2 * index + 2;
}

getParent(index: number) {
if (index === 0) {
return null;
}
return Math.floor((index - 1) / 2);
}

isEmpty() {
return this.list.length <= 0;
}

top() {
return this.isEmpty() ? undefined : this.list[0];
}

delMin() {
const top = this.top();
const bottom = this.list.pop();
if (this.list.length > 0) {
this.list[0] = bottom;
this.moveDown(0);
}
return top;
}

insert(value: T) {
if (value !== null) {
this.list.push(value);
const index = this.list.length - 1;
this.moveUp(index);
return true;
}
return false;
}

moveUp(index: number) {
let i = index;
let parent = this.getParent(i);
while (i && i > 0 && this.compareFn(this.list[parent], this.list[i]) > 0) {
// swap
const tmp = this.list[parent];
this.list[parent] = this.list[i];
this.list[i] = tmp;
i = parent;
parent = this.getParent(i);
}
}

moveDown(index: number) {
let element = index;
const left = this.getLeft(index);
const right = this.getRight(index);
const size = this.list.length;
if (left !== null && left < size && this.compareFn(this.list[element], this.list[left]) > 0) {
element = left;
} else if (
right !== null &&
right < size &&
this.compareFn(this.list[element], this.list[right]) > 0
) {
element = right;
}
if (index !== element) {
[this.list[index], this.list[element]] = [this.list[element], this.list[index]];
this.moveDown(element);
}
}
}
44 changes: 44 additions & 0 deletions packages/graph/src/structs/union-find.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/**
* Disjoint set to support quick union
*/
export default class UnionFind {
count: number;

parent: { [key: number | string]: number | string };

constructor(items: (number | string)[]) {
this.count = items.length;
this.parent = {};
for (const i of items) {
this.parent[i] = i;
}
}

// find the root of the item
find(item: (number | string)) {
let resItem = item;
while (this.parent[resItem] !== resItem) {
resItem = this.parent[resItem];
}
return resItem;
}

union(a: (number | string), b: (number | string)) {
const rootA = this.find(a);
const rootB = this.find(b);
if (rootA === rootB) return;
// make the element with smaller root the parent
if (rootA < rootB) {
if (this.parent[b] !== b) this.union(this.parent[b], a);
this.parent[b] = this.parent[a];
} else {
if (this.parent[a] !== a) this.union(this.parent[a], b);
this.parent[a] = this.parent[b];
}
}

// Determine that A and B are connected
connected(a: (number | string), b: (number | string)) {
return this.find(a) === this.find(b);
}
}
Loading

0 comments on commit 8c2e18d

Please sign in to comment.