Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
67db171
idea of what the code can do
ntalluri Apr 5, 2023
564bc9a
finsished error checking and parser
ntalluri Apr 26, 2023
984a4b2
fixed ml tests
ntalluri Apr 26, 2023
5603fc0
still trying to fix test_ml
ntalluri Apr 26, 2023
d73cb85
hac_h and hac_v added
ntalluri Apr 26, 2023
73e8584
fixed tests
ntalluri Apr 27, 2023
97d1735
Revert "fixed tests"
ntalluri May 3, 2023
198b0ab
added the custom color palette and cleaned up a couple lines
ntalluri May 3, 2023
b52645f
removed excess spaces
ntalluri May 3, 2023
c607f18
fixed longernames2 test
ntalluri May 3, 2023
6223a95
Fix merge conflicts
agitter May 3, 2023
c83a8e5
fixed the test cases once more
ntalluri May 3, 2023
dcc84d9
Revert "fixed the test cases once more"
ntalluri May 10, 2023
ec20801
made slight changes to ml.py but will continue to change more, update…
ntalluri Jun 21, 2023
07f8a57
Merge branch 'master' into version2
agitter Jun 21, 2023
90cb60c
fixed util.py and cleaned up code
ntalluri Jun 22, 2023
490b147
Merge branch 'version2' of github.com:ntalluri/spras into version2
ntalluri Jun 22, 2023
ad779cf
used pre-commit
ntalluri Jun 22, 2023
2c13a89
allowed for matching coloring across plots
ntalluri Jun 23, 2023
57b93e6
reviewed code
ntalluri Jun 23, 2023
a6d0fcc
fixed metrics and linkages
ntalluri Jul 1, 2023
5464358
Rename PCA components file to variance
agitter Jul 3, 2023
3b483bb
Increase figure resolution
agitter Jul 3, 2023
e8e67c5
Fix EGFR config file for ML
agitter Jul 3, 2023
8b36293
PyCharm formatting suggestions
agitter Jul 3, 2023
26fe4a8
update to pca files and test
ntalluri Jul 5, 2023
91d6155
fixed pca files and tests
ntalluri Jul 5, 2023
739031c
final checks
ntalluri Jul 5, 2023
f4f8ad1
Fix algorithm name extraction from path
agitter Jul 6, 2023
a36f32e
Fix test failures
agitter Jul 6, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 15 additions & 10 deletions Snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ SEP = '/'
wildcard_constraints:
params="params-\w+"

config, datasets, out_dir, algorithm_params, algorithm_directed = process_config(config)
config, datasets, out_dir, algorithm_params, algorithm_directed, pca_params, hac_params = process_config(config)

# TODO consider the best way to pass global configuration information to the run functions
SINGULARITY = "singularity" in config and config["singularity"]
Expand Down Expand Up @@ -71,10 +71,12 @@ def make_final_input(wildcards):

if config["analysis"]["ml"]["include"]:
final_input.extend(expand('{out_dir}{sep}{dataset}-pca.png',out_dir=out_dir,sep=SEP,dataset=dataset_labels,algorithm_params=algorithms_with_params))
final_input.extend(expand('{out_dir}{sep}{dataset}-pca-components.txt',out_dir=out_dir,sep=SEP,dataset=dataset_labels,algorithm_params=algorithms_with_params))
final_input.extend(expand('{out_dir}{sep}{dataset}-hac.png',out_dir=out_dir,sep=SEP,dataset=dataset_labels,algorithm_params=algorithms_with_params))
final_input.extend(expand('{out_dir}{sep}{dataset}-hac-clusters.txt',out_dir=out_dir,sep=SEP,dataset=dataset_labels,algorithm_params=algorithms_with_params))
final_input.extend(expand('{out_dir}{sep}{dataset}-pca-variance.txt',out_dir=out_dir,sep=SEP,dataset=dataset_labels,algorithm_params=algorithms_with_params))
final_input.extend(expand('{out_dir}{sep}{dataset}-hac-vertical.png',out_dir=out_dir,sep=SEP,dataset=dataset_labels,algorithm_params=algorithms_with_params))
final_input.extend(expand('{out_dir}{sep}{dataset}-hac-clusters-vertical.txt',out_dir=out_dir,sep=SEP,dataset=dataset_labels,algorithm_params=algorithms_with_params))
final_input.extend(expand('{out_dir}{sep}{dataset}-pca-coordinates.txt',out_dir=out_dir,sep=SEP,dataset=dataset_labels,algorithm_params=algorithms_with_params))
final_input.extend(expand('{out_dir}{sep}{dataset}-hac-horizontal.png',out_dir=out_dir,sep=SEP,dataset=dataset_labels,algorithm_params=algorithms_with_params))
final_input.extend(expand('{out_dir}{sep}{dataset}-hac-clusters-horizontal.txt',out_dir=out_dir,sep=SEP,dataset=dataset_labels,algorithm_params=algorithms_with_params))

if len(final_input) == 0:
# No analysis added yet, so add reconstruction output files if they exist.
Expand Down Expand Up @@ -250,15 +252,18 @@ rule ml_analysis:
pathways = expand('{out_dir}{sep}{{dataset}}-{algorithm_params}{sep}pathway.txt', out_dir=out_dir, sep=SEP, algorithm_params=algorithms_with_params)
output:
pca_image = SEP.join([out_dir, '{dataset}-pca.png']),
pca_components= SEP.join([out_dir, '{dataset}-pca-components.txt']),
pca_variance= SEP.join([out_dir, '{dataset}-pca-variance.txt']),
pca_coordinates = SEP.join([out_dir, '{dataset}-pca-coordinates.txt']),
hac_image = SEP.join([out_dir, '{dataset}-hac.png']),
hac_clusters = SEP.join([out_dir, '{dataset}-hac-clusters.txt'])
hac_image_vertical = SEP.join([out_dir, '{dataset}-hac-vertical.png']),
hac_clusters_vertical = SEP.join([out_dir, '{dataset}-hac-clusters-vertical.txt']),
hac_image_horizontal = SEP.join([out_dir, '{dataset}-hac-horizontal.png']),
hac_clusters_horizontal = SEP.join([out_dir, '{dataset}-hac-clusters-horizontal.txt']),

run:
summary_df = ml.summarize_networks(input.pathways)
ml.pca(summary_df, output.pca_image, output.pca_components, output.pca_coordinates)
ml.hac(summary_df, output.hac_image, output.hac_clusters)

ml.pca(summary_df, output.pca_image, output.pca_variance, output.pca_coordinates, **pca_params)
ml.hac_vertical(summary_df, output.hac_image_vertical, output.hac_clusters_vertical, **hac_params)
ml.hac_horizontal(summary_df, output.hac_image_horizontal, output.hac_clusters_horizontal, **hac_params)

# Remove the output directory
rule clean:
Expand Down
9 changes: 9 additions & 0 deletions config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,12 @@
# Machine learning analysis (e.g. clustering) of the pathway output files for each dataset
ml:
include: true
# specify how many principal components to calculate
components: 2
# boolean to show the labels on the pca graph
labels: true
# 'ward', 'complete', 'average', 'single'
# if linkage: ward, must use metric: euclidean
linkage: 'ward'
# 'euclidean', 'manhattan', 'cosine'
metric: 'euclidean'
2 changes: 2 additions & 0 deletions config/egfr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,5 @@ analysis:
include: false
summary:
include: true
ml:
include: false
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ name: spras
channels:
- conda-forge
dependencies:
- adjusttext=0.7.3.1
- bioconda::snakemake-minimal=7.18.2
- docker-py=5.0
- matplotlib=3.5
Expand Down
180 changes: 144 additions & 36 deletions src/analysis/ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import numpy as np
import pandas as pd
import seaborn as sns
from scipy.cluster.hierarchy import dendrogram
from adjustText import adjust_text
from scipy.cluster.hierarchy import dendrogram, fcluster
from sklearn.cluster import AgglomerativeClustering
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
Expand All @@ -15,7 +16,10 @@

plt.switch_backend('Agg')

linkage_methods = ["ward", "complete", "average", "single"]
distance_metrics = ["euclidean", "manhattan", "cosine"]
NODE_SEP = '|||' # separator between nodes when forming edges in the dataframe
DPI = 300


def summarize_networks(file_paths: Iterable[Union[str, PathLike]]) -> pd.DataFrame:
Expand All @@ -29,7 +33,6 @@ def summarize_networks(file_paths: Iterable[Union[str, PathLike]]) -> pd.DataFra
"""
# creating a tuple that contains the algorithm column name and edge pairs
edge_tuples = []

for file in file_paths:
try:
# collecting and sorting the edge pairs per algortihm
Expand All @@ -54,7 +57,6 @@ def summarize_networks(file_paths: Iterable[Union[str, PathLike]]) -> pd.DataFra

# initially construct separate dataframes per algorithm
edge_dataframes = []

# the dataframe is set up per algorithm and a 1 is set for the edge pair that exists in the algorithm
for tup in edge_tuples:
dataframe = pd.DataFrame(
Expand All @@ -73,53 +75,86 @@ def summarize_networks(file_paths: Iterable[Union[str, PathLike]]) -> pd.DataFra
return concated_df


def pca(dataframe: pd.DataFrame, output_png: str, output_file: str, output_coord: str):
def create_palette(column_names):
"""
Generates a dictionary mapping each column name (algorithm name)
to a unique color from the specified palette.
"""
# TODO: could add a way for the user to customize the color palette?
custom_palette = sns.color_palette("husl", len(column_names))
label_color_map = {label: color for label, color in zip(column_names, custom_palette)}
return label_color_map


def pca(dataframe: pd.DataFrame, output_png: str, output_var: str, output_coord: str, components: int = 2, labels: bool = True):
"""
Performs PCA on the data and creates a scatterplot of the top two principal components.
It saves the plot, the variance explained by each component, and the
coordinates corresponding to the plot of each algorithm in a separate file.
@param dataframe: binary dataframe of edge comparison between algorithms from summarize_networks
@param output_png: the filename to save the scatterplot
@param output_file: the filename to save the variance explained by each component
@param output_var: the filename to save the variance explained by each component
@param output_coord: the filename to save the coordinates of each algorithm
@param components: the number of principal components to calculate (Default is 2)
@param labels: determines if labels will be included in the scatterplot (Default is True)
"""
df = dataframe.reset_index(drop=True)
columns = dataframe.columns
column_names = [element.split('-')[-3] for element in columns] # assume algorithm names do not contain '-'
df = df.transpose() # based on the algorithms rather than the edges
X = df.values

min_shape = min(df.shape)
if components < 2:
raise ValueError(f"components={components} must be greater than or equal to 2 in the config file.")
elif components > min_shape:
print(f"components={components} is not valid. Setting components to {min_shape}.")
components = min_shape
if not isinstance(labels, bool):
raise ValueError(f"labels={labels} must be True or False")

scaler = StandardScaler()
scaler.fit(X) # calc mean and standard deviation
X_scaled = scaler.transform(X)

# choosing the PCA
pca_2 = PCA(n_components=2)
pca_2.fit(X_scaled)
X_pca_2 = pca_2.transform(X_scaled)
variance = pca_2.explained_variance_ratio_ * 100
pca_instance = PCA(n_components=components)
pca_instance.fit(X_scaled)
X_pca = pca_instance.transform(X_scaled)
variance = pca_instance.explained_variance_ratio_ * 100

# making the plot
label_color_map = create_palette(column_names)
plt.figure(figsize=(10, 7))
sns.scatterplot(x=X_pca_2[:, 0], y=X_pca_2[:, 1], s=70)
sns.scatterplot(x=X_pca[:, 0], y=X_pca[:, 1], s=70, hue=column_names, legend=True, palette=label_color_map)
plt.title("PCA")
plt.xlabel(f"PC1 ({variance[0]:.1f}% variance)")
plt.ylabel(f"PC2 ({variance[1]:.1f}% variance)")

# saving the PCA plot
make_required_dirs(output_png)
plt.savefig(output_png)
# saving the coordinates of each algorithm
make_required_dirs(output_coord)
coordinates_df = pd.DataFrame(X_pca, columns = ['PC' + str(i) for i in range(1, components+1)])
coordinates_df.insert(0, 'algorithm', columns.tolist())
coordinates_df.to_csv(output_coord, sep='\t', index=False)

# saving the principal components
make_required_dirs(output_file)
with open(output_file, "w") as f:
for component in variance:
f.write("%s\n" % component)
make_required_dirs(output_var)
with open(output_var, "w") as f:
for component in range(len(variance)):
f.write("PC%d: %s\n" % (component+1, variance[component]))

# saving the coordinates of each algorithm
columns = dataframe.columns.tolist()
data = {'algorithm': columns, 'x': X_pca_2[:, 0], 'y': X_pca_2[:, 1]}
df = pd.DataFrame(data)
make_required_dirs(output_coord)
df.to_csv(output_coord, sep='\t', index=False)
# labeling the graphs
if labels:
x_coord = coordinates_df['PC1'].to_numpy()
y_coord = coordinates_df['PC2'].to_numpy()
texts = []
for i, algorithm in enumerate(column_names):
texts.append(plt.text(x_coord[i], y_coord[i], algorithm, size=10))
adjust_text(texts, force_points=(5.0, 5.0), arrowprops=dict(arrowstyle='->', color='black'))

# saving the PCA plot
make_required_dirs(output_png)
plt.savefig(output_png, dpi=DPI)


# This function is taken from the scikit-learn version 1.2.1 example code
Expand Down Expand Up @@ -153,31 +188,104 @@ def plot_dendrogram(model, **kwargs):
dendrogram(linkage_matrix, **kwargs)


def hac(dataframe: pd.DataFrame, output_png: str, output_file: str):
def hac_vertical(dataframe: pd.DataFrame, output_png: str, output_file: str, linkage: str = 'ward', metric: str = 'euclidean'):
"""
Performs hierarchical agglomerative clustering on the dataframe,
creates a dendrogram of the resulting tree,
creates a dendrogram of the resulting tree using seaborn and scipy for the cluster groups,
and saves the dendrogram and the cluster labels of said dendrogram in separate files.
@param dataframe: binary dataframe of edge comparison between algorithms from summarize_networks
@param output_png: the file name to save the dendrogram image
@param output_file: the file name to save the clustering labels
@param linkage: methods for calculating the distance between clusters
@param metric: used for distance computation between instances of clusters
"""
X = dataframe.reset_index(drop=True)
X = X.transpose()
model = AgglomerativeClustering(distance_threshold=0.5, n_clusters=None)
model = model.fit(X)
if linkage not in linkage_methods:
raise ValueError(f"linkage={linkage} must be one of {linkage_methods}")
if metric not in distance_metrics:
raise ValueError(f"metric={metric} must be one of {distance_metrics}")
if metric == "manhattan":
# clustermap does not support manhattan as a metric but cityblock is equivalent per
# https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.distance.pdist.html#scipy.spatial.distance.pdist
metric = "cityblock"
if linkage == "ward":
if metric != "euclidean":
print("For linkage='ward', the metric must be 'euclidean'; setting metric = 'euclidean")
metric = "euclidean"

df = dataframe.reset_index(drop=True)
columns = df.columns
column_names = [element.split('-')[-3] for element in columns] # assume algorithm names do not contain '-'
df = df.transpose()

# create a color map for the given labels
# and map it to the dataframe's index for row coloring in the plot
label_color_map = create_palette(column_names)
row_colors = pd.Series(column_names, index=df.index).map(label_color_map)

# plotting the seaborn figure
plt.figure(figsize=(10, 7))
clustergrid = sns.clustermap(df, metric=metric, method=linkage, row_colors=row_colors, col_cluster=False)
clustergrid.ax_heatmap.remove()
clustergrid.cax.remove()
clustergrid.ax_row_dendrogram.set_visible(True)
clustergrid.ax_col_dendrogram.set_visible(False)
legend_labels = [plt.Rectangle((0, 0), 0, 0, color=label_color_map[label]) for label in label_color_map]
plt.legend(legend_labels, label_color_map.keys(), bbox_to_anchor=(1.02, 1), loc='upper left')

# Use linkage matrix from seaborn clustergrid to generate cluster assignments
# then using fcluster with a distance thershold(t) to make the clusters
linkage_matrix = clustergrid.dendrogram_row.linkage
clusters = fcluster(linkage_matrix, t=0.5, criterion='distance')
cluster_data = {'algorithm': columns.tolist(), 'labels': clusters}
clusters_df = pd.DataFrame(cluster_data)

# saving files
make_required_dirs(output_file)
clusters_df.to_csv(output_file, sep='\t', index=False)
make_required_dirs(output_png)
plt.savefig(output_png, bbox_inches="tight", dpi=DPI)


def hac_horizontal(dataframe: pd.DataFrame, output_png: str, output_file: str, linkage: str = 'ward', metric: str = 'euclidean'):
"""
Performs hierarchical agglomerative clustering on the dataframe,
creates a dendrogram of the resulting tree using sckit learn and makes cluster groups scipy,
and saves the dendrogram and the cluster labels of said dendrogram in separate files.
@param dataframe: binary dataframe of edge comparison between algorithms from summarize_networks
@param output_png: the file name to save the dendrogram image
@param output_file: the file name to save the clustering labels
@param linkage: methods for calculating the distance between clusters
@param metric: used for distance computation between instances of clusters
"""
if linkage not in linkage_methods:
raise ValueError(f"linkage={linkage} must be one of {linkage_methods}")
if linkage == "ward":
if metric != "euclidean":
print("For linkage='ward', the metric must be 'euclidean'; setting metric = 'euclidean")
metric = "euclidean"
if metric not in distance_metrics:
raise ValueError(f"metric={metric} must be one of {distance_metrics}")

df = dataframe.reset_index(drop=True)
df = df.transpose()

# plotting figure
plt.figure(figsize=(10, 7))
model = AgglomerativeClustering(linkage=linkage, affinity=metric,distance_threshold=0.5, n_clusters=None)
model = model.fit(df)
plt.figure(figsize=(10, 7))
plt.title("Hierarchical Agglomerative Clustering Dendrogram")
plt.xlabel("algorithms")
algo_names = list(dataframe.columns)
plot_dendrogram(model, labels=algo_names, leaf_rotation=90, leaf_font_size=10, color_threshold=0,
truncate_mode=None)
plt.xlabel("algorithms")
make_required_dirs(output_png)
plt.savefig(output_png, bbox_inches="tight")

columns = dataframe.columns.tolist()
data = {'algorithm': columns, 'labels': model.labels_}
df = pd.DataFrame(data)
# saving cluster assignments
cluster_data = {'algorithm': algo_names, 'labels': model.labels_}
clusters_df = pd.DataFrame(cluster_data)

# saving files
make_required_dirs(output_file)
df.to_csv(output_file, sep='\t', index=False)
clusters_df.to_csv(output_file, sep='\t', index=False)
make_required_dirs(output_png)
plt.savefig(output_png, bbox_inches="tight", dpi=DPI)
18 changes: 16 additions & 2 deletions src/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,8 +344,22 @@ def process_config(config):
f'(current length {hash_length}).')
algorithm_params[alg["name"]][params_hash] = run_dict

return config, datasets, out_dir, algorithm_params, algorithm_directed

analysis_params = config["analysis"] if "analysis" in config else {}
ml_params = analysis_params["ml"] if "ml" in analysis_params else {}

pca_params = {}
if "components" in ml_params:
pca_params["components"] = ml_params["components"]
if "labels" in ml_params:
pca_params["labels"] = ml_params["labels"]

hac_params = {}
if "linkage" in ml_params:
hac_params["linkage"] = ml_params["linkage"]
if "metric" in ml_params:
hac_params["metric"] = ml_params ["metric"]

return config, datasets, out_dir, algorithm_params, algorithm_directed, pca_params, hac_params

def compare_files(file1, file2) -> bool:
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
,s1,s2,s3,longName,longName2,empty,spaces
,test-data-s1,test-data-s2,test-data-s3,test-data-longName,test-data-longName2,test-data-empty,test-data-spaces
A|||B,1,1,0,0,0,0,0
C|||D,1,1,0,0,0,0,0
E|||F,1,1,0,0,0,0,0
Expand Down
4 changes: 4 additions & 0 deletions test/ml/expected/expected-hac-horizontal-clusters.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
algorithm labels
test-data-s1 2
test-data-s2 1
test-data-s3 0
4 changes: 4 additions & 0 deletions test/ml/expected/expected-hac-vertical-clusters.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
algorithm labels
test-data-s1 1
test-data-s2 2
test-data-s3 3
4 changes: 4 additions & 0 deletions test/ml/expected/expected-pca-coordinates.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
algorithm PC1 PC2
test-data-s1 -2.006650210482033 -0.9865875190637743
test-data-s2 -1.5276508866841987 1.0799457247533237
test-data-s3 3.534301097166232 -0.0933582056895495
2 changes: 2 additions & 0 deletions test/ml/expected/expected-pca-variance.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
PC1: 89.76974544878588
PC2: 10.230254551214124
4 changes: 0 additions & 4 deletions test/ml/expected/expected_clusters.txt

This file was deleted.

2 changes: 0 additions & 2 deletions test/ml/expected/expected_components.txt

This file was deleted.

Loading