-
Notifications
You must be signed in to change notification settings - Fork 23
feat: label propagation algorithm #78
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
bb56990
083644c
b417394
0fa86c2
a138a93
2daf480
4c5cf9d
231ffec
0e142ba
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,54 @@ | ||
| import { Graph } from "@antv/graphlib"; | ||
| import { labelPropagation } from '../../packages/graph/src'; | ||
| import { dataTransformer } from "../utils/data"; | ||
| import labelPropagationTestData from '../data/label-propagation-test-data.json'; | ||
|
|
||
|
|
||
| describe('label propagation', () => { | ||
| it('simple label propagation', () => { | ||
| const oldData = { | ||
| nodes: [ | ||
| { id: '0' }, { id: '1' }, { id: '2' }, { id: '3' }, { id: '4' }, | ||
| { id: '5' }, { id: '6' }, { id: '7' }, { id: '8' }, { id: '9' }, | ||
| { id: '10' }, { id: '11' }, { id: '12' }, { id: '13' }, { id: '14' }, | ||
| ], | ||
| edges: [ | ||
| { source: '0', target: '1' }, { source: '0', target: '2' }, { source: '0', target: '3' }, { source: '0', target: '4' }, | ||
| { source: '1', target: '2' }, { source: '1', target: '3' }, { source: '1', target: '4' }, | ||
| { source: '2', target: '3' }, { source: '2', target: '4' }, | ||
| { source: '3', target: '4' }, | ||
| { source: '0', target: '0' }, | ||
| { source: '0', target: '0' }, | ||
| { source: '0', target: '0' }, | ||
|
|
||
| { source: '5', target: '6', weight: 5 }, { source: '5', target: '7' }, { source: '5', target: '8' }, { source: '5', target: '9' }, | ||
| { source: '6', target: '7' }, { source: '6', target: '8' }, { source: '6', target: '9' }, | ||
| { source: '7', target: '8' }, { source: '7', target: '9' }, | ||
| { source: '8', target: '9' }, | ||
|
|
||
| { source: '10', target: '11' }, { source: '10', target: '12' }, { source: '10', target: '13' }, { source: '10', target: '14' }, | ||
| { source: '11', target: '12' }, { source: '11', target: '13' }, { source: '11', target: '14' }, | ||
| { source: '12', target: '13' }, { source: '12', target: '14' }, | ||
| { source: '13', target: '14', weight: 5 }, | ||
|
|
||
| { source: '0', target: '5' }, | ||
| { source: '5', target: '10' }, | ||
| { source: '10', target: '0' }, | ||
| { source: '10', target: '0' }, | ||
| ] | ||
| }; | ||
| const data = dataTransformer(oldData); | ||
| const graph = new Graph(data); | ||
| const clusteredData = labelPropagation(graph, false, 'weight'); | ||
| expect(clusteredData.clusters.length).not.toBe(0); | ||
| expect(clusteredData.clusterEdges.length).not.toBe(0); | ||
| }); | ||
|
|
||
| it('label propagation with large graph', () => { | ||
| const data = dataTransformer(labelPropagationTestData); | ||
| const graph = new Graph(data); | ||
| const clusteredData = labelPropagation(graph, false, 'weight'); | ||
| expect(clusteredData.clusters.length).not.toBe(0); | ||
| expect(clusteredData.clusterEdges.length).not.toBe(0); | ||
| } | ||
| }); |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,197 @@ | ||
| import { uniqueId } from "./utils"; | ||
| import { ClusterData, INode, IEdge, Graph, Matrix } from "./types"; | ||
| import { ID } from "@antv/graphlib"; | ||
|
|
||
| function getAdjMatrix(graph: Graph, directed: boolean) { | ||
| const nodes = graph.getAllNodes(); | ||
| const matrix: Matrix[] = []; | ||
| // map node with index in data.nodes | ||
| const nodeMap = new Map<string | number, number>(); | ||
|
|
||
| if (!nodes) { | ||
| throw new Error("invalid nodes data!"); | ||
| } | ||
|
|
||
| if (nodes) { | ||
| nodes.forEach((node, i) => { | ||
| nodeMap.set(node.id, i); | ||
| const row: number[] = []; | ||
| matrix.push(row); | ||
| }); | ||
| } | ||
|
|
||
| const edges = graph.getAllEdges(); | ||
| if (edges) { | ||
| edges.forEach((edge) => { | ||
| const { source, target } = edge; | ||
| const sIndex = nodeMap.get(source); | ||
| const tIndex = nodeMap.get(target); | ||
| if ((!sIndex && sIndex !== 0) || (!tIndex && tIndex !== 0)) return; | ||
| matrix[sIndex][tIndex] = 1; | ||
| if (!directed) { | ||
| matrix[tIndex][sIndex] = 1; | ||
| } | ||
| }); | ||
| } | ||
| return matrix; | ||
| } | ||
|
|
||
| /** | ||
| * 标签传播算法 | ||
| * @param graphData 图数据 | ||
| * @param directed 是否有向图,默认为 false | ||
| * @param weightPropertyName 权重的属性字段 | ||
| * @param maxIteration 最大迭代次数 | ||
| */ | ||
| export const labelPropagation = ( | ||
| graph: Graph, | ||
| directed: boolean = false, | ||
| weightPropertyName: string = "weight", | ||
| maxIteration: number = 1000 | ||
| ): ClusterData => { | ||
| // the origin data | ||
| const nodes = graph.getAllNodes(); | ||
| const edges = graph.getAllEdges(); | ||
|
|
||
| const clusters: { [key: string]: { id: string; nodes: INode[] } } = {}; | ||
| const nodeMap: { [key: ID]: { node: INode; idx: number } } = {}; | ||
| const nodeToCluster = new Map<ID, string>(); | ||
| // init the clusters and nodeMap | ||
| nodes.forEach((node, i) => { | ||
| const cid: string = uniqueId(); | ||
| nodeToCluster.set(node.id, cid); | ||
| clusters[cid] = { | ||
| id: cid, | ||
| nodes: [node], | ||
| }; | ||
| nodeMap[node.id] = { | ||
| node, | ||
| idx: i, | ||
| }; | ||
| }); | ||
|
|
||
| // the adjacent matrix of calNodes inside clusters | ||
| const adjMatrix = getAdjMatrix(graph, directed); | ||
| // the sum of each row in adjacent matrix | ||
| const ks = []; | ||
| /** | ||
| * neighbor nodes (id for key and weight for value) for each node | ||
| * neighbors = { | ||
| * id(node_id): { id(neighbor_1_id): weight(weight of the edge), id(neighbor_2_id): weight(weight of the edge), ... }, | ||
| * ... | ||
| * } | ||
| */ | ||
| const neighbors: { [key: ID]: { [key: ID]: number } } = {}; | ||
|
||
| adjMatrix.forEach((row, i) => { | ||
| let k = 0; | ||
| const iid = nodes[i].id; | ||
| neighbors[iid] = {}; | ||
| row.forEach((entry, j) => { | ||
| if (!entry) return; | ||
| k += entry; | ||
| const jid = nodes[j].id; | ||
| neighbors[iid][jid] = entry; | ||
| }); | ||
| ks.push(k); | ||
| }); | ||
|
|
||
| let iter = 0; | ||
|
|
||
| while (iter < maxIteration) { | ||
| let changed = false; | ||
| nodes.forEach((node) => { | ||
| const neighborClusters: { [key: string]: number } = {}; | ||
| Object.keys(neighbors[node.id]).forEach((neighborId) => { | ||
| const neighborWeight = neighbors[node.id][neighborId]; | ||
| const neighborNode = nodeMap[neighborId].node; | ||
|
|
||
| const neighborClusterId = nodeToCluster.get(neighborNode.id); | ||
| if (!neighborClusters[neighborClusterId]) { | ||
| neighborClusters[neighborClusterId] = 0; | ||
| } | ||
| neighborClusters[neighborClusterId] += neighborWeight; | ||
| }); | ||
| // find the cluster with max weight | ||
| let maxWeight = -Infinity; | ||
| let bestClusterIds: string[] = []; | ||
| Object.keys(neighborClusters).forEach((clusterId) => { | ||
| if (maxWeight < neighborClusters[clusterId]) { | ||
| maxWeight = neighborClusters[clusterId]; | ||
| bestClusterIds = [clusterId]; | ||
| } else if (maxWeight === neighborClusters[clusterId]) { | ||
| bestClusterIds.push(clusterId); | ||
| } | ||
| }); | ||
| if ( | ||
| bestClusterIds.length === 1 && | ||
| bestClusterIds[0] === nodeToCluster.get(node.id) | ||
| ) { | ||
| return; | ||
| } | ||
| const selfClusterIdx = bestClusterIds.indexOf(nodeToCluster.get(node.id)); | ||
| if (selfClusterIdx >= 0) bestClusterIds.splice(selfClusterIdx, 1); | ||
| if (bestClusterIds && bestClusterIds.length) { | ||
| changed = true; | ||
|
|
||
| // remove from origin cluster | ||
| const selfCluster = clusters[nodeToCluster.get(node.id)]; | ||
| const nodeInSelfClusterIdx = selfCluster.nodes.indexOf(node); | ||
| selfCluster.nodes.splice(nodeInSelfClusterIdx, 1); | ||
|
|
||
| // move the node to the best cluster | ||
| const randomIdx = Math.floor(Math.random() * bestClusterIds.length); | ||
| const bestCluster = clusters[bestClusterIds[randomIdx]]; | ||
| bestCluster.nodes.push(node); | ||
| nodeToCluster.set(node.id, bestCluster.id); | ||
| } | ||
| }); | ||
| if (!changed) break; | ||
| iter++; | ||
| } | ||
|
|
||
| // delete the empty clusters | ||
| Object.keys(clusters).forEach((clusterId) => { | ||
| const cluster = clusters[clusterId]; | ||
| if (!cluster.nodes || !cluster.nodes.length) { | ||
| delete clusters[clusterId]; | ||
| } | ||
| }); | ||
|
|
||
| // get the cluster edges | ||
| const clusterEdges: IEdge[] = []; | ||
| const clusterEdgeMap: { [key: string]: IEdge } = {}; | ||
| edges.forEach((edge) => { | ||
| let i = 0; | ||
| const { source, target } = edge; | ||
| const weight = (edge.data[weightPropertyName] || 1) as number; | ||
| const sourceClusterId = nodeToCluster.get(nodeMap[source].node.id); | ||
| const targetClusterId = nodeToCluster.get(nodeMap[target].node.id); | ||
| const newEdgeId = `${sourceClusterId}---${targetClusterId}`; | ||
| if (clusterEdgeMap[newEdgeId]) { | ||
| clusterEdgeMap[newEdgeId].data.weight += weight; | ||
| (clusterEdgeMap[newEdgeId].data.count as number)++; | ||
| } else { | ||
| const newEdge = { | ||
| id: i++, | ||
| source: sourceClusterId, | ||
| target: targetClusterId, | ||
| data: { | ||
| weight, | ||
| count: 1, | ||
| }, | ||
| }; | ||
| clusterEdgeMap[newEdgeId] = newEdge; | ||
| clusterEdges.push(newEdge); | ||
| } | ||
| }); | ||
|
|
||
| const clustersArray: { id: string; nodes: INode[] }[] = []; | ||
| Object.keys(clusters).forEach((clusterId) => { | ||
| clustersArray.push(clusters[clusterId]); | ||
| }); | ||
| return { | ||
| clusters: clustersArray, | ||
| clusterEdges, | ||
| nodeToCluster, | ||
| }; | ||
| }; | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,4 @@ | ||
| import { Node, PlainObject } from "@antv/graphlib"; | ||
| import { Vector } from "./vector"; | ||
| import { DistanceType, KeyValueMap, NodeData } from "./types"; | ||
| import { uniq } from "@antv/util"; | ||
|
|
||
|
|
@@ -101,3 +100,10 @@ function euclideanDistance(source: number[], target: number[]) { | |
| return Math.sqrt(res); | ||
| } | ||
|
|
||
| export const uniqueId = (index: number = 0) => { | ||
|
||
| const random1 = `${Math.random()}`.split('.')[1].slice(0, 5); | ||
| const random2 = `${Math.random()}`.split('.')[1].slice(0, 5); | ||
| return `${index}-${random1}${random2}`; | ||
| }; | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是否能统一为英文注释~