Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ install_requires =
wandb
numpy
torch
grelu<1.0.7
pyfaidx
genomepy
grelu>=1.0.9
lightning
torchmetrics
bioframe
Expand All @@ -67,19 +69,18 @@ install_requires =
h5py
pyBigWig
pyarrow
tangermeme<0.5
tangermeme

[options.packages.find]
where = src
exclude =
tests

[options.extras_require]
# Add here additional requirements for extra features, to install with:
# `pip install decima[PDF]` like:
# PDF = ReportLab; RXP
optional =
vep =
cyvcf2
all =
%(vep)s

# Add here test requirements (semicolon/line-separated)
testing =
Expand Down
9 changes: 8 additions & 1 deletion src/decima/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import sys
from decima.constants import NUM_CELLS, DECIMA_CONTEXT_SIZE
from decima.core.result import DecimaResult
from decima.interpret.save_attributions import predict_save_attributions
from decima.vep import predict_variant_effect
Expand All @@ -20,4 +21,10 @@
del version, PackageNotFoundError


__all__ = ["DecimaResult", "predict_variant_effect", "predict_save_attributions"]
__all__ = [
"DecimaResult",
"predict_variant_effect",
"predict_save_attributions",
"NUM_CELLS",
"DECIMA_CONTEXT_SIZE",
]
3 changes: 3 additions & 0 deletions src/decima/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from decima.cli.attributions import cli_attributions
from decima.cli.query_cell import cli_query_cell
from decima.cli.vep import cli_predict_variant_effect
from decima.cli.vep import cli_vep_ensemble
# from decima.cli.finetune import cli_finetune


Expand All @@ -32,6 +33,8 @@ def main():
main.add_command(cli_attributions, name="attributions")
main.add_command(cli_query_cell, name="query-cell")
main.add_command(cli_predict_variant_effect, name="vep")
main.add_command(cli_vep_ensemble, name="vep-ensemble")


if __name__ == "__main__":
main()
43 changes: 42 additions & 1 deletion src/decima/cli/vep.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import click
from decima.constants import DECIMA_CONTEXT_SIZE
from decima.utils.dataframe import ensemble_predictions
from decima.vep import predict_variant_effect


Expand All @@ -22,7 +23,7 @@
@click.option(
"--model",
type=str,
default="0",
default="ensemble",
help="Model to use for variant effect prediction either replicate number or path to the model.",
)
@click.option(
Expand Down Expand Up @@ -62,6 +63,11 @@
help="Column name for gene names. Default: None.",
)
@click.option("--genome", type=str, default="hg38", help="Genome build. Default: hg38.")
@click.option(
"--save-replicates",
is_flag=True,
help="Save the replicates in the output parquet file. Default: False.",
)
def cli_predict_variant_effect(
variants,
output_pq,
Expand All @@ -78,6 +84,7 @@ def cli_predict_variant_effect(
include_cols,
gene_col,
genome,
save_replicates,
):
"""Predict variant effect and save to parquet

Expand All @@ -94,6 +101,12 @@ def cli_predict_variant_effect(
>>> decima vep -v "data/sample.vcf" -o "vep_results.parquet" --gene-col "gene_name" # use gene_name column as gene names if these option passed genes and variants mapped based on these column not based on the genomic locus based on the annotaiton.

>>> decima vep -v "data/sample.vcf" -o "vep_results.parquet" --distance-type tss --min-distance 50000 --max-distance 100000 # predict for variants within 50kb of the TSS and 100kb of the TSS

>>> decima vep -v "data/sample.vcf" -o "vep_results.parquet" --save-replicates # save the replicates in the output parquet file

>>> decima vep -v "data/sample.vcf" -o "vep_results.parquet" --genome "hg38" # use hg38 genome build

>>> decima vep -v "data/sample.vcf" -o "vep_results.parquet" --genome "path/to/fasta/hg38.fa" # use custom genome build
"""
if model in ["0", "1", "2", "3"]: # replicate index
model = int(model)
Expand All @@ -104,6 +117,9 @@ def cli_predict_variant_effect(
if include_cols:
include_cols = include_cols.split(",")

if save_replicates and (model != "ensemble"):
raise ValueError("`--save-replicates` is only supported for ensemble model (`--model ensemble`).")

predict_variant_effect(
variants,
output_pq=output_pq,
Expand All @@ -120,4 +136,29 @@ def cli_predict_variant_effect(
include_cols=include_cols,
gene_col=gene_col,
genome=genome,
save_replicates=save_replicates,
)


@click.command()
@click.option("-f", "--files", type=str, help="Path to the parquet files to ensemble. Can be passed multiple times.")
@click.option("-o", "--output_pq", type=click.Path(), help="Path to the output parquet file.")
@click.option(
"--save-replicates",
default=False,
type=bool,
is_flag=True,
help="Save the replicates in the output parquet file. Default: False.",
)
def cli_vep_ensemble(files, output_pq, save_replicates=False):
"""Ensemble variant effect predictions from multiple parquet files

Examples:

>>> decima vep-ensemble -f "data/sample_rep0.parquet,data/sample_rep1.parquet,data/sample_rep2.parquet" -o "vep_results.parquet"

>>> decima vep-ensemble -f "data/sample_rep*.parquet" -o "vep_results.parquet" --save-replicates
"""
if "," in files:
files = files.split(",")
ensemble_predictions(files=files, output_pq=output_pq, save_replicates=save_replicates)
2 changes: 2 additions & 0 deletions src/decima/constants.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
DECIMA_CONTEXT_SIZE = 524288
SUPPORTED_GENOMES = {"hg38"}
NUM_CELLS = 8856
6 changes: 5 additions & 1 deletion src/decima/core/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,11 @@ def prepare_one_hot(self, gene: str, variants: Optional[List[Dict]] = None) -> t

def gene_sequence(self, gene: str, stranded: bool = True) -> str:
"""Get sequence for a gene."""
assert gene in self.genes, f"{gene} is not in the anndata object"
try:
assert gene in self.genes, f"{gene} is not in the anndata object"
except AssertionError:
print(gene)
print(self.genes)
gene_meta = self.gene_metadata.loc[gene]
if not stranded:
gene_meta = {"chrom": gene_meta.chrom, "start": gene_meta.start, "end": gene_meta.end}
Expand Down
22 changes: 21 additions & 1 deletion src/decima/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,13 +216,14 @@ def overlap_genes(
min_distance=0,
max_distance=float("inf"),
):
assert min_distance < max_distance, "`min_distance` must be less than `max_distance`"
include_cols = include_cols or list()

df_variants = df_variants.copy().astype({"chrom": str})
if not df_variants["chrom"].str.startswith("chr").any():
warnings.warn("Chromosome names do not have 'chr' prefix. Adding it to the chromosome names.")
df_variants["chrom"] = "chr" + df_variants["chrom"].astype(str)
df_variants["start"] = df_variants.pos
df_variants["start"] = df_variants.pos.astype(int)
df_variants["end"] = df_variants["start"] + 1

if gene_col is not None:
Expand All @@ -237,6 +238,12 @@ def overlap_genes(
df_variants = df_variants.rename(columns={gene_col: "gene"})
df = df_variants.merge(df_genes, how="left", on="gene", suffixes=("", "_gene"))
else:
if "gene" in df_variants.columns:
warnings.warn(
"Gene column `gene` found in variant file."
" Overwriting with `gene` column with genes based on the overlap based on genomic coordinates."
)
del df_variants["gene"] # remove gene column from df_genes to avoid duplicate column names
df = bioframe.overlap(df_genes, df_variants, how="inner", suffixes=("_gene", ""))

if df.shape[0] == 0:
Expand Down Expand Up @@ -363,3 +370,16 @@ def collate_fn(self, batch):
"seq": default_collate([i["seq"] for i in batch]),
"warning": list(flatten([b["warning"] for b in batch])),
}

def __str__(self):
return (
"VariantDataset("
f"{self.variants.shape[0]} variants "
f"from {list(self.variants.chrom.unique())} "
f"between {self.variants.start.min()} "
f"and {self.variants.end.max()} bp from TSS"
")"
)

def __repr__(self):
return self.__str__()
47 changes: 40 additions & 7 deletions src/decima/hub/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,29 @@
import os
from typing import Union, Optional
import wandb
from pathlib import Path
from tempfile import TemporaryDirectory
import anndata
from grelu.resources import get_artifact, DEFAULT_WANDB_HOST
from decima.model.lightning import LightningModel
from decima.model.lightning import LightningModel, EnsembleLightningModel


def login_wandb():
try:
wandb.login(host=DEFAULT_WANDB_HOST, anonymous="never", timeout=0)
wandb.login(host=os.environ.get("WANDB_HOST", DEFAULT_WANDB_HOST), anonymous="never", timeout=0)
except wandb.errors.UsageError: # login anonymously if not logged in already
wandb.login(host=DEFAULT_WANDB_HOST, relogin=True, anonymous="must", timeout=0)
wandb.login(host=os.environ.get("WANDB_HOST", DEFAULT_WANDB_HOST), relogin=True, anonymous="must", timeout=0)


def get_model_name(model: Union[str, int] = 0) -> str:
if isinstance(model, int):
return f"decima_rep{model}"
elif isinstance(model, str):
return model
else:
raise ValueError(
f"Invalid model: {model} it need to be a string of model_name on wandb or an integer of replicate number {0, 1, 2, 3}"
)


def load_decima_model(model: Union[str, int] = 0, device: Optional[str] = None):
Expand All @@ -32,22 +44,40 @@ def load_decima_model(model: Union[str, int] = 0, device: Optional[str] = None):
"""
if isinstance(model, LightningModel):
return model
elif model == "ensemble":
model = EnsembleLightningModel(
[
load_decima_model(0, device),
load_decima_model(1, device),
load_decima_model(2, device),
load_decima_model(3, device),
]
)
model.name = "ensemble"
return model
elif isinstance(model, str):
model_name = get_model_name(model)
if Path(model).exists():
return LightningModel.load_from_checkpoint(model, map_location=device)
model_name = model
model = LightningModel.load_from_checkpoint(model, map_location=device)
model.name = model_name
return model
elif isinstance(model, int):
model_name = f"decima_rep{model}"
model_name = get_model_name(model)
else:
raise ValueError(
f"Invalid model: {model} it need to be a string of model_name on wandb "
"or an integer of replicate number {0, 1, 2, 3}, or a path to a local model"
)

if model_name.upper() in os.environ:
return LightningModel.load_from_checkpoint(os.environ[model_name.upper()], map_location=device)

art = get_artifact(model_name, project="decima")
with TemporaryDirectory() as d:
art.download(d)
return LightningModel.load_from_checkpoint(Path(d) / "model.ckpt", map_location=device)
model = LightningModel.load_from_checkpoint(Path(d) / "model.ckpt", map_location=device)
model.name = str(model_name)
return model


def load_decima_metadata(path: Optional[str] = None):
Expand All @@ -62,6 +92,9 @@ def load_decima_metadata(path: Optional[str] = None):
if path is not None:
return anndata.read_h5ad(path)

if "DECIMA_METADATA" in os.environ:
return anndata.read_h5ad(os.environ["DECIMA_METADATA"])

art = get_artifact("decima_metadata", project="decima")
with TemporaryDirectory() as d:
art.download(d)
Expand Down
2 changes: 1 addition & 1 deletion src/decima/hub/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
def download_hg38():
"""Download hg38 genome from UCSC."""
logger.info("Downloading hg38 genome...")
genomepy.install_genome("hg38", provider="UCSC")
genomepy.install_genome(provider="url", name="http://hgdownload.soe.ucsc.edu/goldenPath/hg38/bigZips/hg38.fa.gz")


def download_decima_weights():
Expand Down
3 changes: 2 additions & 1 deletion src/decima/interpret/attributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,8 @@ def peaks_to_bed(self):
df["start"], df["end"] = self.end - df["end"], self.end - df["start"]

df["strand"] = "."
df["score"] = -np.log10(df["p-value"] + 1e-50)
# np.maximum because of https://github.com/jmschrei/tangermeme/issues/40
df["score"] = -np.log10(np.maximum(df["p-value"], 0) + 1e-50)
df["score"] = df["score"].astype(int).clip(lower=0, upper=50)
return df[["chrom", "start", "end", "name", "score", "strand"]]

Expand Down
3 changes: 3 additions & 0 deletions src/decima/model/decima_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ def __init__(self, n_tasks: int, replicate: int = 0, mask=True, init_borzoi=True
attn_dropout=0.0,
n_heads=8,
n_pos_features=32,
# backward compatibility with grelu<1.0.7
norm_kwargs={"eps": 1e-5},
act_func="gelu",
final_act_func=None,
final_pool_func=None,
)
Expand Down
Loading