9
9
10
10
import numpy as np
11
11
import matplotlib .pyplot as plt
12
+ from sklearn .metrics .pairwise import pairwise_distances
12
13
13
14
14
15
def d (u , v ):
@@ -30,7 +31,7 @@ def cost(X, R, M):
30
31
return cost
31
32
32
33
33
- def plot_k_means (X , K , max_iter = 20 , beta = 1 .0 , show_plots = True ):
34
+ def plot_k_means (X , K , max_iter = 20 , beta = 3 .0 , show_plots = False ):
34
35
N , D = X .shape
35
36
M = np .zeros ((K , D ))
36
37
# R = np.zeros((N, K))
@@ -40,27 +41,41 @@ def plot_k_means(X, K, max_iter=20, beta=1.0, show_plots=True):
40
41
for k in range (K ):
41
42
M [k ] = X [np .random .choice (N )]
42
43
43
- costs = np .zeros (max_iter )
44
+ costs = []
45
+ k = 0
44
46
for i in range (max_iter ):
47
+ k += 1
45
48
# step 1: determine assignments / resposibilities
46
49
# is this inefficient?
47
50
for k in range (K ):
48
51
for n in range (N ):
49
- # R[n,k] = np.exp(-beta*d(M[k], X[n])) / np.sum( np.exp(-beta*d(M[j], X[n])) for j in range(K) )
50
52
exponents [n ,k ] = np .exp (- beta * d (M [k ], X [n ]))
51
-
52
53
R = exponents / exponents .sum (axis = 1 , keepdims = True )
53
- # assert(np.abs(R - R2).sum() < 1e-10)
54
+
54
55
55
56
# step 2: recalculate means
56
- for k in range (K ):
57
- M [k ] = R [:,k ].dot (X ) / R [:,k ].sum ()
57
+ # decent vectorization
58
+ # for k in range(K):
59
+ # M[k] = R[:,k].dot(X) / R[:,k].sum()
60
+ # oldM = M
58
61
59
- costs [i ] = cost (X , R , M )
62
+ # full vectorization
63
+ M = R .T .dot (X ) / R .sum (axis = 0 , keepdims = True ).T
64
+ # print("diff M:", np.abs(M - oldM).sum())
65
+
66
+ c = cost (X , R , M )
67
+ costs .append (c )
60
68
if i > 0 :
61
- if np .abs (costs [i ] - costs [i - 1 ]) < 1e-5 :
69
+ if np .abs (costs [- 1 ] - costs [- 2 ]) < 1e-5 :
62
70
break
63
71
72
+ if len (costs ) > 1 :
73
+ if costs [- 1 ] > costs [- 2 ]:
74
+ pass
75
+ # print("cost increased!")
76
+ # print("M:", M)
77
+ # print("R.min:", R.min(), "R.max:", R.max())
78
+
64
79
if show_plots :
65
80
plt .plot (costs )
66
81
plt .title ("Costs" )
@@ -71,6 +86,7 @@ def plot_k_means(X, K, max_iter=20, beta=1.0, show_plots=True):
71
86
plt .scatter (X [:,0 ], X [:,1 ], c = colors )
72
87
plt .show ()
73
88
89
+ print ("Final cost" , costs [- 1 ])
74
90
return M , R
75
91
76
92
@@ -98,13 +114,19 @@ def main():
98
114
plt .show ()
99
115
100
116
K = 3 # luckily, we already know this
101
- plot_k_means (X , K )
117
+ plot_k_means (X , K , beta = 1.0 , show_plots = True )
118
+
119
+ K = 3 # luckily, we already know this
120
+ plot_k_means (X , K , beta = 3.0 , show_plots = True )
121
+
122
+ K = 3 # luckily, we already know this
123
+ plot_k_means (X , K , beta = 10.0 , show_plots = True )
102
124
103
125
K = 5 # what happens if we choose a "bad" K?
104
- plot_k_means (X , K , max_iter = 30 )
126
+ plot_k_means (X , K , max_iter = 30 , show_plots = True )
105
127
106
128
K = 5 # what happens if we change beta?
107
- plot_k_means (X , K , max_iter = 30 , beta = 0.3 )
129
+ plot_k_means (X , K , max_iter = 30 , beta = 0.3 , show_plots = True )
108
130
109
131
110
132
if __name__ == '__main__' :
0 commit comments