Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 101 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,29 @@ This repository provides a PyTorch implementation of our method from the [paper]
```

## Installation

Checkout the repo and set up conda environment:

```bash
conda env create -f environment.yaml
```

Activate the new environment:

```bash
conda activate spoco
```

## Training

This implementation uses `DistributedDataParallel` training. In order to restrict the number of GPUs used for training
use `CUDA_VISIBLE_DEVICES`, e.g. `CUDA_VISIBLE_DEVICES=0 python spoco_train.py ...` will execute training on `GPU:0`.

### CVPPP dataset
We used A1 subset of the [CVPPP2017_LSC challenge](https://competitions.codalab.org/competitions/18405) for training. In order to train with 10% of randomly selected objects, run:
### CVPPP dataset

We used A1 subset of the [CVPPP2017_LSC challenge](https://competitions.codalab.org/competitions/18405) for training. In
order to train with 10% of randomly selected objects, run:

```bash
python spoco_train.py \
--spoco \
Expand All @@ -54,6 +61,7 @@ python spoco_train.py \
```

`CVPPP_ROOT_DIR` is assumed to have the following subdirectories:

```
- train:
- A1:
Expand All @@ -73,11 +81,16 @@ python spoco_train.py \
- ...

```
Since the CVPPP dataset consist of only `training` and `testing` subdirectories, one has to create the train/val split manually using the `training` subdir.

Since the CVPPP dataset consist of only `training` and `testing` subdirectories, one has to create the train/val split
manually using the `training` subdir.

### Cityscapes Dataset
Download the images `leftImg8bit_trainvaltest.zip` and the labels `gtFine_trainvaltest.zip` from the [Cityscapes website](https://www.cityscapes-dataset.com/downloads)

Download the images `leftImg8bit_trainvaltest.zip` and the labels `gtFine_trainvaltest.zip` from
the [Cityscapes website](https://www.cityscapes-dataset.com/downloads)
and extract them into the `CITYSCAPES_ROOT_DIR` of your choice, so it has the following structure:

```
- gtFine:
- train
Expand All @@ -91,20 +104,27 @@ and extract them into the `CITYSCAPES_ROOT_DIR` of your choice, so it has the fo
```

Create random samplings of each class using the [cityscapesampler.py](spoco/datasets/cityscapesampler.py) script:

```bash
python spoco/datasets/cityscapesampler.py --base_dir CITYSCAPES_ROOT_DIR --class_names person rider car truck bus train motorcycle bicycle
```
this will randomly sample 10%, 20%, ..., 90% of objects from the specified class(es) and save the results in dedicated directories,

this will randomly sample 10%, 20%, ..., 90% of objects from the specified class(es) and save the results in dedicated
directories,
e.g. `CITYSCAPES_ROOT_DIR/gtFine/train/darmstadt/car/0.4` will contain random 40% of objects of class `car`.

One can also sample from all of the objects (people, riders, cars, trucks, buses, trains, motorcycles, bicycles) collectively by simply:
One can also sample from all of the objects (people, riders, cars, trucks, buses, trains, motorcycles, bicycles)
collectively by simply:

```bash
python spoco/datasets/cityscapesampler.py --base_dir CITYSCAPES_ROOT_DIR
```

this will randomly sample 10%, 20%, ..., 90% of **all** objects and save the results in dedicated directories,
e.g. `CITYSCAPES_ROOT_DIR/gtFine/train/darmstadt/all/0.4` will contain random 40% of all objects.

In order to train with 40% of randomly selected objects of class `car`, run:

```bash
python spoco_train.py \
--spoco \
Expand All @@ -129,10 +149,13 @@ python spoco_train.py \
--log-after-iters 500 --max-num-iterations 90000
```

In order to train with a random 40% of all ground truth objects, just remove the `--things-class` argument from the command above.
In order to train with a random 40% of all ground truth objects, just remove the `--things-class` argument from the
command above.

## Prediction

Give a model trained on the CVPPP dataset, run the prediction using the following command:

```bash
python spoco_predict.py \
--spoco \
Expand All @@ -143,13 +166,17 @@ python spoco_predict.py \
--model-feature-maps 16 32 64 128 256 512 \
--output-dir OUTPUT_DIR
```

Results will be saved in the given `OUTPUT_DIR` directory. For each test input image `plantXXX_rgb.png` the following
3 output files will be saved in the `OUTPUT_DIR`:
* `plantXXX_rgb_predictions.h5` - HDF5 file with datasets `/raw` (input image), `/embeddings1` (output from the `f` embedding network), `/embeddings2` (output from the `g` momentum contrast network)

* `plantXXX_rgb_predictions.h5` - HDF5 file with datasets `/raw` (input image), `/embeddings1` (output from the `f`
embedding network), `/embeddings2` (output from the `g` momentum contrast network)
* `plantXXX_rgb_predictions_1.png` - output from the `f` embedding network PCA-projected into the RGB-space
* `plantXXX_rgb_predictions_2.png` - output from the `g` momentum contrast network PCA-projected into the RGB-space

And similarly for the Cityscapes dataset
And similarly for the Cityscapes dataset

```bash
python spoco_predict.py \
--spoco \
Expand All @@ -163,8 +190,11 @@ python spoco_predict.py \
```

## Clustering

To produce the final segmentation one needs to cluster the embeddings with and algorithm of choice. Supported
algoritms: mean-shift, HDBSCAN and Consistency Clustering (as described in the paper). E.g. to cluster CVPPP with HDBSCAN, run:
algoritms: mean-shift, HDBSCAN and Consistency Clustering (as described in the paper). E.g. to cluster CVPPP with
HDBSCAN, run:

```bash
python cluster_predictions.py \
--ds-name cvppp \
Expand All @@ -175,7 +205,9 @@ python cluster_predictions.py \
Where `PREDICTION_DIR` is the directory where h5 files containing network predictions are stored. Resulting segmentation
will be saved as a separate dataset (named `segmentation`) inside each of the H5 prediction files.

In order to cluster the Cityscapes predictions and extract the instances of class `car` and compute the segmentation scores on the validation set:
In order to cluster the Cityscapes predictions and extract the instances of class `car` and compute the segmentation
scores on the validation set:

```bash
python cluster_predictions.py \
--ds-name cityscapes \
Expand All @@ -185,5 +217,63 @@ python cluster_predictions.py \
--things-class car \
--clustering msplus --delta-var 0.5 --delta-dist 2.0
```

Where `SEM_PREDICTION_DIR` is the directory containing the semantic segmentation predictions for your validation images.
We used pre-trained DeepLabv3 model from [here](https://github.com/VainF/DeepLabV3Plus-Pytorch).

## Training and inference on MitoEM dataset

Download the MitoEM-R dataset from https://mitoem.grand-challenge.org and split the h5 file containing 500 slices
into training and validation sets: training file should be named `train.h5` and have 400 slices and validation file
should be named `val.h5` and contain 100 slices.

Then create the random 1%, 5%, 10% samplings of instances using the [mitoemsampler.py](spoco/datasets/mitoemsampler.py)
script:

```bash
python spoco/datasets/mitoemsampler.py --dataset_dir MITOEM_ROOT_DIR --instance_ratios 0.01 0.05, 0.1
```

this will create the following additional datasets inside the `MITOEM_ROOT_DIR/train.h5`:

```
- label_0.01
- label_0.05
- label_0.1
```

### Training on MitoEM

In order to train with 1% of randomly selected instances, run:

```bash
python spoco_train.py \
--spoco \
--ds-name mitoem \
--ds-path MITOEM_ROOT_DIR \
--patch-shape 512 512 \
--stride-shape 512 512 \
--instance-ratio 0.01 \
--batch-size 16 \
--model-name UNet2D \
--model-in-channels 1 \
--model-feature-maps 16 32 64 128 256 512 \
--learning-rate 0.0002 \
--weight-decay 0.00001 \
--cos \
--loss-delta-var 0.5 \
--loss-delta-dist 2.0 \
--loss-unlabeled-push 1.0 \
--loss-instance-weight 1.0 \
--loss-consistency-weight 1.0 \
--kernel-threshold 0.5 \
--checkpoint-dir CHECKPOINT_DIR \
--log-after-iters 500 \
--max-num-iterations 100000
```

### Prediction on MitoEM

The prediction scripts converts the embeddings to affinities using the formula defined in the paper (see eq. 12 in
Appendix 4).
TODO
2 changes: 0 additions & 2 deletions environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ channels:
- conda-forge

dependencies:
- python 3.8
- tqdm
- pytorch
- torchvision
Expand All @@ -17,4 +16,3 @@ dependencies:
- scikit-learn
- pyyaml
- hdbscan
- pytest
64 changes: 64 additions & 0 deletions spoco/datasets/mitoemsampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import argparse
from pathlib import Path

import h5py
import numpy as np


def mitoem_sample_instances(label, instance_ratio, random_state):
"""
Sample a fraction of ground truth objects from the label dataset.

Args:
label: np.array, label dataset
instance_ratio: np.array, fraction of ground truth objects to sample
random_state: instance of np.random.RandomState

Returns:
np.array, sampled label dataset
"""
label_img = np.copy(label)
unique_ids = np.unique(label)[1:]
random_state.shuffle(unique_ids)
# pick instance_ratio objects
num_objects = round(instance_ratio * len(unique_ids))
assert num_objects > 0, 'No objects to sample'
print(f'Sampled {num_objects} out of {len(unique_ids)} objects. Instance ratio: {instance_ratio}')
# create a set of object ids left for training
sampled_ids = set(unique_ids[:num_objects])
for id in unique_ids:
if id not in sampled_ids:
label_img[label_img == id] = 0
return label_img


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--dataset_dir', type=str, help='MitoEM dir containing train.h5 and val.h5 files',
required=True)
parser.add_argument('--instance_ratios', nargs="+", type=float,
help='fraction of ground truth objects to sample.', required=True)

args = parser.parse_args()

# load label dataset from the train.h5 file
train_file = Path(args.dataset_dir) / 'train.h5'
assert train_file.exists(), f'{train_file} does not exist'

with h5py.File(train_file, 'r+') as f:
label = f['label'][:]

for instance_ratio in args.instance_ratios:
assert 0.0 <= instance_ratio <= 1.0, 'Instance ratio must be in [0, 1]'

ir = float(instance_ratio)
rs = np.random.RandomState(47)
print(f'Sampling {ir * 100}% of mitoEM instances')

label_sampled = mitoem_sample_instances(label, ir, rs)

dataset_name = f'label_{instance_ratio}'
if dataset_name in f:
del f[dataset_name]
# save the sampled label dataset
f.create_dataset(dataset_name, data=label_sampled, compression='gzip')
27 changes: 27 additions & 0 deletions spoco/datasets/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import collections
from pathlib import Path

import torch
from torch.utils.data import DataLoader

from spoco.datasets.cityscapes import CityscapesDataset
from spoco.datasets.cvppp import CVPPP2017Dataset
from spoco.datasets.volumetric import VolumetricH5Dataset


def create_train_val_loaders(args):
Expand All @@ -25,6 +27,21 @@ def create_train_val_loaders(args):
train_dataset = CityscapesDataset(args.ds_path, phase='train', class_name=args.things_class, spoco=args.spoco,
instance_ratio=args.instance_ratio)
val_dataset = CityscapesDataset(args.ds_path, phase='val', class_name=args.things_class, spoco=args.spoco)
elif args.ds_name == 'mitoem':
ds_path = Path(args.ds_path)
train_file = ds_path / 'train.h5'
val_file = ds_path / 'val.h5'
assert train_file.exists(), f'Training file {train_file} does not exist'
assert val_file.exists(), f'Validation file {val_file} does not exist'
assert len(args.patch_shape) == 3, 'Patch shape must be a 3D tuple'
assert len(args.stride_shape) == 3, 'Stride shape must be a 3D tuple'
assert args.patch_shape[0] == 1, 'Patch shape must have a depth of 1: only 2D patches are supported'
assert args.stride_shape[0] == 1, 'Stride shape must have a depth of 1: only 2D patches are supported'
train_dataset = VolumetricH5Dataset(train_file, phase='train', patch_shape=args.patch_shape,
stride_shape=args.stride_shape, spoco=args.spoco,
instance_ratio=args.instance_ratio)
val_dataset = VolumetricH5Dataset(val_file, phase='val', patch_shape=args.patch_shape,
stride_shape=args.stride_shape, spoco=args.spoco)
else:
raise RuntimeError(f'Unsupported dataset: {args.ds_name}')

Expand All @@ -51,6 +68,16 @@ def create_test_loader(args):
test_dataset = CVPPP2017Dataset(args.ds_path, phase='test', spoco=args.spoco)
elif args.ds_name == 'cityscapes':
test_dataset = CityscapesDataset(args.ds_path, phase='test', class_name=None, spoco=args.spoco)
elif args.ds_name == 'mitoem':
ds_path = Path(args.ds_path)
test_file = ds_path / 'val.h5'
assert test_file.exists(), f'Test file {test_file} does not exist'
assert len(args.patch_shape) == 3, 'Patch shape must be a 3D tuple'
assert len(args.stride_shape) == 3, 'Stride shape must be a 3D tuple'
assert args.patch_shape[0] == 1, 'Patch shape must have a depth of 1: only 2D patches are supported'
assert args.stride_shape[0] == 1, 'Stride shape must have a depth of 1: only 2D patches are supported'
test_dataset = VolumetricH5Dataset(test_file, phase='test', patch_shape=args.patch_shape,
stride_shape=args.stride_shape, spoco=args.spoco)
else:
raise RuntimeError(f'Unsupported dataset {args.ds_name}')

Expand Down
Loading