Skip to content

Commit 440a2c6

Browse files
committed
kmeans
1 parent 6d93980 commit 440a2c6

File tree

1 file changed

+106
-0
lines changed

1 file changed

+106
-0
lines changed

python/01_intro/as_kmeans.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# download "s3.txt" from http://cs.uef.fi/sipu/datasets/
2+
# wget http://cs.uef.fi/sipu/datasets/s3.txt
3+
4+
"""
5+
given a set of points
6+
1. select k points (centroids) randomly
7+
8+
2. repeat
9+
- label points
10+
- update centroids
11+
"""
12+
from pprint import pprint
13+
from typing import Tuple, Sequence, Mapping, Callable, Iterable
14+
import matplotlib.pyplot as plt
15+
from random import sample
16+
from collections import defaultdict
17+
from functools import partial
18+
from statistics import mean
19+
20+
Point = Tuple[float, float]
21+
Centroid = Point
22+
Cluster = Sequence[Point]
23+
Dist_func = Callable[[Point, Point], float]
24+
Distortion = float
25+
26+
27+
def guess_centroids(dataset: Sequence[Point], k: int) -> Sequence[Centroid]:
28+
return sample(dataset, k=k)
29+
30+
31+
def distance(p: Point, q: Point, /) -> float:
32+
return (p[0] - q[0]) * (p[0] - q[0]) + (p[1] - q[1]) * (p[1] - q[1])
33+
# return sum((xp-xq)*(xp-xq) for xp,xq in zip(p,q))
34+
35+
36+
def label(
37+
dataset: Sequence[Point], centroids: Sequence[Centroid], dist: Dist_func
38+
) -> Mapping[Centroid, Cluster]:
39+
d = defaultdict(list)
40+
for p in dataset:
41+
pdist = partial(dist, p)
42+
centroid = min(centroids, key=pdist)
43+
d[centroid].append(p)
44+
return d
45+
46+
47+
def update_centroids(clusters: Iterable[Cluster]) -> Sequence[Centroid]:
48+
centroids = []
49+
for cluster in clusters:
50+
xc, yc = list(zip(*cluster))
51+
centroids.append((mean(xc), mean(yc)))
52+
return centroids
53+
54+
55+
def distortion(
56+
labeled_dataset: Mapping[Centroid, Cluster], distance: Dist_func
57+
) -> float:
58+
dist = 0.0
59+
for centroid, cluster in labeled_dataset.items():
60+
pdist = partial(distance, centroid)
61+
dist += mean(map(pdist, cluster))
62+
return dist
63+
64+
65+
def _kmeans(
66+
dataset: Sequence[Point], k: int, n_iter: int, dist: Dist_func
67+
) -> Tuple[Mapping[Centroid, Cluster], Distortion]:
68+
centroids = guess_centroids(dataset, k)
69+
for _ in range(n_iter):
70+
labeled = label(dataset, centroids, dist)
71+
centroids = update_centroids(labeled.values())
72+
labeled = label(dataset, centroids, dist)
73+
return labeled, distortion(labeled, dist)
74+
75+
76+
def kmeans(dataset, k, inner, outer, dist):
77+
best_distortion = float("inf")
78+
best_mapping = {}
79+
for _ in range(outer):
80+
mapping, distortion = _kmeans(dataset, k, inner, dist)
81+
if distortion < best_distortion:
82+
best_mapping = mapping
83+
best_distortion = distortion
84+
return best_mapping, best_distortion
85+
86+
87+
if __name__ == "__main__":
88+
89+
points: Sequence[Point]
90+
91+
with open("s3.txt") as f:
92+
points = [tuple(map(float, line.split())) for line in f]
93+
94+
# pprint(points, width=40)
95+
X, Y = list(zip(*points))
96+
97+
d, _ = kmeans(points, k=15, inner=10, outer=15, dist=distance)
98+
99+
centroids = d.keys()
100+
101+
Xc, Yc = list(zip(*centroids))
102+
103+
plt.scatter(X, Y, s=0.5)
104+
plt.scatter(Xc, Yc)
105+
106+
plt.show()

0 commit comments

Comments
 (0)