diff --git a/__tests__/unit/k-means.spec.ts b/__tests__/unit/k-means.spec.ts new file mode 100644 index 0000000..92b8e06 --- /dev/null +++ b/__tests__/unit/k-means.spec.ts @@ -0,0 +1,391 @@ +import { kMeans } from '../../packages/graph/src' +import propertiesGraphData from '../data/cluster-origin-properties-data.json'; +import { Graph } from "@antv/graphlib"; +import { dataPropertiesTransformer, dataLabelDataTransformer } from '../utils/data'; + + +describe('kMeans abnormal demo', () => { + it('no properties demo: ', () => { + const noPropertiesData = { + nodes: [ + { + id: 'node-0', + data: {}, + }, + { + id: 'node-1', + data: {}, + }, + { + id: 'node-2', + data: {}, + }, + { + id: 'node-3', + data: {}, + } + ], + } + const graph = new Graph(noPropertiesData); + const { clusters, clusterEdges } = kMeans(graph, 2); + expect(clusters.length).toBe(1); + expect(clusterEdges.length).toBe(0); + }); +}); + + +describe('kMeans normal demo', () => { + it('simple data demo: ', () => { + const simpleGraphData = { + nodes: [ + { + id: 'node-0', + properties: { + amount: 10, + city: '10001', + } + }, + { + id: 'node-1', + properties: { + amount: 10000, + city: '10002', + } + }, + { + id: 'node-2', + properties: { + amount: 3000, + city: '10003', + } + }, + { + id: 'node-3', + properties: { + amount: 3200, + city: '10003', + } + }, + { + id: 'node-4', + properties: { + amount: 2000, + city: '10003', + } + } + ], + edges: [ + { + id: 'edge-0', + source: 'node-0', + target: 'node-1', + }, + { + id: 'edge-1', + source: 'node-0', + target: 'node-2', + }, + { + id: 'edge-4', + source: 'node-3', + target: 'node-2', + }, + { + id: 'edge-5', + source: 'node-2', + target: 'node-1', + }, + { + id: 'edge-6', + source: 'node-4', + target: 'node-1', + }, + ] + } + const data = dataPropertiesTransformer(simpleGraphData); + const graph = new Graph(data); + const { clusters, nodeToCluster } = kMeans(graph, 3); + expect(clusters.length).toBe(3); + const nodes = graph.getAllNodes(); + + + expect(nodeToCluster.get(nodes[2].id)).toEqual(nodeToCluster.get(nodes[3].id)); + expect(nodeToCluster.get(nodes[2].id)).toEqual(nodeToCluster.get(nodes[4].id)); + }); + + + it('complex data demo: ', () => { + const data = dataLabelDataTransformer(propertiesGraphData); + const graph = new Graph(data); + const { clusters,nodeToCluster } = kMeans(graph, 3); + expect(clusters.length).toBe(3); + const nodes = graph.getAllNodes(); + expect(nodeToCluster.get(nodes[0].id)).toEqual(nodeToCluster.get(nodes[1].id)); + expect(nodeToCluster.get(nodes[0].id)).toEqual(nodeToCluster.get(nodes[2].id)); + expect(nodeToCluster.get(nodes[0].id)).toEqual(nodeToCluster.get(nodes[3].id)); + expect(nodeToCluster.get(nodes[0].id)).toEqual(nodeToCluster.get(nodes[4].id)); + expect(nodeToCluster.get(nodes[5].id)).toEqual(nodeToCluster.get(nodes[6].id)); + expect(nodeToCluster.get(nodes[5].id)).toEqual(nodeToCluster.get(nodes[7].id)); + expect(nodeToCluster.get(nodes[5].id)).toEqual(nodeToCluster.get(nodes[8].id)); + expect(nodeToCluster.get(nodes[5].id)).toEqual(nodeToCluster.get(nodes[9].id)); + expect(nodeToCluster.get(nodes[5].id)).toEqual(nodeToCluster.get(nodes[10].id)); + expect(nodeToCluster.get(nodes[11].id)).toEqual(nodeToCluster.get(nodes[12].id)); + expect(nodeToCluster.get(nodes[11].id)).toEqual(nodeToCluster.get(nodes[13].id)); + expect(nodeToCluster.get(nodes[11].id)).toEqual(nodeToCluster.get(nodes[14].id)); + expect(nodeToCluster.get(nodes[11].id)).toEqual(nodeToCluster.get(nodes[15].id)); + expect(nodeToCluster.get(nodes[11].id)).toEqual(nodeToCluster.get(nodes[16].id)); + }); + + it('demo use involvedKeys: ', () => { + const simpleGraphData = { + nodes: [ + { + id: 'node-0', + properties: { + amount: 10, + city: '10001', + } + }, + { + id: 'node-1', + properties: { + amount: 10000, + city: '10002', + } + }, + { + id: 'node-2', + properties: { + amount: 3000, + city: '10003', + } + }, + { + id: 'node-3', + properties: { + amount: 3200, + city: '10003', + } + }, + { + id: 'node-4', + properties: { + amount: 2000, + city: '10003', + } + } + ], + edges: [ + { + id: 'edge-0', + source: 'node-0', + target: 'node-1', + }, + { + id: 'edge-1', + source: 'node-0', + target: 'node-2', + }, + { + id: 'edge-4', + source: 'node-3', + target: 'node-2', + }, + { + id: 'edge-5', + source: 'node-2', + target: 'node-1', + }, + { + id: 'edge-6', + source: 'node-4', + target: 'node-1', + }, + ] + } + const data = dataPropertiesTransformer(simpleGraphData); + const involvedKeys = ['amount']; + const graph = new Graph(data); + const { clusters ,nodeToCluster} = kMeans(graph, 3, involvedKeys); + expect(clusters.length).toBe(3); + const nodes = graph.getAllNodes(); + expect(nodeToCluster.get(nodes[2].id)).toEqual(nodeToCluster.get(nodes[3].id)); + expect(nodeToCluster.get(nodes[2].id)).toEqual(nodeToCluster.get(nodes[4].id)); + }); + + it('demo use uninvolvedKeys: ', () => { + const simpleGraphData = { + nodes: [ + { + id: 'node-0', + properties: { + amount: 10, + city: '10001', + } + }, + { + id: 'node-1', + properties: { + amount: 10000, + city: '10002', + } + }, + { + id: 'node-2', + properties: { + amount: 3000, + city: '10003', + } + }, + { + id: 'node-3', + properties: { + amount: 3200, + city: '10003', + } + }, + { + id: 'node-4', + properties: { + amount: 2000, + city: '10003', + } + } + ], + edges: [ + { + id: 'edge-0', + source: 'node-0', + target: 'node-1', + }, + { + id: 'edge-1', + source: 'node-0', + target: 'node-2', + }, + { + id: 'edge-4', + source: 'node-3', + target: 'node-2', + }, + { + id: 'edge-5', + source: 'node-2', + target: 'node-1', + }, + { + id: 'edge-6', + source: 'node-4', + target: 'node-1', + }, + ] + } + const data = dataPropertiesTransformer(simpleGraphData); + const graph = new Graph(data); + const uninvolvedKeys = ['id', 'city']; + const { clusters,nodeToCluster } = kMeans(graph, 3, [], uninvolvedKeys); + expect(clusters.length).toBe(3); + const nodes = graph.getAllNodes(); data + expect(nodeToCluster.get(nodes[2].id)).toEqual(nodeToCluster.get(nodes[3].id)); + expect(nodeToCluster.get(nodes[2].id)).toEqual(nodeToCluster.get(nodes[4].id)); + }); + +}); + +describe('kMeans All properties values are numeric demo', () => { + it('all properties values are numeric demo: ', () => { + const allPropertiesValuesNumericData = { + nodes: [ + { + id: 'node-0', + properties: { + max: 1000000, + mean: 900000, + min: 800000, + } + }, + { + id: 'node-1', + properties: { + max: 1600000, + mean: 1100000, + min: 600000, + } + }, + { + id: 'node-2', + properties: { + max: 5000, + mean: 3500, + min: 2000, + } + }, + { + id: 'node-3', + properties: { + max: 9000, + mean: 7500, + min: 6000, + } + } + ], + edges: [], + } + const data = dataPropertiesTransformer(allPropertiesValuesNumericData); + const graph = new Graph(data); + const { clusters, clusterEdges,nodeToCluster } = kMeans(graph, 2); + expect(clusters.length).toBe(2); + expect(clusterEdges.length).toBe(0); + const nodes = graph.getAllNodes(); + expect(nodeToCluster.get(nodes[0].id)).toEqual(nodeToCluster.get(nodes[1].id)); + expect(nodeToCluster.get(nodes[2].id)).toEqual(nodeToCluster.get(nodes[3].id)); + }); + it('only one property and the value are numeric demo: ', () => { + const allPropertiesValuesNumericData = { + nodes: [ + { + id: 'node-0', + properties: { + num: 10, + } + }, + { + id: 'node-1', + properties: { + num: 12, + } + }, + { + id: 'node-2', + properties: { + num: 56, + } + }, + { + id: 'node-3', + properties: { + num: 300, + } + }, + { + id: 'node-4', + properties: { + num: 350, + } + } + ], + edges: [], + } + const data = dataPropertiesTransformer(allPropertiesValuesNumericData); + const graph = new Graph(data); + const { clusters, clusterEdges,nodeToCluster } = kMeans(graph, 2); + expect(clusters.length).toBe(2); + expect(clusterEdges.length).toBe(0); + const nodes = graph.getAllNodes(); + expect(nodeToCluster.get(nodes[0].id)).toEqual(nodeToCluster.get(nodes[1].id)); + expect(nodeToCluster.get(nodes[0].id)).toEqual(nodeToCluster.get(nodes[2].id)); + expect(nodeToCluster.get(nodes[3].id)).toEqual(nodeToCluster.get(nodes[4].id)); + }); + +}); + diff --git a/__tests__/unit/louvain.spec.ts b/__tests__/unit/louvain.spec.ts index 639db91..3f060af 100644 --- a/__tests__/unit/louvain.spec.ts +++ b/__tests__/unit/louvain.spec.ts @@ -73,6 +73,21 @@ describe('Louvain', () => { expect(clusteredData.clusterEdges[0].data.count).toBe(13); expect(clusteredData.clusterEdges[1].data.count).toBe(10); expect(clusteredData.clusterEdges[1].data.weight).toBe(14); + expect(clusteredData.nodeToCluster.get('0')).toBe('1'); + expect(clusteredData.nodeToCluster.get('1')).toBe('1'); + expect(clusteredData.nodeToCluster.get('2')).toBe('1'); + expect(clusteredData.nodeToCluster.get('3')).toBe('1'); + expect(clusteredData.nodeToCluster.get('4')).toBe('1'); + expect(clusteredData.nodeToCluster.get('5')).toBe('2'); + expect(clusteredData.nodeToCluster.get('6')).toBe('2'); + expect(clusteredData.nodeToCluster.get('7')).toBe('2'); + expect(clusteredData.nodeToCluster.get('8')).toBe('2'); + expect(clusteredData.nodeToCluster.get('9')).toBe('2'); + expect(clusteredData.nodeToCluster.get('10')).toBe('3'); + expect(clusteredData.nodeToCluster.get('11')).toBe('3'); + expect(clusteredData.nodeToCluster.get('12')).toBe('3'); + expect(clusteredData.nodeToCluster.get('13')).toBe('3'); + expect(clusteredData.nodeToCluster.get('14')).toBe('3'); }); // it('louvain with large graph', () => { // https://gw.alipayobjects.com/os/antvdemo/assets/data/relations.json diff --git a/__tests__/utils/data.ts b/__tests__/utils/data.ts index 5d33b89..2914ea5 100644 --- a/__tests__/utils/data.ts +++ b/__tests__/utils/data.ts @@ -21,3 +21,32 @@ export const dataTransformer = (data: { }), }; }; + +export const dataPropertiesTransformer = (data: { nodes: { id: NodeID, [key: string]: any }[], edges: { source: NodeID, target: NodeID, [key: string]: any }[] }): { nodes: INode[], edges: IEdge[] } => { + const { nodes, edges } = data; + return { + nodes: nodes.map((n) => { + const { id, properties, ...rest } = n; + return { id, data: { ...properties, ...rest } }; + }), + edges: edges.map((e, i) => { + const { id, source, target, ...rest } = e; + return { id: id ? id : `edge-${i}`, target, source, data: rest }; + }), + }; +}; + + +export const dataLabelDataTransformer = (data: { nodes: { id: NodeID, [key: string]: any }[], edges: { source: NodeID, target: NodeID, [key: string]: any }[] }): { nodes: INode[], edges: IEdge[] } => { + const { nodes, edges } = data; + return { + nodes: nodes.map((n) => { + const { id, label, data } = n; + return { id, data: { label, ...data } }; + }), + edges: edges.map((e, i) => { + const { id, source, target, ...rest } = e; + return { id: id ? id : `edge-${i}`, target, source, data: rest }; + }), + }; +}; \ No newline at end of file diff --git a/package.json b/package.json index 0c4fcca..2ee91e3 100644 --- a/package.json +++ b/package.json @@ -21,7 +21,7 @@ "build:ci": "pnpm -r run build:ci", "prepare": "husky install", "test": "jest", - "test_one": "jest ./__tests__/unit/detect-cycle.spec.ts", + "test_one": "jest ./__tests__/unit/k-means.spec.ts", "coverage": "jest --coverage", "build:site": "vite build", "deploy": "gh-pages -d site/dist", diff --git a/packages/graph/src/index.ts b/packages/graph/src/index.ts index ae8dce6..a02ba64 100644 --- a/packages/graph/src/index.ts +++ b/packages/graph/src/index.ts @@ -11,4 +11,5 @@ export * from './nodes-cosine-similarity'; export * from './gaddi'; export * from './connected-component'; export * from './mst'; -export * from './detect-cycle'; \ No newline at end of file +export * from './k-means'; +export * from './detect-cycle'; diff --git a/packages/graph/src/k-means.ts b/packages/graph/src/k-means.ts new file mode 100644 index 0000000..4ae1bdd --- /dev/null +++ b/packages/graph/src/k-means.ts @@ -0,0 +1,223 @@ +import { isEqual, uniq } from '@antv/util'; +import { Edge, ID } from '@antv/graphlib'; +import { getAllProperties, oneHot, getDistance } from './utils'; +import { Vector } from "./vector"; +import { ClusterData, DistanceType, Graph, EdgeData, Cluster } from './types'; + +/** + * Calculates the centroid based on the distance type and the given index. + * @param distanceType The distance type to use for centroid calculation. + * @param allPropertiesWeight The weight matrix of all properties. + * @param index The index of the centroid. + * @returns The centroid. + */ +const getCentroid = (distanceType: DistanceType, allPropertiesWeight: number[][], index: number) => { + let centroid: number[] = []; + switch (distanceType) { + case DistanceType.EuclideanDistance: + centroid = allPropertiesWeight[index]; + break; + default: + centroid = []; + break; + } + return centroid; +}; + +/** + * Performs the k-means clustering algorithm on a graph. + * @param graph The graph to perform clustering on. + * @param k The number of clusters. + * @param involvedKeys The keys of properties to be considered for clustering. Default is an empty array. + * @param uninvolvedKeys The keys of properties to be ignored for clustering. Default is ['id']. + * @param distanceType The distance type to use for clustering. Default is DistanceType.EuclideanDistance. + * @returns The cluster data containing the clusters and cluster edges. + */ +export const kMeans = ( + graph: Graph, + k: number = 3, + involvedKeys: string[] = [], + uninvolvedKeys: string[] = [], + distanceType: DistanceType = DistanceType.EuclideanDistance, +): ClusterData => { + const nodes = graph.getAllNodes(); + const edges = graph.getAllEdges(); + const nodeToOriginIdx = new Map(); + const nodeToCluster = new Map(); + const defaultClusterInfo: ClusterData = { + clusters: [ + { + id: "0", + nodes, + } + ], + clusterEdges: [], + nodeToCluster, + }; + + // When the distance type is Euclidean distance and there are no attributes in data, return directly + if (distanceType === DistanceType.EuclideanDistance && !nodes.every((node) => node.data)) { + return defaultClusterInfo; + } + let properties = []; + let allPropertiesWeight: number[][] = []; + if (distanceType === DistanceType.EuclideanDistance) { + properties = getAllProperties(nodes); + allPropertiesWeight = oneHot(properties, involvedKeys, uninvolvedKeys) as number[][]; + } + if (!allPropertiesWeight.length) { + return defaultClusterInfo; + } + const allPropertiesWeightUniq = uniq(allPropertiesWeight.map((item) => item.join(''))); + // When the number of nodes or the length of the attribute set is less than k, k will be adjusted to the smallest of them + const finalK = Math.min(k, nodes.length, allPropertiesWeightUniq.length); + for (let i = 0; i < nodes.length; i++) { + nodeToOriginIdx.set(nodes[i].id, i); + } + const centroids: number[][] = []; + const centroidIndexList: number[] = []; + const clusters: Cluster[] = []; + for (let i = 0; i < finalK; i++) { + if (i === 0) { + // random choose centroid + const randomIndex = Math.floor(Math.random() * nodes.length); + switch (distanceType) { + case DistanceType.EuclideanDistance: + centroids[i] = allPropertiesWeight[randomIndex]; + break; + default: + centroids[i] = []; + break; + } + centroidIndexList.push(randomIndex); + nodeToCluster.set(nodes[randomIndex].id, `${i}`); + clusters[i] = { + id: `${i}`, + nodes: [nodes[randomIndex]] + }; + } else { + let maxDistance = -Infinity; + let maxDistanceNodeIndex = 0; + // Select the point with the farthest average distance from the existing centroid as the new centroid + for (let m = 0; m < nodes.length; m++) { + if (!centroidIndexList.includes(m)) { + let totalDistance = 0; + for (let j = 0; j < centroids.length; j++) { + // Find the distance from the node to the centroid (Euclidean distance of the default node attribute) + let distance = 0; + switch (distanceType) { + case DistanceType.EuclideanDistance: + distance = getDistance(allPropertiesWeight[nodeToOriginIdx.get(nodes[m].id)], centroids[j], distanceType); + break; + default: + break; + } + totalDistance += distance; + } + // The average distance from the node to each centroid (default Euclidean distance) + const avgDistance = totalDistance / centroids.length; + // Record the distance and node index to the farthest centroid + if (avgDistance > maxDistance && + !centroids.find((centroid) => isEqual(centroid, getCentroid(distanceType, allPropertiesWeight, nodeToOriginIdx.get(nodes[m].id))))) { + maxDistance = avgDistance; + maxDistanceNodeIndex = m; + } + } + } + centroids[i] = getCentroid(distanceType, allPropertiesWeight, maxDistanceNodeIndex); + centroidIndexList.push(maxDistanceNodeIndex); + clusters[i] = { + id: `${i}`, + nodes: [nodes[maxDistanceNodeIndex]] + }; + nodeToCluster.set(nodes[maxDistanceNodeIndex].id, `${i}`); + } + } + + + let iterations = 0; + while (true) { + for (let i = 0; i < nodes.length; i++) { + let minDistanceIndex = 0; + let minDistance = Infinity; + if (!(iterations === 0 && centroidIndexList.includes(i))) { + for (let j = 0; j < centroids.length; j++) { + let distance = 0; + switch (distanceType) { + case DistanceType.EuclideanDistance: + distance = getDistance(allPropertiesWeight[i], centroids[j], distanceType); + break; + default: + break; + } + if (distance < minDistance) { + minDistance = distance; + minDistanceIndex = j; + } + } + // delete node + const cId = nodeToCluster.get(nodes[i].id); + if (cId !== undefined) { + for (let n = clusters[Number(cId)].nodes.length - 1; n >= 0; n--) { + if (clusters[Number(cId)].nodes[n].id === nodes[i].id) { + clusters[Number(cId)].nodes.splice(n, 1); + } + } + } + // Divide the node into the class corresponding to the centroid (cluster center) with the smallest distance. + nodeToCluster.set(nodes[i].id, `${minDistanceIndex}`); + clusters[minDistanceIndex].nodes.push(nodes[i]); + } + } + // Determine if there is a centroid (cluster center) movement + let centroidsEqualAvg = false; + for (let i = 0; i < clusters.length; i++) { + const clusterNodes = clusters[i].nodes; + let totalVector = new Vector([]); + for (let j = 0; j < clusterNodes.length; j++) { + totalVector = totalVector.add(new Vector(allPropertiesWeight[nodeToOriginIdx.get(clusterNodes[j].id)])); + } + // Calculates the mean vector for each category + const avgVector = totalVector.avg(clusterNodes.length); + // If the mean vector is not equal to the centroid vector + if (!avgVector.equal(new Vector(centroids[i]))) { + centroidsEqualAvg = true; + // Move/update the centroid (cluster center) of each category to this mean vector + centroids[i] = avgVector.getArr(); + } + } + iterations++; + // Stop if each node belongs to a category and there is no centroid (cluster center) movement or the number of iterations exceeds 1000 + if (nodes.every((node) => !nodeToCluster.get(node.id)) && centroidsEqualAvg || iterations >= 1000) { + break; + } + } + + // get the cluster edges + const clusterEdges: Edge[] = []; + const clusterEdgeMap: { + [key: string]: Edge + } = {}; + let edgeIndex = 0; + edges.forEach((edge) => { + const { source, target } = edge; + const sourceClusterId = nodeToCluster.get(source); + const targetClusterId = nodeToCluster.get(target); + const newEdgeId = `${sourceClusterId}---${targetClusterId}`; + if (clusterEdgeMap[newEdgeId]) { + (clusterEdgeMap[newEdgeId].data.count as number)++; + } else { + const newEdge = { + id: edgeIndex++, + source: sourceClusterId, + target: targetClusterId, + data: { count: 1 }, + }; + clusterEdgeMap[newEdgeId] = newEdge; + clusterEdges.push(newEdge); + } + }); + + return { clusters, clusterEdges, nodeToCluster }; +}; + diff --git a/packages/graph/src/types.ts b/packages/graph/src/types.ts index df9e364..352ec9d 100644 --- a/packages/graph/src/types.ts +++ b/packages/graph/src/types.ts @@ -29,6 +29,7 @@ export interface ClusterMap { [key: string]: Cluster; } + export type Graph = IGraph; export type Matrix = number[]; @@ -58,6 +59,10 @@ export type IEdge = Edge; export type IMSTAlgorithm = (graph: Graph, weightProps?: string) => IEdge[]; export interface IMSTAlgorithmOpt { - prim: IMSTAlgorithm; - kruskal: IMSTAlgorithm; + 'prim': IMSTAlgorithm; + 'kruskal': IMSTAlgorithm; +} + +export enum DistanceType { + EuclideanDistance = 'euclideanDistance', } diff --git a/packages/graph/src/utils.ts b/packages/graph/src/utils.ts index e383a1f..93fa9a5 100644 --- a/packages/graph/src/utils.ts +++ b/packages/graph/src/utils.ts @@ -1,5 +1,6 @@ import { Node, PlainObject } from "@antv/graphlib"; -import { KeyValueMap, NodeData } from "./types"; +import { Vector } from "./vector"; +import { DistanceType, KeyValueMap, NodeData } from "./types"; import { uniq } from "@antv/util"; export const getAllProperties = (nodes: Node[]) => { @@ -77,3 +78,26 @@ export const oneHot = (dataList: PlainObject[], involvedKeys?: string[], uninvol }); return oneHotCode; }; + +export const getDistance = (item: number[], otherItem: number[], distanceType: DistanceType = DistanceType.EuclideanDistance) => { + let distance = 0; + switch (distanceType) { + case DistanceType.EuclideanDistance: + distance = euclideanDistance(item, otherItem); + break; + default: + break; + } + return distance; +}; + + +function euclideanDistance(source: number[], target: number[]) { + if (source.length !== target.length) return 0; + let res = 0; + source.forEach((s, i) => { + res += Math.pow(s - target[i], 2); + }); + return Math.sqrt(res); +} +