Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 57 additions & 47 deletions scripts/count_svs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Count SVs per sample in a DuckDB database
"""Count SVs per sample

usage: python count_svs.py <counts_db> <sv_db>

Expand Down Expand Up @@ -40,66 +40,76 @@
from pathlib import Path
from collections.abc import Sequence

import duckdb


def validate_filters(con: duckdb.DuckDBPyConnection):
sql = "SELECT svtype, min_svlen, max_svlen FROM sv_filters;"
filters = con.sql(sql).fetchall()
for f in filters:
if f[1] < 0:
raise ValueError("Min SV length must be >= 0")
if f[1] > f[2]:
raise ValueError("Min SV length must be <= max SV length")


def count_svs(con: duckdb.DuckDBPyConnection, filter_id: int):
sql = """SELECT svtype, min_svlen, max_svlen
FROM sv_filters
WHERE id = ?;
"""
filters = con.execute(sql, [filter_id]).fetchall()
sql = (
f"CREATE OR REPLACE TABLE sv_counts_{filter_id}"
" AS SELECT sample, COUNT(*) AS count"
" FROM sv_db.svs"
" WHERE svtype = ? AND svlen >= ? AND svlen <= ?"
" GROUP BY sample"
)

con.execute(sql, list(filters[0]))


def make_tables(counts_db: Path, sv_db: Path):
with duckdb.connect(counts_db) as con:
validate_filters(con)
con.sql(f"ATTACH '{sv_db}' AS sv_db;")
filter_ids = con.sql("SELECT id FROM sv_filters;").fetchall()
for i in filter_ids:
count_svs(con, i[0])
import pandas as pd


def validate_filters(filters: pd.DataFrame):
if (filters["min_svlen"] < 0).any():
raise ValueError("Min SV length must be >= 0")
if (filters["min_svlen"] > filters["max_svlen"]).any():
raise ValueError("Min SV length must be <= max SV length")


def count_svs(svs: pd.DataFrame, filters: pd.DataFrame) -> pd.DataFrame:
"""Count SVs per sample for each filter."""
all_counts = []
for _, row in filters.iterrows():
filtered_svs = svs[
(svs["svtype"] == row["svtype"])
& (svs["svlen"] >= row["min_svlen"])
& (svs["svlen"] <= row["max_svlen"])
]
if filtered_svs.empty:
continue
counts = filtered_svs.groupby("sample").size().reset_index(name="count")
counts = counts.assign(
svtype=row["svtype"],
min_svlen=row["min_svlen"],
max_svlen=row["max_svlen"],
)
all_counts.append(counts)
if not all_counts:
return pd.DataFrame(
columns=["sample", "count", "svtype", "min_svlen", "max_svlen"]
)
return pd.concat(all_counts, ignore_index=True)


def main(argv: Sequence[str] | None = None) -> int:
parser = argparse.ArgumentParser(description="Count SVs per sample")
parser.add_argument(
"counts_db",
metavar="COUNTS_DB",
"svs_tsv",
metavar="SVS_TSV",
type=Path,
help="Path to the SV counts DuckDB database",
help="Path to the SVs TSV file",
)
parser.add_argument(
"sv_db", metavar="SV_DB", type=Path, help="Path to the SV DuckDB database"
"filters_tsv",
metavar="FILTERS_TSV",
type=Path,
help="Path to the filters TSV file",
)
parser.add_argument(
"output_tsv",
metavar="OUTPUT_TSV",
type=Path,
help="Path to the output TSV file",
)
args = parser.parse_args(argv)

retval = 0

if not args.counts_db.is_file():
raise FileNotFoundError("Counts database must exist")
if not args.sv_db.is_file():
raise FileNotFoundError("SV database must exist")
if not args.svs_tsv.is_file():
raise FileNotFoundError("SVs TSV file must exist")
if not args.filters_tsv.is_file():
raise FileNotFoundError("Filters TSV file must exist")

svs = pd.read_csv(args.svs_tsv, sep="\t")
filters = pd.read_csv(args.filters_tsv, sep="\t")

make_tables(args.counts_db, args.sv_db)
validate_filters(filters)
counts_df = count_svs(svs, filters)
counts_df.to_csv(args.output_tsv, sep="\t", index=False)

return retval

Expand Down
176 changes: 103 additions & 73 deletions scripts/determine_outlier_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,91 +14,112 @@
from pathlib import Path
from collections.abc import Sequence

import duckdb


def find_filter_outliers(
con: duckdb.DuckDBPyConnection, filter_id: int, iqr_mult: float
):
sql = (
"SELECT quant[2], quant[3] - quant[1]"
" FROM"
" (SELECT quantile_cont(count, [0.25, 0.5, 0.75]) AS quant"
f" FROM sv_counts_{filter_id});"
)
results = con.sql(sql).fetchall()[0]
median = results[0]
iqr = results[1]
sql = (
f"CREATE OR REPLACE TABLE outliers_{filter_id}"
" AS"
" SELECT sample"
f" FROM sv_counts_{filter_id}"
" WHERE count < $1 - $2 * $3 OR count > $1 + $2 * $3"
)
con.execute(sql, [median, iqr, iqr_mult])


def find_sv_count_outliers(con: duckdb.DuckDBPyConnection, iqr_mult: float):
filter_ids = con.sql("SELECT id FROM sv_filters;").fetchall()
for i in filter_ids:
find_filter_outliers(con, i[0], iqr_mult)


def find_wgd_outliers(con: duckdb.DuckDBPyConnection, min_wgd: float, max_wgd: float):
sql = (
"COPY TO wgd_outliers"
" FROM (SELECT sample"
" FROM wgd_scores"
" WHERE score < ? OR score > ?)"
)
con.execute(sql, [min_wgd, max_wgd])


def make_wgd_tables(con: duckdb.DuckDBPyConnection):
con.sql("CREATE OR REPLACE TABLE wgd_scores (sample VARCHAR, score FLOAT);")
con.sql("CREATE OR REPLACE TABLE wgd_outliers (sample VARCHAR);")


def load_wgd_table(con: duckdb.DuckDBPyConnection, scores_path: Path):
# This seems like a SQL injection vunerability
con.sql(f"COPY wgd_scores FROM '{scores_path}' (DELIMITER '\\t', HEADER false);")
import pandas as pd


def find_sv_count_outliers(
sv_counts: pd.DataFrame, filters: pd.DataFrame, iqr_mult: float
) -> pd.DataFrame:
"""Find outlier samples using SV counts."""
filter_outliers = []
for _, row in filters.iterrows():
counts = sv_counts[
(sv_counts["svtype"] == row["svtype"])
& (sv_counts["svlen"] >= row["min_svlen"])
& (sv_counts["svlen"] <= row["max_svlen"])
]
if counts.empty:
continue

quantiles = counts["count"].quantile([0.25, 0.75])
iqr = quantiles[0.75] - quantiles[0.25]
median = counts["count"].median()
lower_bound = median - iqr * iqr_mult
upper_bound = median + iqr * iqr_mult

outliers = counts[
(counts["count"] < lower_bound) | (counts["count"] > upper_bound)
]
outliers = outliers.assign(
min_svlen=row["min_svlen"], max_svlen=row["max_svlen"]
)
filter_outliers.append(outliers)

if not filter_outliers:
return pd.DataFrame(
columns=["sample", "count", "svtype", "min_svlen", "max_svlen"]
)
return pd.concat(filter_outliers)


def find_wgd_outliers(
wgd_scores: pd.DataFrame, min_wgd: float, max_wgd: float
) -> pd.DataFrame:
"""Find outlier samples using WGD scores."""
return wgd_scores[
(wgd_scores["score"] < min_wgd) | (wgd_scores["score"] > max_wgd)
]


def determine_outliers(
db: Path,
sv_counts_tsv: Path,
sv_filters_tsv: Path,
iqr_mult: float,
sv_count_outliers_path: Path,
wgd_outlier_samples_path: Path,
wgd_scores: Path | None = None,
min_wgd: float | None = None,
max_wgd: float | None = None,
):
# TODO Might be sensible to do this in a transaction and roll back on error
with duckdb.connect(db) as con:
find_sv_count_outliers(con, iqr_mult)
make_wgd_tables(con)
if wgd_scores is not None:
if not wgd_scores.is_file():
raise FileNotFoundError("WGD scores file must exist if given")
else:
load_wgd_table(con, wgd_scores)
if min_wgd is None or max_wgd is None:
raise ValueError(
"Min and max WGD scores must be given if WGD scores are given"
)
elif min_wgd > max_wgd:
raise ValueError("Min WGD score must be <= max WGD score")

find_wgd_outliers(con, min_wgd, max_wgd)
sv_counts = pd.read_csv(sv_counts_tsv, sep="\t")
sv_filters = pd.read_csv(sv_filters_tsv, sep="\t")
sv_count_outliers = find_sv_count_outliers(sv_counts, sv_filters, iqr_mult)
sv_count_outliers.to_csv(sv_count_outliers_path, sep="\t", index=False)

if wgd_scores:
if not wgd_scores.is_file():
raise FileNotFoundError("WGD scores file must exist if given")
if min_wgd is None or max_wgd is None:
raise ValueError(
"Min and max WGD scores must be given if WGD scores are given"
)
if min_wgd > max_wgd:
raise ValueError("Min WGD score must be <= max WGD score")

wgd_scores_df = pd.read_csv(wgd_scores, sep="\t", names=["sample", "score"])
wgd_outliers = find_wgd_outliers(wgd_scores_df, min_wgd, max_wgd)
wgd_outliers.to_csv(wgd_outlier_samples_path, sep="\t", index=False, header=False)
else:
# create empty file
wgd_outlier_samples_path.touch()


def main(argv: Sequence[str] | None = None) -> int:
parser = argparse.ArgumentParser(
description="Determine outlier samples in GATK-SV callset"
)
parser.add_argument(
"sv_counts_db",
metavar="SV_COUNTS_DB",
help="Path to the input SV counts DuckDB database",
"sv_counts_tsv",
metavar="SV_COUNTS_TSV",
help="Path to the input SV counts TSV",
type=Path,
)
parser.add_argument(
"sv_filters_tsv",
metavar="SV_FILTERS_TSV",
help="Path to the input SV filters TSV",
type=Path,
)
parser.add_argument(
"sv_count_outlier_samples_tsv",
metavar="SV_COUNT_OUTLIER_SAMPLES_TSV",
help="Output path for SV count outlier samples",
type=Path,
)
parser.add_argument(
"wgd_outlier_samples_tsv",
metavar="WGD_OUTLIER_SAMPLES_TSV",
help="Output path for WGD outlier samples",
type=Path,
)
parser.add_argument(
Expand Down Expand Up @@ -132,13 +153,22 @@ def main(argv: Sequence[str] | None = None) -> int:

retval = 0

if not args.sv_counts_db.is_file():
raise FileNotFoundError("Counts database must exist")
if not args.sv_counts_tsv.is_file():
raise FileNotFoundError("Counts TSV must exist")
if not args.sv_filters_tsv.is_file():
raise FileNotFoundError("Filters TSV must exist")
if args.iqr_mult < 0:
raise ValueError("IQR multiplier must be greater than or equal to 0")

determine_outliers(
args.sv_counts_db, args.iqr_mult, args.wgd_scores, args.min_wgd, args.max_wgd
args.sv_counts_tsv,
args.sv_filters_tsv,
args.iqr_mult,
args.sv_count_outlier_samples_tsv,
args.wgd_outlier_samples_tsv,
args.wgd_scores,
args.min_wgd,
args.max_wgd,
)

return retval
Expand Down
Loading