This is an implementation of several unsupervised object discovery models (Slot Attention, SLATE, GNM) in PyTorch.
This repo is in active development. Expect some breaking changes.
The initial code for this repo was forked from untitled-ai/slot_attention.
- Poetry
- Python >= 3.9
- CUDA enabled computing device
- Clone the repo:
git clone https://github.com/HHousen/slot-attention-pytorch/ && cd slot-attention-pytorch
. - Install requirements and activate environment:
poetry install
thenpoetry shell
. - Download the CLEVR (with masks) dataset (or the original CLEVR dataset by running
./data_scripts/download_clevr.sh /tmp/CLEVR
). More details about the datasets are below. - Modify the hyperparameters in object_discovery/params.py to fit your needs. Make sure to change
data_root
to the location of your dataset. - Train a model:
python -m slot_attention.train
.
Code to load these models can be adapted from predict.py.
Model | Dataset | Download |
---|---|---|
Slot Attention | CLEVR6 Masks | Google Drive |
Slot Attention | Sketchy | Google Drive |
GNM | CLEVR6 Masks | Google Drive |
Slot Attention | ClevrTex6 | Google Drive |
GNM | ClevrTex6 | Google Drive |
SLATE | CLEVR6 Masks | Google Drive |
Train a model by running python -m slot_attention.train
.
Hyperparameters can be changed in object_discovery/params.py. training_params
has global parameters that apply to all model types. These parameters can be overridden if the same key is present in slot_attention_params
or slate_params
. Change the global parameter model_type
to sa
to use Slot Attention (SlotAttentionModel
in slot_attention_model.py) or slate
to use SLATE (SLATE
in slate_model.py). This will determine which model's set of parameters will be merged with training_params
.
Perform inference by modifying and running the predict.py script.
Our implementations are based on several open-source repositories.
- Slot Attention ("Object-Centric Learning with Slot Attention"): untitled-ai/slot_attention & Official
- SLATE ("Illiterate DALL-E Learns to Compose"): Official
- GNM ("Generative Neurosymbolic Machines"): karazijal/clevrtex & Official
Select a dataset by changing the dataset
parameter in object_discovery/params.py to the name of the dataset: clevr
, shapes3d
, or ravens
. Then, set the data_root
parameter to the location of the data. The code for loading supported datasets is in object_discovery/data.py.
- CLEVR: Download by executing download_clevr.sh.
- CLEVR (with masks): Original TFRecords Download / Our HDF5 PyTorch Version.
- This dataset is a regenerated version of CLEVR but with ground-truth segmentation masks. This enables the training script to calculate Adjusted Rand Index (ARI) during validation runs.
- The dataset contains 100,000 images with a resolution of 240x320 pixels. The dataloader splits them 70K train, 15K validation, 15k test. Test images are not used by the object_discovery/train.py script.
- We convert the original TFRecords dataset to HDF5 for easy use with PyTorch. This was done using the data_scripts/preprocess_clevr_with_masks.py script, which takes approximately 2 hours to execute depending on your machine.
- 3D Shapes: Official Google Cloud Bucket
- RAVENS Robot Data: Official Train & Official Test
- We generated a dataset similar in structure to CLEVR (with masks) but of robotic images using RAVENS. Our modified version of RAVENS used to generate the dataset is HHousen/ravens.
- The dataset contains 85,002 images split 70,002 train and 15K validation/test.
- Sketchy: Download and process by following directions in applied-ai-lab/genesis / Download Our Processed Version
- Dataset details are in the paper Scaling data-driven robotics with reward sketching and batch reinforcement learning.
- ClevrTex: Download by executing download_clevrtex.sh. Our dataloader needs to index the entire dataset before training can begin. This can take around 2 hours. Thus, it is recommended to download our pre-made index from this Google Drive folder and put it in
./data/cache/
. - Tetrominoes: Original TFRecords Download / Our HDF5 PyTorch Version.
- There are 1,000,000 samples in the dataset. However, following the Slot Attention paper, we only use the first 60K samples for training.
- We convert the original TFRecords dataset to HDF5 for easy use with PyTorch. This was done using the data_scripts/preprocess_tetrominoes.py script, which takes approximately 2 hours to execute depending on your machine.
To log outputs to wandb, run wandb login YOUR_API_KEY
and set is_logging_enabled=True
in SlotAttentionParams
.
If you use a dataset with ground-truth segmentation masks, then the Adjusted Rand Index (ARI), a clustering similarity score, will be logged for each validation loop. We convert the implementation from deepmind/multi_object_datasets to PyTorch in object_discovery/segmentation_metrics.py.
Slot Attention CLEVR10 | Slot Attention Sketchy |
---|---|
Visualizations (above) for a model trained on CLEVR6 predicting on CLEVR10 (with no increase in number of slots) and a model trained and predicting on Sketchy. The order from left to right of the images is original, reconstruction, raw predicted segmentation mask, processed segmentation mask, and then the slots.
Slot Attention ClevrTex6 | GNM ClevrTex6 |
---|---|
The Slot Attention visualization image order is the same as in the above visualizations. For GNM, the order is original, reconstruction, ground truth segmentation mask, prediction segmentation mask (repeated 4 times).
SLATE CLEVR6 | GNM CLEVR6 |
---|---|
For SLATE, the image order is original, dVAE reconstruction, autoregressive reconstruction, and then the pixels each slot pays attention to.
- untitled-ai/slot_attention: An unofficial implementation of Slot Attention from which this repo was forked.
- Slot Attention: Official Code / "Object-Centric Learning with Slot Attention".
- SLATE: Official Code / "Illiterate DALL-E Learns to Compose".
- IODINE: Official Code / "Multi-Object Representation Learning with Iterative Variational Inference". In the Slot Attention paper, IODINE was frequently used for comparison. The IODINE code was helpful to create this repo.
- Multi-Object Datasets: deepmind/multi_object_datasets. This is the original source of the CLEVR (with masks) dataset.
- Implicit Slot Attention: "Object Representations as Fixed Points: Training Iterative Refinement Algorithms with Implicit Differentiation". This paper explains a one-line change that improves the optimization of Slot Attention while simultaneously making backpropagation have constant space and time complexity.