Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions __tests__/data/label-propagation-test-data.json

Large diffs are not rendered by default.

54 changes: 54 additions & 0 deletions __tests__/unit/label-propagation.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import { Graph } from "@antv/graphlib";
import { labelPropagation } from '../../packages/graph/src';
import { dataTransformer } from "../utils/data";
import labelPropagationTestData from '../data/label-propagation-test-data.json';


describe('label propagation', () => {
it('simple label propagation', () => {
const oldData = {
nodes: [
{ id: '0' }, { id: '1' }, { id: '2' }, { id: '3' }, { id: '4' },
{ id: '5' }, { id: '6' }, { id: '7' }, { id: '8' }, { id: '9' },
{ id: '10' }, { id: '11' }, { id: '12' }, { id: '13' }, { id: '14' },
],
edges: [
{ source: '0', target: '1' }, { source: '0', target: '2' }, { source: '0', target: '3' }, { source: '0', target: '4' },
{ source: '1', target: '2' }, { source: '1', target: '3' }, { source: '1', target: '4' },
{ source: '2', target: '3' }, { source: '2', target: '4' },
{ source: '3', target: '4' },
{ source: '0', target: '0' },
{ source: '0', target: '0' },
{ source: '0', target: '0' },

{ source: '5', target: '6', weight: 5 }, { source: '5', target: '7' }, { source: '5', target: '8' }, { source: '5', target: '9' },
{ source: '6', target: '7' }, { source: '6', target: '8' }, { source: '6', target: '9' },
{ source: '7', target: '8' }, { source: '7', target: '9' },
{ source: '8', target: '9' },

{ source: '10', target: '11' }, { source: '10', target: '12' }, { source: '10', target: '13' }, { source: '10', target: '14' },
{ source: '11', target: '12' }, { source: '11', target: '13' }, { source: '11', target: '14' },
{ source: '12', target: '13' }, { source: '12', target: '14' },
{ source: '13', target: '14', weight: 5 },

{ source: '0', target: '5' },
{ source: '5', target: '10' },
{ source: '10', target: '0' },
{ source: '10', target: '0' },
]
};
const data = dataTransformer(oldData);
const graph = new Graph(data);
const clusteredData = labelPropagation(graph, false, 'weight');
expect(clusteredData.clusters.length).not.toBe(0);
expect(clusteredData.clusterEdges.length).not.toBe(0);
});

it('label propagation with large graph', () => {
const data = dataTransformer(labelPropagationTestData);
const graph = new Graph(data);
const clusteredData = labelPropagation(graph, false, 'weight');
expect(clusteredData.clusters.length).not.toBe(0);
expect(clusteredData.clusterEdges.length).not.toBe(0);
}
});
52 changes: 26 additions & 26 deletions __tests__/utils/data.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ import { INode, IEdge } from '../../packages/graph/src/types';
* @return {{nodes:INode[],edges:IEdge[]}} new data
*/
export const dataTransformer = (data: {
nodes: { id: ID; [key: string]: any }[];
edges: { source: ID; target: ID; [key: string]: any }[];
nodes: { id: ID;[key: string]: any }[];
edges: { source: ID; target: ID;[key: string]: any }[];
}): { nodes: INode[]; edges: IEdge[] } => {
const { nodes, edges } = data;
return {
Expand All @@ -22,31 +22,31 @@ export const dataTransformer = (data: {
};
};

export const dataPropertiesTransformer = (data: { nodes: { id: NodeID, [key: string]: any }[], edges: { source: NodeID, target: NodeID, [key: string]: any }[] }): { nodes: INode[], edges: IEdge[] } => {
const { nodes, edges } = data;
return {
nodes: nodes.map((n) => {
const { id, properties, ...rest } = n;
return { id, data: { ...properties, ...rest } };
}),
edges: edges.map((e, i) => {
const { id, source, target, ...rest } = e;
return { id: id ? id : `edge-${i}`, target, source, data: rest };
}),
};
export const dataPropertiesTransformer = (data: { nodes: { id: ID, [key: string]: any }[], edges: { source: ID, target: ID, [key: string]: any }[] }): { nodes: INode[], edges: IEdge[] } => {
const { nodes, edges } = data;
return {
nodes: nodes.map((n) => {
const { id, properties, ...rest } = n;
return { id, data: { ...properties, ...rest } };
}),
edges: edges.map((e, i) => {
const { id, source, target, ...rest } = e;
return { id: id ? id : `edge-${i}`, target, source, data: rest };
}),
};
};


export const dataLabelDataTransformer = (data: { nodes: { id: NodeID, [key: string]: any }[], edges: { source: NodeID, target: NodeID, [key: string]: any }[] }): { nodes: INode[], edges: IEdge[] } => {
const { nodes, edges } = data;
return {
nodes: nodes.map((n) => {
const { id, label, data } = n;
return { id, data: { label, ...data } };
}),
edges: edges.map((e, i) => {
const { id, source, target, ...rest } = e;
return { id: id ? id : `edge-${i}`, target, source, data: rest };
}),
};
export const dataLabelDataTransformer = (data: { nodes: { id: ID, [key: string]: any }[], edges: { source: ID, target: ID, [key: string]: any }[] }): { nodes: INode[], edges: IEdge[] } => {
const { nodes, edges } = data;
return {
nodes: nodes.map((n) => {
const { id, label, data } = n;
return { id, data: { label, ...data } };
}),
edges: edges.map((e, i) => {
const { id, source, target, ...rest } = e;
return { id: id ? id : `edge-${i}`, target, source, data: rest };
}),
};
};
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
"build:ci": "pnpm -r run build:ci",
"prepare": "husky install",
"test": "jest",
"test_one": "jest ./__tests__/unit/k-means.spec.ts",
"test_one": "jest ./__tests__/unit/label-propagation.spec.ts",
"coverage": "jest --coverage",
"build:site": "vite build",
"deploy": "gh-pages -d site/dist",
Expand Down
10 changes: 6 additions & 4 deletions packages/graph/src/detect-cycle.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ export const detectDirectedCycle = (
return true;
},
};
for (let key of Object.keys(unvisitedSet)) {
for (const key of Object.keys(unvisitedSet)) {
depthFirstSearch(graph, key, callbacks, true, false);
}
return cycle;
Expand Down Expand Up @@ -193,8 +193,9 @@ export const detectAllDirectedCycle = (
adjList: { [key: ID]: number[] }
) => {
let closed = false; // whether a path is closed
if (nodeIds && include === false && nodeIds.indexOf(node.id) > -1)
if (nodeIds && !include && nodeIds.indexOf(node.id) > -1) {
return closed;
}
path.push(node);
blocked.add(node);
const neighbors = adjList[node.id];
Expand Down Expand Up @@ -277,7 +278,7 @@ export const detectAllDirectedCycle = (
// 对自环情况 (点连向自身) 特殊处理:记录自环,但不加入adjList
if (
neighbor === node.id &&
!(include === false && nodeIds.indexOf(node.id) > -1)
!(!include && nodeIds.indexOf(node.id) > -1)
) {
allCycles.push({ [node.id]: node });
} else {
Expand Down Expand Up @@ -306,8 +307,9 @@ export const detectAllDirectedCycle = (
});
const startNode = idx2Node[minIdx];
// StartNode is not in the specified node to include. End the search ahead of time.
if (nodeIds && include && nodeIds.indexOf(startNode.id) === -1)
if (nodeIds && include && nodeIds.indexOf(startNode.id) === -1) {
return allCycles;
}
circuit(startNode, startNode, adjList);
nodeIdx = minIdx + 1;
} else {
Expand Down
1 change: 1 addition & 0 deletions packages/graph/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ export * from './connected-component';
export * from './mst';
export * from './k-means';
export * from './detect-cycle';
export * from './label-propagation';
197 changes: 197 additions & 0 deletions packages/graph/src/label-propagation.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
import { uniqueId } from "./utils";
import { ClusterData, INode, IEdge, Graph, Matrix } from "./types";
import { ID } from "@antv/graphlib";

function getAdjMatrix(graph: Graph, directed: boolean) {
const nodes = graph.getAllNodes();
const matrix: Matrix[] = [];
// map node with index in data.nodes
const nodeMap = new Map<string | number, number>();

if (!nodes) {
throw new Error("invalid nodes data!");
}

if (nodes) {
nodes.forEach((node, i) => {
nodeMap.set(node.id, i);
const row: number[] = [];
matrix.push(row);
});
}

const edges = graph.getAllEdges();
if (edges) {
edges.forEach((edge) => {
const { source, target } = edge;
const sIndex = nodeMap.get(source);
const tIndex = nodeMap.get(target);
if ((!sIndex && sIndex !== 0) || (!tIndex && tIndex !== 0)) return;
matrix[sIndex][tIndex] = 1;
if (!directed) {
matrix[tIndex][sIndex] = 1;
}
});
}
return matrix;
}

/**
* 标签传播算法
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是否能统一为英文注释~

* @param graphData 图数据
* @param directed 是否有向图,默认为 false
* @param weightPropertyName 权重的属性字段
* @param maxIteration 最大迭代次数
*/
export const labelPropagation = (
graph: Graph,
directed: boolean = false,
weightPropertyName: string = "weight",
maxIteration: number = 1000
): ClusterData => {
// the origin data
const nodes = graph.getAllNodes();
const edges = graph.getAllEdges();

const clusters: { [key: string]: { id: string; nodes: INode[] } } = {};
const nodeMap: { [key: ID]: { node: INode; idx: number } } = {};
const nodeToCluster = new Map<ID, string>();
// init the clusters and nodeMap
nodes.forEach((node, i) => {
const cid: string = uniqueId();
nodeToCluster.set(node.id, cid);
clusters[cid] = {
id: cid,
nodes: [node],
};
nodeMap[node.id] = {
node,
idx: i,
};
});

// the adjacent matrix of calNodes inside clusters
const adjMatrix = getAdjMatrix(graph, directed);
// the sum of each row in adjacent matrix
const ks = [];
/**
* neighbor nodes (id for key and weight for value) for each node
* neighbors = {
* id(node_id): { id(neighbor_1_id): weight(weight of the edge), id(neighbor_2_id): weight(weight of the edge), ... },
* ...
* }
*/
const neighbors: { [key: ID]: { [key: ID]: number } } = {};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

因为新版的 ID 支持数字,这里使用对象,key 可能出现无法区分 0 和 "0" 的情况。可以测试看看会不会有这个问题

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

确实是会存在这个问题,我修改一下

adjMatrix.forEach((row, i) => {
let k = 0;
const iid = nodes[i].id;
neighbors[iid] = {};
row.forEach((entry, j) => {
if (!entry) return;
k += entry;
const jid = nodes[j].id;
neighbors[iid][jid] = entry;
});
ks.push(k);
});

let iter = 0;

while (iter < maxIteration) {
let changed = false;
nodes.forEach((node) => {
const neighborClusters: { [key: string]: number } = {};
Object.keys(neighbors[node.id]).forEach((neighborId) => {
const neighborWeight = neighbors[node.id][neighborId];
const neighborNode = nodeMap[neighborId].node;

const neighborClusterId = nodeToCluster.get(neighborNode.id);
if (!neighborClusters[neighborClusterId]) {
neighborClusters[neighborClusterId] = 0;
}
neighborClusters[neighborClusterId] += neighborWeight;
});
// find the cluster with max weight
let maxWeight = -Infinity;
let bestClusterIds: string[] = [];
Object.keys(neighborClusters).forEach((clusterId) => {
if (maxWeight < neighborClusters[clusterId]) {
maxWeight = neighborClusters[clusterId];
bestClusterIds = [clusterId];
} else if (maxWeight === neighborClusters[clusterId]) {
bestClusterIds.push(clusterId);
}
});
if (
bestClusterIds.length === 1 &&
bestClusterIds[0] === nodeToCluster.get(node.id)
) {
return;
}
const selfClusterIdx = bestClusterIds.indexOf(nodeToCluster.get(node.id));
if (selfClusterIdx >= 0) bestClusterIds.splice(selfClusterIdx, 1);
if (bestClusterIds && bestClusterIds.length) {
changed = true;

// remove from origin cluster
const selfCluster = clusters[nodeToCluster.get(node.id)];
const nodeInSelfClusterIdx = selfCluster.nodes.indexOf(node);
selfCluster.nodes.splice(nodeInSelfClusterIdx, 1);

// move the node to the best cluster
const randomIdx = Math.floor(Math.random() * bestClusterIds.length);
const bestCluster = clusters[bestClusterIds[randomIdx]];
bestCluster.nodes.push(node);
nodeToCluster.set(node.id, bestCluster.id);
}
});
if (!changed) break;
iter++;
}

// delete the empty clusters
Object.keys(clusters).forEach((clusterId) => {
const cluster = clusters[clusterId];
if (!cluster.nodes || !cluster.nodes.length) {
delete clusters[clusterId];
}
});

// get the cluster edges
const clusterEdges: IEdge[] = [];
const clusterEdgeMap: { [key: string]: IEdge } = {};
edges.forEach((edge) => {
let i = 0;
const { source, target } = edge;
const weight = (edge.data[weightPropertyName] || 1) as number;
const sourceClusterId = nodeToCluster.get(nodeMap[source].node.id);
const targetClusterId = nodeToCluster.get(nodeMap[target].node.id);
const newEdgeId = `${sourceClusterId}---${targetClusterId}`;
if (clusterEdgeMap[newEdgeId]) {
clusterEdgeMap[newEdgeId].data.weight += weight;
(clusterEdgeMap[newEdgeId].data.count as number)++;
} else {
const newEdge = {
id: i++,
source: sourceClusterId,
target: targetClusterId,
data: {
weight,
count: 1,
},
};
clusterEdgeMap[newEdgeId] = newEdge;
clusterEdges.push(newEdge);
}
});

const clustersArray: { id: string; nodes: INode[] }[] = [];
Object.keys(clusters).forEach((clusterId) => {
clustersArray.push(clusters[clusterId]);
});
return {
clusters: clustersArray,
clusterEdges,
nodeToCluster,
};
};
8 changes: 7 additions & 1 deletion packages/graph/src/utils.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import { Node, PlainObject } from "@antv/graphlib";
import { Vector } from "./vector";
import { DistanceType, KeyValueMap, NodeData } from "./types";
import { uniq } from "@antv/util";

Expand Down Expand Up @@ -101,3 +100,10 @@ function euclideanDistance(source: number[], target: number[]) {
return Math.sqrt(res);
}

export const uniqueId = (index: number = 0) => {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

咦我记得有 uniqueId 这个工具方法呀,在 @antv/util 里面

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okok 我改一下,我看之前的是自己写的就直接挪过来了

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

那这段是不是可以删掉了

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我的问题😭

const random1 = `${Math.random()}`.split('.')[1].slice(0, 5);
const random2 = `${Math.random()}`.split('.')[1].slice(0, 5);
return `${index}-${random1}${random2}`;
};