Skip to content

Can't save SCVI model if using custom datamodule #3684

@deto

Description

@deto

I'm currently experimenting with providing a custom, scDataset-based, datamodule in order to alleviate on-disk dataloading bottlenecking when training on large datasets with scVI. However, It looks like you are unable to save the model if it is initialized without an anndata file.

It seems the issue is that self.adata should be None, but the backing attribute self._adata is never initialized. And even if that is fixed, I believe there will still be issues because self.registry is similarly not initialized when using a custom datamodule.

# My custom datamodule
dm = scvi_datamodule.scDatasetModule(
    ad_paths=ad_paths,
    batch_key="sample_id",
    batch_size=512,
    num_workers=8,
    val_proportion=0.05,
    test_proportion=0.0,
    var_indices=var_indices,
    obs_indices=obs_indices,
)


model = SCVI(adata=None, n_latent=100, gene_likelihood="nb")

model.train(
    max_epochs=1,  # 1 for testing
    datamodule=dm,
    early_stopping=True,
    early_stopping_warmup_epochs=10,
    plan_kwargs=dict(
        n_epochs_kl_warmup=10,
        reduce_lr_on_plateau=True,
        lr_patience=5,
        lr=0.01,
    ),
)

model.save(scvi_model_dir)  # error here
AttributeError                            Traceback (most recent call last)
Cell In[1], line 64
     49 model = SCVI(adata=None, n_latent=100, gene_likelihood="nb")
     50 model.train(
     51     max_epochs=1,
     52     datamodule=dm,
   (...)     62
     63 )
---> 64 model.save(scvi_model_dir)

File ~/projects/2_healthy_brain_screen/20251124_2k_pert/.venv/lib/python3.13/site-packages/scvi/model/base/_base_model.py:769, in BaseModelClass.save(self, dir_path, prefix, overwrite, save_anndata, save_kwargs, legacy_mudata_format, datamodule, **anndata_write_kwargs)
    766 model_state_dict = self.module.state_dict()
    767 model_state_dict["pyro_param_store"] = pyro.get_param_store().get_state()
--> 769 var_names = self.get_var_names(legacy_mudata_format=legacy_mudata_format)
    771 # get all the user attributes
    772 user_attributes = self._get_user_attributes()

File ~/projects/2_healthy_brain_screen/20251124_2k_pert/.venv/lib/python3.13/site-packages/scvi/model/base/_base_model.py:170, in BaseModelClass.get_var_names(self, legacy_mudata_format)
    167 """Variable names of input data."""
    168 from scvi.model.base._save_load import _get_var_names
--> 170 if self.adata:
    171     return _get_var_names(self.adata, legacy_mudata_format=legacy_mudata_format)
    172 else:

File ~/projects/2_healthy_brain_screen/20251124_2k_pert/.venv/lib/python3.13/site-packages/scvi/model/base/_base_model.py:159, in BaseModelClass.adata(self)
    156 @property
    157 def adata(self) -> None | AnnOrMuData:
    158     """Data attached to model instance."""
--> 159     return self._adata

AttributeError: 'SCVI' object has no attribute '_adata'

Versions:

scvi version: 1.4.1

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions