Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/decima/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from decima.cli.attributions import cli_attributions
from decima.cli.query_cell import cli_query_cell
from decima.cli.vep import cli_predict_variant_effect
from decima.cli.finetune import cli_finetune
from decima.cli.vep import cli_vep_ensemble
# from decima.cli.finetune import cli_finetune


logger = logging.getLogger("decima")
Expand All @@ -33,6 +33,7 @@ def main():
main.add_command(cli_attributions, name="attributions")
main.add_command(cli_query_cell, name="query-cell")
main.add_command(cli_predict_variant_effect, name="vep")
main.add_command(cli_finetune, name="finetune")
main.add_command(cli_vep_ensemble, name="vep-ensemble")


Expand Down
85 changes: 48 additions & 37 deletions src/decima/cli/finetune.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Finetune the Decima model."""

import os
import logging
import click
import anndata
import wandb
Expand All @@ -9,62 +8,74 @@


@click.command()
@click.option("--name", required=True, help="Project name")
@click.option("--dir", required=True, help="Data directory path")
@click.option("--lr", default=0.001, type=float, help="Learning rate")
@click.option("--weight", required=True, type=float, help="Weight parameter")
@click.option("--grad", required=True, type=int, help="Gradient accumulation steps")
@click.option("--replicate", default=0, type=int, help="Replication number")
@click.option("--bs", default=4, type=int, help="Batch size")
def cli_finetune(name, dir, lr, weight, grad, replicate, bs):
@click.option("--name", required=True, help="Name of the run.")
@click.option("--model", default="0", type=str, help="Model path or replication number. If a path is provided, the model will be loaded from the path. If a replication number is provided, the model will be loaded from the replication number.")
@click.option("--matrix-file", required=True, help="Matrix file path.")
@click.option("--h5-file", required=True, help="H5 file path.")
@click.option("--outdir", required=True, help="Output directory path to save model checkpoints.")
@click.option("--learning-rate", default=0.001, type=float, help="Learning rate.")
@click.option("--loss-total-weight", required=True, type=float, help="Total weight parameter for the loss function.")
@click.option("--gradient-accumulation", required=True, type=int, help="Gradient accumulation steps.")
@click.option("--batch-size", default=4, type=int, help="Batch size.")
@click.option("--max-seq-shift", default=5000, type=int, help="Shift augmentation.")
@click.option("--gradient-clipping", default=0.0, type=float, help="Gradient clipping.")
@click.option("--save-top-k", default=1, type=int, help="Number of checkpoints to save.")
@click.option("--epochs", default=1, type=int, help="Number of epochs.")
@click.option("--logger", default="wandb", type=str, help="Logger.")
@click.option("--num-workers", default=16, type=int, help="Number of workers.")
@click.option("--seed", default=0, type=int, help="Random seed.")
def cli_finetune(name, model, matrix_file, h5_file , outdir, learning_rate, loss_total_weight, gradient_accumulation, batch_size, max_seq_shift, gradient_clipping, save_top_k, epochs, logger, num_workers, seed):
"""Finetune the Decima model."""
wandb.login(host="https://genentech.wandb.io")
run = wandb.init(project="decima", dir=name, name=name)

matrix_file = os.path.join(dir, "aggregated.h5ad")
h5_file = os.path.join(dir, "data.h5")
print(f"Data paths: {matrix_file}, {h5_file}")

print("Reading anndata")
train_logger = logger
logger = logging.getLogger("decima")
logger.info(f"Data paths: matrix_file={matrix_file}, h5_file={h5_file}")
logger.info("Reading anndata")
ad = anndata.read_h5ad(matrix_file)

print("Making dataset objects")
logger.info("Making dataset objects")
train_dataset = HDF5Dataset(
h5_file=h5_file,
ad=ad,
key="train",
max_seq_shift=5000,
max_seq_shift=max_seq_shift,
augment_mode="random",
seed=0,
seed=seed,
)
val_dataset = HDF5Dataset(h5_file=h5_file, ad=ad, key="val", max_seq_shift=0)

train_params = {
"optimizer": "adam",
"batch_size": bs,
"num_workers": 16,
"name": name,
"batch_size": batch_size,
"num_workers": num_workers,
"devices": 0,
"logger": "wandb",
"save_dir": dir,
"max_epochs": 15,
"lr": lr,
"total_weight": weight,
"accumulate_grad_batches": grad,
"logger": train_logger,
"save_dir": outdir,
"max_epochs": epochs,
"lr": learning_rate,
"total_weight": loss_total_weight,
"accumulate_grad_batches": gradient_accumulation,
"loss": "poisson_multinomial",
"pairs": ad.uns["disease_pairs"].values,
# "pairs": ad.uns["disease_pairs"].values,
"clip": gradient_clipping,
"save_top_k": save_top_k,
"pin_memory": True,
}
model_params = {
"n_tasks": ad.shape[0],
"replicate": replicate,
"replicate": model,
}
print(f"train_params: {train_params}")
print(f"model_params: {model_params}")
logger.info(f"train_params: {train_params}")
logger.info(f"model_params: {model_params}")

print("Initializing model")
logger.info("Initializing model")
model = LightningModel(model_params=model_params, train_params=train_params)

print("Training")
logger.info("Training")
if logger == "wandb":
wandb.login(host="https://genentech.wandb.io")
run = wandb.init(project="decima", dir=name, name=name)
model.train_on_dataset(train_dataset, val_dataset)
train_dataset.close()
val_dataset.close()
run.finish()
if logger == "wandb":
run.finish()
120 changes: 57 additions & 63 deletions src/decima/cli/predict_genes.py
Original file line number Diff line number Diff line change
@@ -1,73 +1,67 @@
"""Make predictions for all genes using an HDF5 file created by Decima's ``write_hdf5.py``."""

import os
import click
import anndata
import numpy as np
import torch
from decima.constants import DECIMA_CONTEXT_SIZE
from decima.model.lightning import LightningModel
from decima.data.read_hdf5 import list_genes
from decima.data.dataset import HDF5Dataset

# TODO: input can be just a h5ad file rather than a combination of h5 and matrix file.
from decima.tools.inference import predict_gene_expression


@click.command()
@click.option("--device", type=int, help="Which GPU to use.")
@click.option("--ckpts", multiple=True, required=True, help="Path to the model checkpoint(s).")
@click.option("--h5_file", required=True, help="Path to h5 file indexed by genes.")
@click.option("--matrix_file", required=True, help="Path to h5ad file containing genes to predict.")
@click.option("--out_file", required=True, help="Output file path.")
@click.option("-o", "--output", type=click.Path(), help="Path to the output h5ad file.")
@click.option(
"--genes",
type=str,
default=None,
help="List of genes to predict. Default: None (all genes). If provided, only these genes will be predicted.",
)
@click.option(
"-m",
"--model",
type=str,
default="ensemble",
help="Path to the model checkpoint: `0`, `1`, `2`, `3`, `ensemble` or `path/to/model.ckpt`.",
)
@click.option(
"--metadata",
type=click.Path(exists=True),
default=None,
help="Path to the metadata anndata file. Default: None.",
)
@click.option(
"--device",
type=str,
default=None,
help="Device to use. Default: None which automatically selects the best device.",
)
@click.option("--batch-size", type=int, default=8, help="Batch size for the model. Default: 8")
@click.option("--num-workers", type=int, default=4, help="Number of workers for the loader. Default: 4")
@click.option("--max_seq_shift", default=0, help="Maximum jitter for augmentation.")
def cli_predict_genes(device, ckpts, h5_file, matrix_file, out_file, max_seq_shift):
"""Make predictions for all genes."""
torch.set_float32_matmul_precision("medium")

# TODO: device is unused, set the device appropriately
os.environ["CUDA_VISIBLE_DEVICES"] = str(device)
device = torch.device(0)
@click.option("--genome", type=str, default="hg38", help="Genome build. Default: hg38.")
@click.option(
"--save-replicates",
is_flag=True,
help="Save the replicates in the output parquet file. Default: False.",
)
def cli_predict_genes(
output, genes, model, metadata, device, batch_size, num_workers, max_seq_shift, genome, save_replicates
):
if model in ["0", "1", "2", "3"]:
model = int(model)

print("Loading anndata")
ad = anndata.read_h5ad(matrix_file)
assert np.all(list_genes(h5_file, key=None) == ad.var_names.tolist())

print("Making dataset")
ds = HDF5Dataset(
key=None,
h5_file=h5_file,
ad=ad,
seq_len=DECIMA_CONTEXT_SIZE,
max_seq_shift=max_seq_shift,
)
if isinstance(device, str) and device.isdigit():
device = int(device)

print("Loading models from checkpoint")
models = [LightningModel.load_from_checkpoint(f).eval() for f in ckpts]
if genes is not None:
genes = genes.split(",")

print("Computing predictions")
preds = (
np.stack([model.predict_on_dataset(ds, devices=0, batch_size=6, num_workers=16) for model in models]).mean(0).T
)
ad.layers["preds"] = preds
if save_replicates and (model != "ensemble"):
raise ValueError("`--save-replicates` is only supported for ensemble model (`--model ensemble`).")

print("Computing correlations per gene")
ad.var["pearson"] = [np.corrcoef(ad.X[:, i], ad.layers["preds"][:, i])[0, 1] for i in range(ad.shape[1])]
ad.var["size_factor_pearson"] = [np.corrcoef(ad.X[:, i], ad.obs["size_factor"])[0, 1] for i in range(ad.shape[1])]
print(
f"Mean Pearson Correlation per gene: True: {ad.var.pearson.mean().round(2)} Size Factor: {ad.var.size_factor_pearson.mean().round(2)}"
ad = predict_gene_expression(
genes=genes,
model=model,
metadata_anndata=metadata,
device=device,
batch_size=batch_size,
num_workers=num_workers,
max_seq_shift=max_seq_shift,
genome=genome,
save_replicates=save_replicates,
)

print("Computing correlation per track")
for dataset in ad.var.dataset.unique():
key = f"{dataset}_pearson"
ad.obs[key] = [
np.corrcoef(
ad[i, ad.var.dataset == dataset].X,
ad[i, ad.var.dataset == dataset].layers["preds"],
)[0, 1]
for i in range(ad.shape[0])
]
print(f"Mean Pearson Correlation per pseudobulk over {dataset} genes: {ad.obs[key].mean().round(2)}")

print("Saved")
ad.write_h5ad(out_file)
ad.write_h5ad(output)
24 changes: 21 additions & 3 deletions src/decima/core/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,25 @@ def predicted_expression_matrix(self, genes: Optional[List[str]] = None) -> pd.D
else:
return pd.DataFrame(self.anndata[:, genes].layers["preds"], index=self.cells, columns=genes)

def prepare_one_hot(self, gene: str, variants: Optional[List[Dict]] = None) -> torch.Tensor:
def _pad_gene_metadata(self, gene_meta: pd.Series, padding: int = 0) -> pd.Series:
"""
Pad gene metadata with padding.

Args:
gene_meta: Gene metadata
padding: Padding to add to the gene metadata

Returns:
pd.Series: Padded gene metadata
"""
gene_meta = gene_meta.copy()
gene_meta["start"] = gene_meta["start"] - padding
gene_meta["end"] = gene_meta["end"] + padding
gene_meta["gene_mask_start"] = gene_meta["gene_mask_start"] + padding
gene_meta["gene_mask_end"] = gene_meta["gene_mask_end"] + padding
return gene_meta

def prepare_one_hot(self, gene: str, variants: Optional[List[Dict]] = None, padding: int = 0) -> torch.Tensor:
"""Prepare one-hot encoding for a gene.

Args:
Expand All @@ -167,15 +185,15 @@ def prepare_one_hot(self, gene: str, variants: Optional[List[Dict]] = None) -> t
torch.Tensor: One-hot encoding of the gene
"""
assert gene in self.genes, f"{gene} is not in the anndata object"
gene_meta = self.gene_metadata.loc[gene]
gene_meta = self._pad_gene_metadata(self.gene_metadata.loc[gene], padding)

if variants is None:
seq = intervals_to_strings(gene_meta, genome="hg38")
gene_start, gene_end = gene_meta.gene_mask_start, gene_meta.gene_mask_end
else:
seq, (gene_start, gene_end) = prepare_seq_alt_allele(gene_meta, variants)

mask = np.zeros(shape=(1, DECIMA_CONTEXT_SIZE))
mask = np.zeros(shape=(1, DECIMA_CONTEXT_SIZE + padding * 2))
mask[0, gene_start:gene_end] += 1
mask = torch.from_numpy(mask).float()

Expand Down
Loading
Loading