Skip to content

Commit 73376ee

Browse files
committed
move modules into scripts to avoid having to import entire local package
1 parent b2e7549 commit 73376ee

File tree

4 files changed

+56
-10
lines changed

4 files changed

+56
-10
lines changed

src/bettercode/simple_workflow/snakemake_workflow/scripts/compute_correlation.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,13 @@
44

55
import pandas as pd
66

7-
from bettercode.simple_workflow.correlation import (
8-
compute_correlation_matrix,
9-
)
7+
8+
def compute_correlation_matrix(
9+
df: pd.DataFrame,
10+
method: str = "spearman",
11+
) -> pd.DataFrame:
12+
"""Compute correlation matrix using the specified method."""
13+
return df.corr(method=method)
1014

1115

1216
def main():

src/bettercode/simple_workflow/snakemake_workflow/scripts/filter_data.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44

55
import pandas as pd
66

7-
from bettercode.simple_workflow.filter_data import (
8-
filter_numerical_columns,
9-
)
7+
8+
def filter_numerical_columns(df: pd.DataFrame) -> pd.DataFrame:
9+
"""Filter a dataframe to keep only numerical columns."""
10+
return df.select_dtypes(include=["number"])
1011

1112

1213
def main():

src/bettercode/simple_workflow/snakemake_workflow/scripts/generate_heatmap.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,45 @@
22

33
from pathlib import Path
44

5+
import matplotlib.pyplot as plt
56
import pandas as pd
7+
import seaborn as sns
68

7-
from bettercode.simple_workflow.visualization import (
8-
generate_clustered_heatmap,
9-
)
9+
10+
def generate_clustered_heatmap(
11+
corr_matrix: pd.DataFrame,
12+
output_path: Path | None = None,
13+
figsize: tuple[int, int] = (8, 10),
14+
cmap: str = "coolwarm",
15+
vmin: float = -1.0,
16+
vmax: float = 1.0,
17+
) -> sns.matrix.ClusterGrid:
18+
"""Generate a clustered heatmap from a correlation matrix."""
19+
# Create clustered heatmap
20+
g = sns.clustermap(
21+
corr_matrix,
22+
cmap=cmap,
23+
vmin=vmin,
24+
vmax=vmax,
25+
figsize=figsize,
26+
dendrogram_ratio=(0.1, 0.1),
27+
cbar_pos=(0.02, 0.8, 0.03, 0.15),
28+
xticklabels=False,
29+
yticklabels=True,
30+
)
31+
32+
# Set y-axis label font size
33+
plt.setp(g.ax_heatmap.get_yticklabels(), rotation=0, fontsize=3)
34+
35+
# Set title
36+
g.fig.suptitle("Clustered Correlation Heatmap (Spearman)", y=1.02, fontsize=14)
37+
38+
# Save if output path provided
39+
if output_path is not None:
40+
output_path.parent.mkdir(parents=True, exist_ok=True)
41+
g.savefig(output_path, dpi=300, bbox_inches="tight")
42+
43+
return g
1044

1145

1246
def main():

src/bettercode/simple_workflow/snakemake_workflow/scripts/join_data.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,14 @@
44

55
import pandas as pd
66

7-
from bettercode.simple_workflow.join_data import join_dataframes
7+
8+
def join_dataframes(
9+
df1: pd.DataFrame,
10+
df2: pd.DataFrame,
11+
how: str = "inner",
12+
) -> pd.DataFrame:
13+
"""Join two dataframes based on their index."""
14+
return df1.join(df2, how=how, lsuffix="_mv", rsuffix="_demo")
815

916

1017
def main():

0 commit comments

Comments
 (0)