Description
Dear developers/maintainers,
I have used trVAE a while ago before it became part of scArches. Now I'm trying to use the scArches implementation to annotate a new unlabeled dataset. The training on the reference/source (a bit over 2m cells) was successful (took a bit more than 13h, plateaued and stopped after 226 iters); however, when I'm trying to train on query dataset (~230k cells), it is hanging at "Instantiating Dataset" indefinitely (>24h). I was wondering if you could kindly advise.
The code up to the training on query set:
########### Loading packages #################
import scanpy as sc
import torch
import scarches as sca
from scarches.dataset.trvae.data_handling import remove_sparsity
import matplotlib.pyplot as plt
import numpy as np
import gdown
import pickle
print("Loaded packages...")
import warnings
warnings.filterwarnings('ignore')
import psutil
mem = psutil.virtual_memory()
print(f"Available memory: {mem.available / 1024**3:.2f} GB")
########### Train on reference ###############
print("Read data...")
ref = sc.read_h5ad(REF_DATA_PATH)
early_stopping_kwargs = {
"early_stopping_metric": "val_unweighted_loss",
"threshold": 0,
"patience": 20,
"reduce_lr": True,
"lr_patience": 13,
"lr_factor": 0.1,
}
condition_key = 'donor_id'
source_conditions = ref.obs[condition_key].unique().tolist()
cell_type_key = 'supercluster_term'
trvae = sca.models.TRVAE(
adata=ref,
condition_key=condition_key,
conditions=source_conditions,
hidden_layer_sizes=[128, 128, 128],
recon_loss="zinb",
)
trvae.train(
n_epochs=500,
alpha_epoch_anneal=200,
early_stopping_kwargs=early_stopping_kwargs
)
trvae.save(REF_MODEL_PATH, overwrite=True)
########### Train on query ###############
adata = sc.read_h5ad(QUERY_DATA_PATH)
new_trvae = sca.models.TRVAE.load_query_data(adata=adata, reference_model=REF_MODEL_PATH)
new_trvae.train(
n_epochs=500,
alpha_epoch_anneal=200,
early_stopping_kwargs=early_stopping_kwargs,
weight_decay=0
)
Session info:
gdown 5.2.0
matplotlib 3.9.2
numpy 1.26.1
psutil 6.1.0
scanpy 1.9.6
scarches 0.6.1
session_info 1.0.0
torch 2.5.1+cu124
----
Python 3.9.20 (main, Oct 3 2024, 07:27:41) [GCC 11.2.0]
Linux-5.14.0-427.42.1.el9_4.x86_64-with-glibx2.34
I checked that memory usage was way less than 50% (1500GB mem, gres = 6x Nvidia Tesla T4). I did use mmd and am aware that would slow process down considerably, but would expect at least the progress bar to show up. Any help would be greatly appreciated!