Skip to content

trVAE training on query data indefinitely hang on "Instantiating Dataset" #258

Open
@li-xuyang28

Description

@li-xuyang28

Dear developers/maintainers,

I have used trVAE a while ago before it became part of scArches. Now I'm trying to use the scArches implementation to annotate a new unlabeled dataset. The training on the reference/source (a bit over 2m cells) was successful (took a bit more than 13h, plateaued and stopped after 226 iters); however, when I'm trying to train on query dataset (~230k cells), it is hanging at "Instantiating Dataset" indefinitely (>24h). I was wondering if you could kindly advise.

The code up to the training on query set:

########### Loading packages #################
import scanpy as sc
import torch
import scarches as sca
from scarches.dataset.trvae.data_handling import remove_sparsity
import matplotlib.pyplot as plt
import numpy as np
import gdown
import pickle
print("Loaded packages...")
import warnings
warnings.filterwarnings('ignore')
import psutil
mem = psutil.virtual_memory()
print(f"Available memory: {mem.available / 1024**3:.2f} GB")

########### Train on reference ###############
print("Read data...")
ref = sc.read_h5ad(REF_DATA_PATH)
early_stopping_kwargs = {
    "early_stopping_metric": "val_unweighted_loss",
    "threshold": 0,
    "patience": 20,
    "reduce_lr": True,
    "lr_patience": 13,
    "lr_factor": 0.1,
}

condition_key = 'donor_id'
source_conditions = ref.obs[condition_key].unique().tolist()
cell_type_key = 'supercluster_term'

trvae = sca.models.TRVAE(
    adata=ref,
    condition_key=condition_key,
    conditions=source_conditions,
    hidden_layer_sizes=[128, 128, 128],
    recon_loss="zinb",
)

trvae.train(
    n_epochs=500,
    alpha_epoch_anneal=200,
    early_stopping_kwargs=early_stopping_kwargs
)
trvae.save(REF_MODEL_PATH, overwrite=True)

########### Train on query ###############
adata = sc.read_h5ad(QUERY_DATA_PATH)

new_trvae = sca.models.TRVAE.load_query_data(adata=adata, reference_model=REF_MODEL_PATH)
new_trvae.train(
    n_epochs=500,
    alpha_epoch_anneal=200,
    early_stopping_kwargs=early_stopping_kwargs,
    weight_decay=0
)

Session info:

gdown    5.2.0
matplotlib 3.9.2
numpy 1.26.1
psutil 6.1.0
scanpy 1.9.6
scarches 0.6.1
session_info 1.0.0
torch 2.5.1+cu124
----
Python 3.9.20 (main, Oct 3 2024, 07:27:41) [GCC 11.2.0]
Linux-5.14.0-427.42.1.el9_4.x86_64-with-glibx2.34

I checked that memory usage was way less than 50% (1500GB mem, gres = 6x Nvidia Tesla T4). I did use mmd and am aware that would slow process down considerably, but would expect at least the progress bar to show up. Any help would be greatly appreciated!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions