Skip to content

Commit

Permalink
Add non-small cell lung cancer model (#14)
Browse files Browse the repository at this point in the history
  • Loading branch information
martinkim0 authored Mar 14, 2023
1 parent 83e9e98 commit a94bae7
Show file tree
Hide file tree
Showing 6 changed files with 152 additions and 6 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -138,5 +138,6 @@ local/
.vscode

# Snakemake
*/.snakemake
*/.snakemake/
*/data/
*/models/
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ Code for the pre-trained reference models uploaded to scvi-hub on HuggingFace.

First, install pre-commit from pip and run `pre-commit install` at the root of the repository.

To run the code, create a conda environment, install snakemake in this environment, then run the following command:
To run the uploading workflow for a particular reference model, create a conda
environment snakemake, activate the environment, `cd` into the model's directory, and
run the following command:

```
export HF_API_TOKEN=TOKEN_VAL
snakemake --use-conda --cores all --envvars HF_API_TOKEN
snakemake --forceall --use-conda --cores all --envvars HF_API_TOKEN
```

from the corresponding reference directory's workflow directory.
12 changes: 12 additions & 0 deletions non_small_cell_lung_cancer/Snakefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
envvars:
"HF_API_TOKEN"

rule all:
log:
"logs/stderr.log",
"logs/stdout.log",
conda:
"env.yaml"
threads: 16
script:
"main.py"
18 changes: 18 additions & 0 deletions non_small_cell_lung_cancer/config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
{
"model_url": "https: //zenodo.org/record/7227571/files/core_atlas_scanvi_model.tar.gz",
"known_hash": "60d6c0ccbad89178a359b3bd2f2981638c86b260011d8bd1977c989fbbc5ad7e",
"model_fname": "lung_cancer_scanvi",
"model_path": "full_atlas_hvg_integrated_scvi_scanvi_model",
"adata_path": "full_atlas_hvg_integrated_scvi_integrated_scanvi.h5ad",
"latent_qzm_key": "X_latent_qzm",
"latent_qzv_key": "X_latent_qzv",
"model_dir": "lung_cancer_scanvi_minified",
"citation": "High-resolution single-cell atlas reveals diversity and plasticity of tissue-resident neutrophils in non-small cell lung cancer. S Salcher, G Sturm, L Horvath, G Untergasser, C Kuempers, G Fotakis, E Panizzolo, A Martowicz, M Trebo, G Pall, G Gamerith, M Sykora, F Augustin, K Schmitz, F Finotello, D Rieder, S Perner, S Sopper, D Wolf, A Pircher, Z Trajanoski. Cancer Cell. 2022; 40 (12): 1503-1520.e8. https: //doi.org/10.1016/j.ccell.2022.10.008",
"description": "The single cell lung cancer atlas is a resource integrating more than 1.2 million cells from 309 patients across 29 datasets.",
"tissues": ["lung"],
"training_data_url": "https://zenodo.org/record/7227571/files/core_atlas_scanvi_model.tar.gz",
"training_code_url": "https://github.com/icbi-lab/luca",
"data_modalities": ["rna"],
"license_info": "cc-by-4.0",
"repo_name": "scvi-tools/non_small_cell_lung_cancer"
}
3 changes: 2 additions & 1 deletion non_small_cell_lung_cancer/env.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ channels:
- defaults
dependencies:
- python=3.10
- scvi-tools==0.20.1
- scvi-tools=0.20.2
- huggingface_hub
- anndata>=0.8.0
- scanpy>=1.9.0
- pytorch
Expand Down
114 changes: 114 additions & 0 deletions non_small_cell_lung_cancer/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import json
import os
import pathlib
from pathlib import Path
from typing import Tuple

import anndata as ad
import pooch
import scanpy as sc
import scvi
from scvi.hub import HubMetadata, HubModel, HubModelCardHelper

HF_API_TOKEN = os.environ["HF_API_TOKEN"]


def make_parents(*paths) -> None:
"""Make parent directories of a file path if they do not exist."""
for p in paths:
pathlib.Path(p).parent.mkdir(parents=True, exist_ok=True)


def load_config(config_path: str) -> dict:
"""Load a JSON configuration file as a Python dictionary."""
with open(config_path) as f:
config = json.load(f)
return config


def load_model(config: dict) -> Tuple[scvi.model.SCANVI, ad.AnnData]:
"""Load the model and dataset."""
model_url = config["model_url"]
unzipped = pooch.retrieve(
url=model_url,
fname=config["model_fname"],
known_hash=config["known_hash"],
processor=pooch.Untar(),
)[0]
base_path = Path(unzipped).parent
model_path = os.path.join(base_path, config["model_path"])
adata_path = os.path.join(base_path, config["adata_path"])

adata = sc.read(adata_path)
model = scvi.model.SCANVI.load(model_path, adata=adata)

return model


def minify_model_and_save(model: scvi.model.SCANVI, config: dict):
"""Minify the model and save it to disk."""
latent_qzm_key = config["latent_qzm_key"]
latent_qzv_key = config["latent_qzv_key"]

qzm, qzv = model.get_latent_representation(return_dist=True)
model.adata.obsm[latent_qzm_key] = qzm
model.adata.obsm[latent_qzv_key] = qzv
model.minify_adata(
use_latent_qzm_key=latent_qzm_key, use_latent_qzv_key=latent_qzv_key
)

model_dir = os.path.join("models", config["model_dir"])
make_parents(model_dir)
model.save(model_dir, overwrite=True)


def create_hub_model(config: dict) -> HubModel:
"""Create a HubModel object."""
model_dir = os.path.join("models", config["model_dir"])

metadata = HubMetadata.from_dir(
model_dir,
anndata_version=ad.__version__,
training_data_url=config["training_data_url"],
)

card = HubModelCardHelper.from_dir(
model_dir,
license_info=config["license_info"],
anndata_version=ad.__version__,
data_is_minified=True,
data_is_annotated=True,
tissues=config["tissues"],
training_data_url=config["training_data_url"],
training_code_url=config["training_code_url"],
description=config["description"],
references=config["citation"],
data_modalities=["rna"],
)

return HubModel(model_dir, metadata=metadata, model_card=card)


def upload_hub_model(hubmodel: HubModel, repo_token: str, config: dict):
"""Upload the model to the HuggingFace Hub."""
repo_name = config["repo_name"]
try:
hubmodel.push_to_huggingface_hub(
repo_name=repo_name, repo_token=repo_token, repo_create=True,
)
except Exception as e:
hubmodel.push_to_huggingface_hub(
repo_name=repo_name, repo_token=repo_token, repo_create=False,
)


def main():
config = load_config("config.json")
model = load_model(config)
minify_model_and_save(model, config)
hubmodel = create_hub_model(config)
upload_hub_model(hubmodel, HF_API_TOKEN, config)


if __name__ == "__main__":
main()

0 comments on commit a94bae7

Please sign in to comment.