Skip to content

Commit d8e2c27

Browse files
committed
Update 3/6
2 parents ea0c996 + c0e4c0e commit d8e2c27

File tree

193 files changed

+3248
-319
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

193 files changed

+3248
-319
lines changed

.github/FUNDING.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# These are supported funding model platforms
2+
3+
ko_fi: koriavinash1
4+
liberapay: koriavinash1

.github/workflows/pythonpackage.yml

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
2+
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
3+
4+
name: Python package
5+
6+
on:
7+
push:
8+
branches: [ master ]
9+
pull_request:
10+
branches: [ master ]
11+
12+
jobs:
13+
build:
14+
15+
runs-on: ubuntu-latest
16+
strategy:
17+
matrix:
18+
python-version: [3.5, 3.6, 3.7, 3.8]
19+
20+
steps:
21+
- uses: actions/checkout@v2
22+
- name: Set up Python ${{ matrix.python-version }}
23+
uses: actions/setup-python@v1
24+
with:
25+
python-version: ${{ matrix.python-version }}
26+
- name: Install dependencies
27+
run: |
28+
python -m pip install --upgrade pip
29+
pip install -r requirements.txt
30+
- name: Lint with flake8
31+
run: |
32+
pip install flake8
33+
# stop the build if there are Python syntax errors or undefined names
34+
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
35+
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
36+
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
37+
- name: Test with pytest
38+
run: |
39+
pip install pytest
40+
pytest

.github/workflows/pythonpublish.yml

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# This workflows will upload a Python Package using Twine when a release is created
2+
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3+
4+
name: Upload Python Package
5+
6+
on:
7+
release:
8+
types: [created]
9+
10+
jobs:
11+
deploy:
12+
13+
runs-on: ubuntu-latest
14+
15+
steps:
16+
- uses: actions/checkout@v2
17+
- name: Set up Python
18+
uses: actions/setup-python@v1
19+
with:
20+
python-version: '3.x'
21+
- name: Install dependencies
22+
run: |
23+
python -m pip install --upgrade pip
24+
pip install setuptools wheel twine
25+
- name: Build and publish
26+
env:
27+
TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
28+
TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
29+
run: |
30+
python setup.py sdist bdist_wheel
31+
twine upload dist/*

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
# Byte-compiled / optimized / DLL files
22
__pycache__/
3+
*Logs*
4+
saved_models/*
35
*trained_models*
6+
*saved_models*
47
*results*
58
*.py[cod]
69
*$py.class

.travis.yml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
language: python
2+
python:
3+
- "3.6"
4+
- "3.5"
5+
install:
6+
- pip install -r requirements.txt
7+
- pip install .
8+
# command to run tests
9+
script: pytest

BioExp/RCT/rct.py

100644100755
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def __init__(self, model, test_path):
2323
self.vol_path = glob(test_path)
2424
self.test_image, self.gt = load_vol_brats(self.vol_path[3], slicen = 78, pad = 0)
2525

26+
2627
def mean_swap(self, plot = True, save_path='/home/parth/Interpretable_ML/BioExp/results/RCT'):
2728

2829
channel = 3
@@ -269,3 +270,4 @@ def generate_random_classification(self, mode='random'):
269270
# I.generate_random_classification(mode='swap')
270271
# I.mean_swap(plot = False)
271272

273+

BioExp/clusters/clusters.py

Lines changed: 264 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,264 @@
1+
import matplotlib
2+
matplotlib.use('Agg')
3+
import keras
4+
import numpy as np
5+
import tensorflow as tf
6+
import os
7+
from matplotlib import pyplot as plt
8+
from scipy.cluster.hierarchy import dendrogram
9+
from sklearn.cluster import AgglomerativeClustering
10+
from sklearn.metrics import silhouette_score, silhouette_samples
11+
12+
class Cluster():
13+
"""
14+
A class for conducting an cluster study on a trained keras model instance
15+
16+
"""
17+
18+
19+
def __init__(self, model, weights_pth, layer_name, max_clusters = None):
20+
21+
"""
22+
model : keras model architecture (keras.models.Model)
23+
weights_pth : saved weights path (str)
24+
metric : metric to compare prediction with gt, for example dice, CE
25+
layer_name : name of the layer which needs to be ablated
26+
test_img : test image used for ablation
27+
max_clusters: maximum number of clusters
28+
"""
29+
30+
self.model = model
31+
self.weights = weights_pth
32+
self.model.load_weights(self.weights)
33+
self.layer = layer_name
34+
self.layer_idx = 0
35+
for idx, layer in enumerate(self.model.layers):
36+
if layer.name == self.layer:
37+
self.layer_idx = idx
38+
self.weights = np.array(self.model.layers[self.layer_idx].get_weights())[0]
39+
40+
41+
def _get_distances_(self, X, model, mode='l2'):
42+
"""
43+
"""
44+
distances = []
45+
weights = []
46+
children=model.children_
47+
48+
dims = (X.shape[1],1)
49+
distCache = {}
50+
weightCache = {}
51+
for childs in children:
52+
c1 = X[childs[0]].reshape(dims)
53+
c2 = X[childs[1]].reshape(dims)
54+
c1Dist = 0
55+
c1W = 1
56+
c2Dist = 0
57+
c2W = 1
58+
if childs[0] in distCache.keys():
59+
c1Dist = distCache[childs[0]]
60+
c1W = weightCache[childs[0]]
61+
if childs[1] in distCache.keys():
62+
c2Dist = distCache[childs[1]]
63+
c2W = weightCache[childs[1]]
64+
d = np.linalg.norm(c1-c2)
65+
# d = np.squeeze(np.dot(c1.T, c2)/ (np.linalg.norm(c1)*np.linalg.norm(c2)))
66+
cc = ((c1W*c1)+(c2W*c2))/(c1W+c2W)
67+
68+
X = np.vstack((X,cc.T))
69+
70+
newChild_id = X.shape[0]-1
71+
72+
# How to deal with a higher level cluster merge with lower distance:
73+
if mode=='l2': # Increase the higher level cluster size suing an l2 norm
74+
added_dist = ((c1Dist**2+c2Dist**2)**0.5)
75+
dNew = (d**2 + added_dist**2)**0.5
76+
elif mode == 'max': # If the previrous clusters had higher distance, use that one
77+
dNew = max(d,c1Dist,c2Dist)
78+
elif mode == 'cosine':
79+
dNew = np.squeeze(np.dot(c1Dist, c2Dist)/ (np.linalg.norm(c1Dist)*np.linalg.norm(c2Dist)))
80+
elif mode == 'actual': # Plot the actual distance.
81+
dNew = d
82+
83+
wNew = (c1W + c2W)
84+
distCache[newChild_id] = dNew
85+
weightCache[newChild_id] = wNew
86+
87+
distances.append(dNew)
88+
weights.append(wNew)
89+
return distances, weights
90+
91+
92+
def _plot_dendrogram_(self, X, model, threshold=.7):
93+
"""
94+
"""
95+
96+
# Create linkage matrix and then plot the dendrogram
97+
distance, weight = self._get_distances_(X,model)
98+
linkage_matrix = np.column_stack([model.children_, distance, weight]).astype(float)
99+
100+
threshold = threshold*np.max(distance)
101+
102+
sorted_ = linkage_matrix[np.argsort(distance)]
103+
splitnode = np.max(sorted_[sorted_[:, 2] > threshold][0, (0,1)])
104+
105+
level = np.log((-.5*splitnode)/(1.*X.shape[0]) + 1.)/np.log(.5)
106+
nclusters = int(np.round((1.*X.shape[0])/(2.**level))) - 1
107+
108+
model = AgglomerativeClustering(n_clusters=max(2, nclusters)).fit(X)
109+
distance, weight = self._get_distances_(X, model)
110+
linkage_matrix = np.column_stack([model.children_, distance, weight]).astype(float)
111+
labels = model.labels_
112+
113+
sil = silhouette_score(X, labels, metric='euclidean')
114+
print ("[INFO: BioExp Clustering] Layer: {}, Nclusters: {}, Labels: {}, Freq. of each labels: {} Clustering Score: {}".format(self.layer, nclusters, np.unique(labels), [sum(labels == i) for i in np.unique(labels)], sil))
115+
# Plot the corresponding dendrogram
116+
117+
return linkage_matrix, labels
118+
119+
120+
def get_clusters(self, threshold=0.8,
121+
normalize=False,
122+
position=True,
123+
save_path = None):
124+
"""
125+
Does clustering on feature space
126+
127+
save_path : path to save dendrogram image
128+
threshold : fraction of max distance to cluster
129+
normalize : to squeeze values between 0, 1
130+
position : encode position information
131+
"""
132+
133+
shape = np.array(self.weights.shape)
134+
135+
coord = []
136+
for sh in shape[:-2]:
137+
coord.append(np.linspace(0, (1. if normalize else sh), sh))
138+
139+
distance = np.sqrt(np.sum([x**2 for x in np.meshgrid(*coord, indexing='ij')]))
140+
distance = distance[..., None]
141+
142+
X = np.mean(self.weights, axis=-2)
143+
# X = self.weights
144+
145+
if normalize: X = (X - np.max(X))/(np.max(X) - np.min(X))
146+
if position: X = X*distance
147+
148+
X = X.reshape(-1, shape[-1]).T
149+
model = AgglomerativeClustering().fit(X)
150+
151+
# plot the top three levels of the dendrogram
152+
linkage_matrix, labels = self._plot_dendrogram_(X, model, threshold = threshold)
153+
154+
plt.figure(figsize=(20, 10))
155+
plt.title('Hierarchical Clustering Dendrogram')
156+
R = dendrogram(linkage_matrix, truncate_mode='level')
157+
plt.xlabel("Number of points in node (or index of point if no parenthesis).")
158+
159+
if save_path:
160+
os.makedirs(save_path, exist_ok=True)
161+
plt.savefig(os.path.join(save_path, '{}_dendrogram.png'.format(self.layer)), bbox_inches='tight')
162+
self.plot_silhouette(X, labels, save_path)
163+
else:
164+
plt.show()
165+
166+
return labels
167+
168+
def plot_silhouette(self, X, labels, save_path):
169+
r"""
170+
"""
171+
fig = plt.figure()
172+
fig.set_size_inches(10, 5)
173+
n_clusters = len(np.unique(labels))
174+
y_lower = 10
175+
plt.xlim([-0.1, 0.3])
176+
plt.ylim([0, len(X) + (n_clusters + 1) * 10])
177+
svalues = silhouette_samples(X, labels)
178+
silhouette_avg = np.mean(svalues)
179+
180+
for i in np.unique(labels):
181+
ith_cluster_silhouette_values = svalues[labels == i]
182+
ith_cluster_silhouette_values.sort()
183+
184+
size_cluster_i = ith_cluster_silhouette_values.shape[0]
185+
y_upper = y_lower + size_cluster_i
186+
187+
color = plt.cm.nipy_spectral(float(i) / n_clusters)
188+
plt.fill_betweenx(np.arange(y_lower, y_upper),
189+
0, ith_cluster_silhouette_values,
190+
facecolor=color, edgecolor=color, alpha=0.7)
191+
192+
# Label the silhouette plots with their cluster numbers at the middle
193+
plt.text(-0.05, y_lower + 0.5 * size_cluster_i, str(i))
194+
195+
# Compute the new y_lower for next plot
196+
y_lower = y_upper + 10 # 10 for the 0 samples
197+
198+
# ax[idx].set_title("The silhouette plot for the various clusters.")
199+
plt.xlabel("The silhouette coefficient values")
200+
plt.ylabel("Cluster label")
201+
202+
# The vertical line for average silhouette score of all the values
203+
plt.axvline(x=silhouette_avg, color="red", linestyle="--")
204+
205+
plt.yticks([]) # Clear the yaxis labels / ticks
206+
# plt.xticks([-0.1, 0, 0.2, 0.4, 0.6, 0.8, 1])
207+
208+
plt.suptitle(("Silhouette analysis for KMeans clustering on sample data "
209+
"with n_clusters = %d" % n_clusters),
210+
fontsize=14, fontweight='bold')
211+
212+
213+
plt.savefig(os.path.join(save_path, 'layer_{}__silhouette_score.png'.format(self.layer_idx)), dpi=200, bbox_inches='tight')
214+
215+
216+
def plot_weights(self, labels, save_path=None):
217+
"""
218+
dim x: k x k x in_c x out_c
219+
"""
220+
shape = self.weights.shape
221+
normweights = (self.weights - np.min(self.weights))/(np.max(self.weights) - np.min(self.weights))
222+
features = []
223+
for label in np.unique(labels):
224+
wts_idx = np.where(labels==label)[0]
225+
wts = normweights[:,:,:,wts_idx].T
226+
wts = wts.reshape(len(wts_idx), -1)
227+
228+
features.extend(wts)
229+
features.extend(np.zeros((3, wts.shape[1])))
230+
"""
231+
feature = np.zeros((s, shape[1]*cls))
232+
for ii in wt_idx:
233+
wt = self.weights[:,:,:, ii]
234+
for i in range(rws):
235+
for j in range(cls):
236+
try:
237+
feature[i*shape[0]: (i + 1)*shape[0],
238+
j*shape[1]: (j + 1)*shape[1]] = wt[:, :, j*rws + i]
239+
except:
240+
pass
241+
242+
plt.clf()
243+
plt.imshow(feature)
244+
if not save_path:
245+
plt.show()
246+
else:
247+
os.makedirs(save_path, exist_ok = True)
248+
plt.savefig(os.path.join(save_path, 'cluster_{}_idx_{}.png'.format(label, ii)), bbox_inches='tight')
249+
"""
250+
plt.clf()
251+
plt.imshow(wts, cmap='jet')
252+
if not save_path:
253+
plt.show()
254+
else:
255+
os.makedirs(save_path, exist_ok = True)
256+
plt.savefig(os.path.join(save_path, 'layer_{}__concept_{}.png'.format(self.layer_idx, label)), dpi=200, bbox_inches='tight')
257+
258+
plt.clf()
259+
plt.imshow(features, cmap='jet')
260+
if not save_path:
261+
plt.show()
262+
else:
263+
os.makedirs(save_path, exist_ok = True)
264+
plt.savefig(os.path.join(save_path, 'layer_{}__all_concepts.png'.format(self.layer_idx)), dpi=200, bbox_inches='tight')

0 commit comments

Comments
 (0)