A python package dedicated to classification and exploration of Cell Painting data. This package relies on lightning
for training/evaluation of the DenseNet
model.
- Ensure
Python >=3.11
andconda
are installed on your machine The recommended installer forconda
isminiforge
- Clone this repository
$ git clone https://github.com/jhuapl-bio/DeepPaint.git
- Navigate to the DeepPaint directory (containing the README)
$ cd DeepPaint
- Create a
conda
virtual environment from theenvironment.yml
file and activate it$ conda env create -n <env_name> -f environment.yml $ conda activate <env_name>
- Install the
DeepPaint
package with pip$ pip install .
The DeepPaint
package can be run as a module with the command python -m deep_paint
to invoke the CLI. This is the entry point for training and evaluating models.
Four commands are available:
fit
: Train or finetune a modelvalidate
: Run one evaluation epoch on a validation settest
: Run one test epoch on a test setpredict
: Get predictions from a trained model on part or all of a dataset
These commands correspond to the lightning.pytorch.Trainer methods. All commands can be run with the --config
argument to specify a configuration file.
The configuration files used for training, getting model predictions, and getting model embeddings are available in the configs directory. Ensure to update the paths in the configuration files (they are commented for convenience).
The configuration file is a YAML file that contains all the necessary parameters for training, evaluating, or testing a model. The YAML file is divided into the following fields:
Field | Subclass | Description | Required? |
---|---|---|---|
model | LightningModule |
Model architecture and hyperparameters | ✅ |
data | LightningDataModule |
Data preprocessing and augmentation | ✅ |
trainer | Trainer |
Training arguments | ✅ |
optimizer | Optimizer |
Optimizer | ❌ |
lr_scheduler | LRScheduler |
Learning Rate Scheduler | ❌ |
ckpt_path | N/A | Path to model checkpoint | ❌ |
All fields except trainer
and ckpt_path
require a class_path
parameter. A full path to the class must be provided. Following this parameter, the rest of the field is parsed as keyword arguments to the class constructor via the init_args
parameter.
- Train a model:
python -m deep_paint fit --config /path/to/your_config.yaml
- Run a validation epoch:
python -m deep_paint validate --config /path/to/your_config.yaml
- Run a test epoch:
python -m deep_paint test --config /path/to/your_config.yaml
- Get model predictions:
python -m deep_paint predict --config /path/to/your_config.yaml
A custom script has been created to extract embeddings from a trained model. The script can be run with the following command:
python -m deep_paint.utils.embeddings --config /path/to/your_config.yaml
This config file looks slightly different than the config file used for the four main commands. Refer to the configs directory for examples.
The results directory contains the following subdirectories:
checkpoints
: Contains model checkpointsconfigs
: Contains configuration files used for training, getting model predictions, and getting model embeddingsembeddings
: Contains embeddings extracted from the model on the test set of theRxRx2
datalogs
: Contains csv files extracted fromtensorboard
logsmetadata
: Contains custom metadata used for training theDenseNet
modelpredictions
: Contains model predictions on the test set of theRxRx2
data
The RxRx2
dataset was used for training and evaluation of the DenseNet
model. The dataset is freely available to download from the RxRx.ai
website.
The checkpoints directory contains model checkpoints for the binary and multiclass DenseNet
model. These checkpoints can be used to load the trained models and make predictions.
The notebooks directory contains Jupyter notebooks that demonstrate the performance of the DenseNet
model on the RxRx2
dataset. The notebooks contain visualizations of the model predictions and embeddings.