1
+ import matplotlib
2
+ matplotlib .use ('Agg' )
3
+ import keras
4
+ import numpy as np
5
+ import tensorflow as tf
6
+ import os
7
+ from matplotlib import pyplot as plt
8
+ from scipy .cluster .hierarchy import dendrogram
9
+ from sklearn .cluster import AgglomerativeClustering
10
+ from sklearn .metrics import silhouette_score , silhouette_samples
11
+
12
+ class Cluster ():
13
+ """
14
+ A class for conducting an cluster study on a trained keras model instance
15
+
16
+ """
17
+
18
+
19
+ def __init__ (self , model , weights_pth , layer_name , max_clusters = None ):
20
+
21
+ """
22
+ model : keras model architecture (keras.models.Model)
23
+ weights_pth : saved weights path (str)
24
+ metric : metric to compare prediction with gt, for example dice, CE
25
+ layer_name : name of the layer which needs to be ablated
26
+ test_img : test image used for ablation
27
+ max_clusters: maximum number of clusters
28
+ """
29
+
30
+ self .model = model
31
+ self .weights = weights_pth
32
+ self .model .load_weights (self .weights )
33
+ self .layer = layer_name
34
+ self .layer_idx = 0
35
+ for idx , layer in enumerate (self .model .layers ):
36
+ if layer .name == self .layer :
37
+ self .layer_idx = idx
38
+ self .weights = np .array (self .model .layers [self .layer_idx ].get_weights ())[0 ]
39
+
40
+
41
+ def _get_distances_ (self , X , model , mode = 'l2' ):
42
+ """
43
+ """
44
+ distances = []
45
+ weights = []
46
+ children = model .children_
47
+
48
+ dims = (X .shape [1 ],1 )
49
+ distCache = {}
50
+ weightCache = {}
51
+ for childs in children :
52
+ c1 = X [childs [0 ]].reshape (dims )
53
+ c2 = X [childs [1 ]].reshape (dims )
54
+ c1Dist = 0
55
+ c1W = 1
56
+ c2Dist = 0
57
+ c2W = 1
58
+ if childs [0 ] in distCache .keys ():
59
+ c1Dist = distCache [childs [0 ]]
60
+ c1W = weightCache [childs [0 ]]
61
+ if childs [1 ] in distCache .keys ():
62
+ c2Dist = distCache [childs [1 ]]
63
+ c2W = weightCache [childs [1 ]]
64
+ d = np .linalg .norm (c1 - c2 )
65
+ # d = np.squeeze(np.dot(c1.T, c2)/ (np.linalg.norm(c1)*np.linalg.norm(c2)))
66
+ cc = ((c1W * c1 )+ (c2W * c2 ))/ (c1W + c2W )
67
+
68
+ X = np .vstack ((X ,cc .T ))
69
+
70
+ newChild_id = X .shape [0 ]- 1
71
+
72
+ # How to deal with a higher level cluster merge with lower distance:
73
+ if mode == 'l2' : # Increase the higher level cluster size suing an l2 norm
74
+ added_dist = ((c1Dist ** 2 + c2Dist ** 2 )** 0.5 )
75
+ dNew = (d ** 2 + added_dist ** 2 )** 0.5
76
+ elif mode == 'max' : # If the previrous clusters had higher distance, use that one
77
+ dNew = max (d ,c1Dist ,c2Dist )
78
+ elif mode == 'cosine' :
79
+ dNew = np .squeeze (np .dot (c1Dist , c2Dist )/ (np .linalg .norm (c1Dist )* np .linalg .norm (c2Dist )))
80
+ elif mode == 'actual' : # Plot the actual distance.
81
+ dNew = d
82
+
83
+ wNew = (c1W + c2W )
84
+ distCache [newChild_id ] = dNew
85
+ weightCache [newChild_id ] = wNew
86
+
87
+ distances .append (dNew )
88
+ weights .append (wNew )
89
+ return distances , weights
90
+
91
+
92
+ def _plot_dendrogram_ (self , X , model , threshold = .7 ):
93
+ """
94
+ """
95
+
96
+ # Create linkage matrix and then plot the dendrogram
97
+ distance , weight = self ._get_distances_ (X ,model )
98
+ linkage_matrix = np .column_stack ([model .children_ , distance , weight ]).astype (float )
99
+
100
+ threshold = threshold * np .max (distance )
101
+
102
+ sorted_ = linkage_matrix [np .argsort (distance )]
103
+ splitnode = np .max (sorted_ [sorted_ [:, 2 ] > threshold ][0 , (0 ,1 )])
104
+
105
+ level = np .log ((- .5 * splitnode )/ (1. * X .shape [0 ]) + 1. )/ np .log (.5 )
106
+ nclusters = int (np .round ((1. * X .shape [0 ])/ (2. ** level ))) - 1
107
+
108
+ model = AgglomerativeClustering (n_clusters = max (2 , nclusters )).fit (X )
109
+ distance , weight = self ._get_distances_ (X , model )
110
+ linkage_matrix = np .column_stack ([model .children_ , distance , weight ]).astype (float )
111
+ labels = model .labels_
112
+
113
+ sil = silhouette_score (X , labels , metric = 'euclidean' )
114
+ print ("[INFO: BioExp Clustering] Layer: {}, Nclusters: {}, Labels: {}, Freq. of each labels: {} Clustering Score: {}" .format (self .layer , nclusters , np .unique (labels ), [sum (labels == i ) for i in np .unique (labels )], sil ))
115
+ # Plot the corresponding dendrogram
116
+
117
+ return linkage_matrix , labels
118
+
119
+
120
+ def get_clusters (self , threshold = 0.8 ,
121
+ normalize = False ,
122
+ position = True ,
123
+ save_path = None ):
124
+ """
125
+ Does clustering on feature space
126
+
127
+ save_path : path to save dendrogram image
128
+ threshold : fraction of max distance to cluster
129
+ normalize : to squeeze values between 0, 1
130
+ position : encode position information
131
+ """
132
+
133
+ shape = np .array (self .weights .shape )
134
+
135
+ coord = []
136
+ for sh in shape [:- 2 ]:
137
+ coord .append (np .linspace (0 , (1. if normalize else sh ), sh ))
138
+
139
+ distance = np .sqrt (np .sum ([x ** 2 for x in np .meshgrid (* coord , indexing = 'ij' )]))
140
+ distance = distance [..., None ]
141
+
142
+ X = np .mean (self .weights , axis = - 2 )
143
+ # X = self.weights
144
+
145
+ if normalize : X = (X - np .max (X ))/ (np .max (X ) - np .min (X ))
146
+ if position : X = X * distance
147
+
148
+ X = X .reshape (- 1 , shape [- 1 ]).T
149
+ model = AgglomerativeClustering ().fit (X )
150
+
151
+ # plot the top three levels of the dendrogram
152
+ linkage_matrix , labels = self ._plot_dendrogram_ (X , model , threshold = threshold )
153
+
154
+ plt .figure (figsize = (20 , 10 ))
155
+ plt .title ('Hierarchical Clustering Dendrogram' )
156
+ R = dendrogram (linkage_matrix , truncate_mode = 'level' )
157
+ plt .xlabel ("Number of points in node (or index of point if no parenthesis)." )
158
+
159
+ if save_path :
160
+ os .makedirs (save_path , exist_ok = True )
161
+ plt .savefig (os .path .join (save_path , '{}_dendrogram.png' .format (self .layer )), bbox_inches = 'tight' )
162
+ self .plot_silhouette (X , labels , save_path )
163
+ else :
164
+ plt .show ()
165
+
166
+ return labels
167
+
168
+ def plot_silhouette (self , X , labels , save_path ):
169
+ r"""
170
+ """
171
+ fig = plt .figure ()
172
+ fig .set_size_inches (10 , 5 )
173
+ n_clusters = len (np .unique (labels ))
174
+ y_lower = 10
175
+ plt .xlim ([- 0.1 , 0.3 ])
176
+ plt .ylim ([0 , len (X ) + (n_clusters + 1 ) * 10 ])
177
+ svalues = silhouette_samples (X , labels )
178
+ silhouette_avg = np .mean (svalues )
179
+
180
+ for i in np .unique (labels ):
181
+ ith_cluster_silhouette_values = svalues [labels == i ]
182
+ ith_cluster_silhouette_values .sort ()
183
+
184
+ size_cluster_i = ith_cluster_silhouette_values .shape [0 ]
185
+ y_upper = y_lower + size_cluster_i
186
+
187
+ color = plt .cm .nipy_spectral (float (i ) / n_clusters )
188
+ plt .fill_betweenx (np .arange (y_lower , y_upper ),
189
+ 0 , ith_cluster_silhouette_values ,
190
+ facecolor = color , edgecolor = color , alpha = 0.7 )
191
+
192
+ # Label the silhouette plots with their cluster numbers at the middle
193
+ plt .text (- 0.05 , y_lower + 0.5 * size_cluster_i , str (i ))
194
+
195
+ # Compute the new y_lower for next plot
196
+ y_lower = y_upper + 10 # 10 for the 0 samples
197
+
198
+ # ax[idx].set_title("The silhouette plot for the various clusters.")
199
+ plt .xlabel ("The silhouette coefficient values" )
200
+ plt .ylabel ("Cluster label" )
201
+
202
+ # The vertical line for average silhouette score of all the values
203
+ plt .axvline (x = silhouette_avg , color = "red" , linestyle = "--" )
204
+
205
+ plt .yticks ([]) # Clear the yaxis labels / ticks
206
+ # plt.xticks([-0.1, 0, 0.2, 0.4, 0.6, 0.8, 1])
207
+
208
+ plt .suptitle (("Silhouette analysis for KMeans clustering on sample data "
209
+ "with n_clusters = %d" % n_clusters ),
210
+ fontsize = 14 , fontweight = 'bold' )
211
+
212
+
213
+ plt .savefig (os .path .join (save_path , 'layer_{}__silhouette_score.png' .format (self .layer_idx )), dpi = 200 , bbox_inches = 'tight' )
214
+
215
+
216
+ def plot_weights (self , labels , save_path = None ):
217
+ """
218
+ dim x: k x k x in_c x out_c
219
+ """
220
+ shape = self .weights .shape
221
+ normweights = (self .weights - np .min (self .weights ))/ (np .max (self .weights ) - np .min (self .weights ))
222
+ features = []
223
+ for label in np .unique (labels ):
224
+ wts_idx = np .where (labels == label )[0 ]
225
+ wts = normweights [:,:,:,wts_idx ].T
226
+ wts = wts .reshape (len (wts_idx ), - 1 )
227
+
228
+ features .extend (wts )
229
+ features .extend (np .zeros ((3 , wts .shape [1 ])))
230
+ """
231
+ feature = np.zeros((s, shape[1]*cls))
232
+ for ii in wt_idx:
233
+ wt = self.weights[:,:,:, ii]
234
+ for i in range(rws):
235
+ for j in range(cls):
236
+ try:
237
+ feature[i*shape[0]: (i + 1)*shape[0],
238
+ j*shape[1]: (j + 1)*shape[1]] = wt[:, :, j*rws + i]
239
+ except:
240
+ pass
241
+
242
+ plt.clf()
243
+ plt.imshow(feature)
244
+ if not save_path:
245
+ plt.show()
246
+ else:
247
+ os.makedirs(save_path, exist_ok = True)
248
+ plt.savefig(os.path.join(save_path, 'cluster_{}_idx_{}.png'.format(label, ii)), bbox_inches='tight')
249
+ """
250
+ plt .clf ()
251
+ plt .imshow (wts , cmap = 'jet' )
252
+ if not save_path :
253
+ plt .show ()
254
+ else :
255
+ os .makedirs (save_path , exist_ok = True )
256
+ plt .savefig (os .path .join (save_path , 'layer_{}__concept_{}.png' .format (self .layer_idx , label )), dpi = 200 , bbox_inches = 'tight' )
257
+
258
+ plt .clf ()
259
+ plt .imshow (features , cmap = 'jet' )
260
+ if not save_path :
261
+ plt .show ()
262
+ else :
263
+ os .makedirs (save_path , exist_ok = True )
264
+ plt .savefig (os .path .join (save_path , 'layer_{}__all_concepts.png' .format (self .layer_idx )), dpi = 200 , bbox_inches = 'tight' )
0 commit comments