Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions __tests__/data/label-propagation-test-data.json

Large diffs are not rendered by default.

54 changes: 54 additions & 0 deletions __tests__/unit/label-propagation.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import { Graph } from "@antv/graphlib";
import { labelPropagation } from '../../packages/graph/src';
import { dataTransformer } from "../utils/data";
import labelPropagationTestData from '../data/label-propagation-test-data.json';


describe('label propagation', () => {
it('simple label propagation', () => {
const oldData = {
nodes: [
{ id: '0' }, { id: '1' }, { id: '2' }, { id: '3' }, { id: '4' },
{ id: '5' }, { id: '6' }, { id: '7' }, { id: '8' }, { id: '9' },
{ id: '10' }, { id: '11' }, { id: '12' }, { id: '13' }, { id: '14' },
],
edges: [
{ source: '0', target: '1' }, { source: '0', target: '2' }, { source: '0', target: '3' }, { source: '0', target: '4' },
{ source: '1', target: '2' }, { source: '1', target: '3' }, { source: '1', target: '4' },
{ source: '2', target: '3' }, { source: '2', target: '4' },
{ source: '3', target: '4' },
{ source: '0', target: '0' },
{ source: '0', target: '0' },
{ source: '0', target: '0' },

{ source: '5', target: '6', weight: 5 }, { source: '5', target: '7' }, { source: '5', target: '8' }, { source: '5', target: '9' },
{ source: '6', target: '7' }, { source: '6', target: '8' }, { source: '6', target: '9' },
{ source: '7', target: '8' }, { source: '7', target: '9' },
{ source: '8', target: '9' },

{ source: '10', target: '11' }, { source: '10', target: '12' }, { source: '10', target: '13' }, { source: '10', target: '14' },
{ source: '11', target: '12' }, { source: '11', target: '13' }, { source: '11', target: '14' },
{ source: '12', target: '13' }, { source: '12', target: '14' },
{ source: '13', target: '14', weight: 5 },

{ source: '0', target: '5' },
{ source: '5', target: '10' },
{ source: '10', target: '0' },
{ source: '10', target: '0' },
]
};
const data = dataTransformer(oldData);
const graph = new Graph(data);
const clusteredData = labelPropagation(graph, false, 'weight');
expect(clusteredData.clusters.length).not.toBe(0);
expect(clusteredData.clusterEdges.length).not.toBe(0);
});

it('label propagation with large graph', () => {
const data = dataTransformer(labelPropagationTestData);
const graph = new Graph(data);
const clusteredData = labelPropagation(graph, false, 'weight');
expect(clusteredData.clusters.length).not.toBe(0);
expect(clusteredData.clusterEdges.length).not.toBe(0);
}
});
52 changes: 26 additions & 26 deletions __tests__/utils/data.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ import { INode, IEdge } from '../../packages/graph/src/types';
* @return {{nodes:INode[],edges:IEdge[]}} new data
*/
export const dataTransformer = (data: {
nodes: { id: ID; [key: string]: any }[];
edges: { source: ID; target: ID; [key: string]: any }[];
nodes: { id: ID;[key: string]: any }[];
edges: { source: ID; target: ID;[key: string]: any }[];
}): { nodes: INode[]; edges: IEdge[] } => {
const { nodes, edges } = data;
return {
Expand All @@ -22,31 +22,31 @@ export const dataTransformer = (data: {
};
};

export const dataPropertiesTransformer = (data: { nodes: { id: NodeID, [key: string]: any }[], edges: { source: NodeID, target: NodeID, [key: string]: any }[] }): { nodes: INode[], edges: IEdge[] } => {
const { nodes, edges } = data;
return {
nodes: nodes.map((n) => {
const { id, properties, ...rest } = n;
return { id, data: { ...properties, ...rest } };
}),
edges: edges.map((e, i) => {
const { id, source, target, ...rest } = e;
return { id: id ? id : `edge-${i}`, target, source, data: rest };
}),
};
export const dataPropertiesTransformer = (data: { nodes: { id: ID, [key: string]: any }[], edges: { source: ID, target: ID, [key: string]: any }[] }): { nodes: INode[], edges: IEdge[] } => {
const { nodes, edges } = data;
return {
nodes: nodes.map((n) => {
const { id, properties, ...rest } = n;
return { id, data: { ...properties, ...rest } };
}),
edges: edges.map((e, i) => {
const { id, source, target, ...rest } = e;
return { id: id ? id : `edge-${i}`, target, source, data: rest };
}),
};
};


export const dataLabelDataTransformer = (data: { nodes: { id: NodeID, [key: string]: any }[], edges: { source: NodeID, target: NodeID, [key: string]: any }[] }): { nodes: INode[], edges: IEdge[] } => {
const { nodes, edges } = data;
return {
nodes: nodes.map((n) => {
const { id, label, data } = n;
return { id, data: { label, ...data } };
}),
edges: edges.map((e, i) => {
const { id, source, target, ...rest } = e;
return { id: id ? id : `edge-${i}`, target, source, data: rest };
}),
};
export const dataLabelDataTransformer = (data: { nodes: { id: ID, [key: string]: any }[], edges: { source: ID, target: ID, [key: string]: any }[] }): { nodes: INode[], edges: IEdge[] } => {
const { nodes, edges } = data;
return {
nodes: nodes.map((n) => {
const { id, label, data } = n;
return { id, data: { label, ...data } };
}),
edges: edges.map((e, i) => {
const { id, source, target, ...rest } = e;
return { id: id ? id : `edge-${i}`, target, source, data: rest };
}),
};
};
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
"build:ci": "pnpm -r run build:ci",
"prepare": "husky install",
"test": "jest",
"test_one": "jest ./__tests__/unit/k-means.spec.ts",
"test_one": "jest ./__tests__/unit/label-propagation.spec.ts",
"coverage": "jest --coverage",
"build:site": "vite build",
"deploy": "gh-pages -d site/dist",
Expand Down
10 changes: 6 additions & 4 deletions packages/graph/src/detect-cycle.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ export const detectDirectedCycle = (
return true;
},
};
for (let key of Object.keys(unvisitedSet)) {
for (const key of Object.keys(unvisitedSet)) {
depthFirstSearch(graph, key, callbacks, true, false);
}
return cycle;
Expand Down Expand Up @@ -193,8 +193,9 @@ export const detectAllDirectedCycle = (
adjList: { [key: ID]: number[] }
) => {
let closed = false; // whether a path is closed
if (nodeIds && include === false && nodeIds.indexOf(node.id) > -1)
if (nodeIds && !include && nodeIds.indexOf(node.id) > -1) {
return closed;
}
path.push(node);
blocked.add(node);
const neighbors = adjList[node.id];
Expand Down Expand Up @@ -277,7 +278,7 @@ export const detectAllDirectedCycle = (
// 对自环情况 (点连向自身) 特殊处理:记录自环,但不加入adjList
if (
neighbor === node.id &&
!(include === false && nodeIds.indexOf(node.id) > -1)
!(!include && nodeIds.indexOf(node.id) > -1)
) {
allCycles.push({ [node.id]: node });
} else {
Expand Down Expand Up @@ -306,8 +307,9 @@ export const detectAllDirectedCycle = (
});
const startNode = idx2Node[minIdx];
// StartNode is not in the specified node to include. End the search ahead of time.
if (nodeIds && include && nodeIds.indexOf(startNode.id) === -1)
if (nodeIds && include && nodeIds.indexOf(startNode.id) === -1) {
return allCycles;
}
circuit(startNode, startNode, adjList);
nodeIdx = minIdx + 1;
} else {
Expand Down
1 change: 1 addition & 0 deletions packages/graph/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ export * from './connected-component';
export * from './mst';
export * from './k-means';
export * from './detect-cycle';
export * from './label-propagation';
197 changes: 197 additions & 0 deletions packages/graph/src/label-propagation.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
import { uniqueId } from "@antv/util";
import { ClusterData, INode, IEdge, Graph, Matrix } from "./types";
import { ID } from "@antv/graphlib";

function getAdjMatrix(graph: Graph, directed: boolean) {
const nodes = graph.getAllNodes();
const matrix: Matrix[] = [];
// map node with index in data.nodes
const nodeMap = new Map<string | number, number>();

if (!nodes) {
throw new Error("invalid nodes data!");
}

if (nodes) {
nodes.forEach((node, i) => {
nodeMap.set(node.id, i);
const row: number[] = [];
matrix.push(row);
});
}

const edges = graph.getAllEdges();
if (edges) {
edges.forEach((edge) => {
const { source, target } = edge;
const sIndex = nodeMap.get(source);
const tIndex = nodeMap.get(target);
if ((!sIndex && sIndex !== 0) || (!tIndex && tIndex !== 0)) return;
matrix[sIndex][tIndex] = 1;
if (!directed) {
matrix[tIndex][sIndex] = 1;
}
});
}
return matrix;
}

/**
* Performs label propagation clustering on the given graph.
* @param graph The graph object representing the nodes and edges.
* @param directed A boolean indicating whether the graph is directed or not. Default is false.
* @param weightPropertyName The name of the property used as the weight for edges. Default is 'weight'.
* @param maxIteration The maximum number of iterations for label propagation. Default is 1000.
* @returns The clustering result including clusters, cluster edges, and node-to-cluster mapping.
*/
export const labelPropagation = (
graph: Graph,
directed: boolean = false,
weightPropertyName: string = "weight",
maxIteration: number = 1000
): ClusterData => {
// the origin data
const nodes = graph.getAllNodes();
const edges = graph.getAllEdges();

const clusters: { [key: string]: { id: string; nodes: INode[] } } = {};
const nodeMap: { [key: ID]: { node: INode; idx: number } } = {};
const nodeToCluster = new Map<ID, string>();
// init the clusters and nodeMap
nodes.forEach((node, i) => {
const cid: string = uniqueId();
nodeToCluster.set(node.id, cid);
clusters[cid] = {
id: cid,
nodes: [node],
};
nodeMap[node.id] = {
node,
idx: i,
};
});

// the adjacent matrix of calNodes inside clusters
const adjMatrix = getAdjMatrix(graph, directed);
// the sum of each row in adjacent matrix
const ks = [];
/**
* neighbor nodes (id for key and weight for value) for each node
* neighbors = {
* id(node_id): { id(neighbor_1_id): weight(weight of the edge), id(neighbor_2_id): weight(weight of the edge), ... },
* ...
* }
*/
const neighbors: Map<ID, Map<ID, number>> = new Map<ID, Map<ID, number>>();
adjMatrix.forEach((row, i) => {
let k = 0;
const iid = nodes[i].id;
neighbors.set(iid, new Map<ID, number>());
row.forEach((entry, j) => {
if (!entry) return;
k += entry;
const jid = nodes[j].id;
neighbors.get(iid).set(jid, entry);
});
ks.push(k);
});

let iter = 0;

while (iter < maxIteration) {
let changed = false;
nodes.forEach((node) => {
const neighborClusters: { [key: string]: number } = {};
neighbors.get(node.id).forEach((neighborId, value) => {
const neighborWeight = neighbors.get(node.id).get(neighborId);
const neighborNode = nodeMap[neighborId].node;
const neighborClusterId = nodeToCluster.get(neighborNode.id);
if (!neighborClusters[neighborClusterId]) {
neighborClusters[neighborClusterId] = 0;
}
neighborClusters[neighborClusterId] += neighborWeight;
});
// find the cluster with max weight
let maxWeight = -Infinity;
let bestClusterIds: string[] = [];
Object.keys(neighborClusters).forEach((clusterId) => {
if (maxWeight < neighborClusters[clusterId]) {
maxWeight = neighborClusters[clusterId];
bestClusterIds = [clusterId];
} else if (maxWeight === neighborClusters[clusterId]) {
bestClusterIds.push(clusterId);
}
});
if (
bestClusterIds.length === 1 &&
bestClusterIds[0] === nodeToCluster.get(node.id)
) {
return;
}
const selfClusterIdx = bestClusterIds.indexOf(nodeToCluster.get(node.id));
if (selfClusterIdx >= 0) bestClusterIds.splice(selfClusterIdx, 1);
if (bestClusterIds && bestClusterIds.length) {
changed = true;

// remove from origin cluster
const selfCluster = clusters[nodeToCluster.get(node.id)];
const nodeInSelfClusterIdx = selfCluster.nodes.indexOf(node);
selfCluster.nodes.splice(nodeInSelfClusterIdx, 1);

// move the node to the best cluster
const randomIdx = Math.floor(Math.random() * bestClusterIds.length);
const bestCluster = clusters[bestClusterIds[randomIdx]];
bestCluster.nodes.push(node);
nodeToCluster.set(node.id, bestCluster.id);
}
});
if (!changed) break;
iter++;
}

// delete the empty clusters
Object.keys(clusters).forEach((clusterId) => {
const cluster = clusters[clusterId];
if (!cluster.nodes || !cluster.nodes.length) {
delete clusters[clusterId];
}
});

// get the cluster edges
const clusterEdges: IEdge[] = [];
const clusterEdgeMap: { [key: string]: IEdge } = {};
edges.forEach((edge) => {
let i = 0;
const { source, target } = edge;
const weight = (edge.data[weightPropertyName] || 1) as number;
const sourceClusterId = nodeToCluster.get(nodeMap[source].node.id);
const targetClusterId = nodeToCluster.get(nodeMap[target].node.id);
const newEdgeId = `${sourceClusterId}---${targetClusterId}`;
if (clusterEdgeMap[newEdgeId]) {
clusterEdgeMap[newEdgeId].data.weight += weight;
(clusterEdgeMap[newEdgeId].data.count as number)++;
} else {
const newEdge = {
id: i++,
source: sourceClusterId,
target: targetClusterId,
data: {
weight,
count: 1,
},
};
clusterEdgeMap[newEdgeId] = newEdge;
clusterEdges.push(newEdge);
}
});

const clustersArray: { id: string; nodes: INode[] }[] = [];
Object.keys(clusters).forEach((clusterId) => {
clustersArray.push(clusters[clusterId]);
});
return {
clusters: clustersArray,
clusterEdges,
nodeToCluster,
};
};
Loading