23
23
24
24
# Set k-means parameters
25
25
# There are 3 types of iris flowers, see if we can predict them
26
- k = 3
26
+ k = 3
27
27
generations = 25
28
28
29
29
data_points = tf .Variable (iris .data )
41
41
point_matrix = tf .reshape (tf .tile (data_points , [1 , k ]), [num_pts , k , num_feats ])
42
42
distances = tf .reduce_sum (tf .square (point_matrix - centroid_matrix ), axis = 2 )
43
43
44
- #Find the group it belongs to with tf.argmin()
44
+ # Find the group it belongs to with tf.argmin()
45
45
centroid_group = tf .argmin (distances , 1 )
46
46
47
+
47
48
# Find the group average
48
49
def data_group_avg (group_ids , data ):
49
50
# Sum each group
@@ -52,7 +53,8 @@ def data_group_avg(group_ids, data):
52
53
num_total = tf .unsorted_segment_sum (tf .ones_like (data ), group_ids , 3 )
53
54
# Calculate average
54
55
avg_by_group = sum_total / num_total
55
- return (avg_by_group )
56
+ return avg_by_group
57
+
56
58
57
59
means = data_group_avg (centroid_group , data_points )
58
60
@@ -73,18 +75,20 @@ def data_group_avg(group_ids, data):
73
75
74
76
[centers , assignments ] = sess .run ([centroids , cluster_labels ])
75
77
78
+
76
79
# Find which group assignments correspond to which group labels
77
80
# First, need a most common element function
78
81
def most_common (my_list ):
79
- return (max (set (my_list ), key = my_list .count ))
82
+ return max (set (my_list ), key = my_list .count )
83
+
80
84
81
85
label0 = most_common (list (assignments [0 :50 ]))
82
86
label1 = most_common (list (assignments [50 :100 ]))
83
87
label2 = most_common (list (assignments [100 :150 ]))
84
88
85
- group0_count = np .sum (assignments [0 :50 ]== label0 )
86
- group1_count = np .sum (assignments [50 :100 ]== label1 )
87
- group2_count = np .sum (assignments [100 :150 ]== label2 )
89
+ group0_count = np .sum (assignments [0 :50 ] == label0 )
90
+ group1_count = np .sum (assignments [50 :100 ] == label1 )
91
+ group2_count = np .sum (assignments [100 :150 ] == label2 )
88
92
89
93
accuracy = (group0_count + group1_count + group2_count )/ 150.
90
94
@@ -108,17 +112,15 @@ def most_common(my_list):
108
112
# Get k-means classifications for the grid points
109
113
xx_pt = list (xx .ravel ())
110
114
yy_pt = list (yy .ravel ())
111
- xy_pts = np .array ([[x ,y ] for x ,y in zip (xx_pt , yy_pt )])
115
+ xy_pts = np .array ([[x , y ] for x , y in zip (xx_pt , yy_pt )])
112
116
mytree = cKDTree (reduced_centers )
113
117
dist , indexes = mytree .query (xy_pts )
114
118
115
119
# Put the result into a color plot
116
120
indexes = indexes .reshape (xx .shape )
117
121
plt .figure (1 )
118
122
plt .clf ()
119
- plt .imshow (indexes , interpolation = 'nearest' ,
120
- extent = (xx .min (), xx .max (), yy .min (), yy .max ()),
121
- cmap = plt .cm .Paired ,
123
+ plt .imshow (indexes , interpolation = 'nearest' , extent = (xx .min (), xx .max (), yy .min (), yy .max ()), cmap = plt .cm .Paired ,
122
124
aspect = 'auto' , origin = 'lower' )
123
125
124
126
# Plot each of the true iris data groups
@@ -128,12 +130,9 @@ def most_common(my_list):
128
130
temp_group = reduced_data [(i * 50 ):(50 )* (i + 1 )]
129
131
plt .plot (temp_group [:, 0 ], temp_group [:, 1 ], symbols [i ], markersize = 10 , label = label_name [i ])
130
132
# Plot the centroids as a white X
131
- plt .scatter (reduced_centers [:, 0 ], reduced_centers [:, 1 ],
132
- marker = 'x' , s = 169 , linewidths = 3 ,
133
- color = 'w' , zorder = 10 )
134
- plt .title ('K-means clustering on Iris Dataset\n '
135
- 'Centroids are marked with white cross' )
133
+ plt .scatter (reduced_centers [:, 0 ], reduced_centers [:, 1 ], marker = 'x' , s = 169 , linewidths = 3 , color = 'w' , zorder = 10 )
134
+ plt .title ('K-means clustering on Iris Dataset Centroids are marked with white cross' )
136
135
plt .xlim (x_min , x_max )
137
136
plt .ylim (y_min , y_max )
138
137
plt .legend (loc = 'lower right' )
139
- plt .show ()
138
+ plt .show ()
0 commit comments