Skip to content

Commit

Permalink
added InriaAIL dataset and datamodule, updated readme
Browse files Browse the repository at this point in the history
  • Loading branch information
isaaccorley committed Aug 27, 2021
1 parent a58877c commit 589b5ec
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 12 deletions.
36 changes: 36 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ pip install 'git+https://github.com/isaaccorley/torchrs.git#egg=torch-rs[train]'
* [RESISC45 - Remote Sensing Image Scene Classification](https://github.com/isaaccorley/torchrs#remote-sensing-image-scene-classification-resisc45)
* [EuroSAT](https://github.com/isaaccorley/torchrs#eurosat)
* [SAT-4-&-SAT-6](https://github.com/isaaccorley/torchrs#eurosat)
* [Inria Aerial Image Labeling - Building Segmentation](https://github.com/isaaccorley/torchrs#inria-aerial-image-labeling)

### PROBA-V Super Resolution

Expand Down Expand Up @@ -596,6 +597,41 @@ dataset.classes
"""
```

### Inria Aerial Image Labeling

<img src="./assets/inria_ail.png" width="950px"></img>

The [Inria Aerial Image Labeling Dataset](https://project.inria.fr/aerialimagelabeling/) is a building segmentation dataset proposed in ["Can semantic labeling methods generalize to any city? the inria aerial image labeling benchmark", Maggiori et al.](https://ieeexplore.ieee.org/document/8127684) of 360 high resolution (0.3m) 5000x5000 RGB imagery extracted from various international GIS services (e.g. [USGS National Map](https://www.usgs.gov/core-science-systems/national-geospatial-program/national-map). The dataset contains imagery from 10 regions around the world (both urban and rural) with train/test sets split into different cities for the purpose of evaluating if models can generalize across dramatically different locations. The dataset was originally used in the [Inria Aerial Image Labeling Dataset Contest](https://project.inria.fr/aerialimagelabeling/contest/) and the test set ground truth masks have not been disclosed.

The dataset can be downloaded (26GB) using `scripts/download_inria_ail.sh` and instantiated below:

```python
from torchrs.transforms import Compose, ToTensor
from torchrs.datasets import InriaAIL

transform = Compose([ToTensor()])

dataset = InriaAIL(
root="path/to/dataset/",
split="train", # or 'test'
transform=transform
)

x = dataset[0]
"""
x: dict(
x: (3, 5000, 5000)
mask: (1, 5000, 5000)
region: str
)
"""

dataset.regions
"""
['austin', 'chicago', 'kitsap', 'tyrol', 'vienna']
"""
```

## Models

* [Multi-Image Super Resolution - RAMS](https://github.com/isaaccorley/torchrs#multi-image-super-resolution---rams)
Expand Down
Binary file added assets/inria_ail.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 2 additions & 1 deletion torchrs/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
from .advance import ADVANCE
from .sat import SAT4, SAT6
from .hrscd import HRSCD
from .inria_ail import InriaAIL


__all__ = [
"PROBAV", "ETCI2021", "RSVQALR", "RSVQAxBEN", "EuroSATRGB", "EuroSATMS",
"RESISC45", "RSICD", "OSCD", "S2Looking", "LEVIRCDPlus", "FAIR1M",
"SydneyCaptions", "UCMCaptions", "S2MTCP", "ADVANCE", "SAT4", "SAT6",
"HRSCD"
"HRSCD", "InriaAIL"
]
45 changes: 35 additions & 10 deletions torchrs/datasets/inria_ail.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import os
import json
import re
from glob import glob
from typing import List, Dict

import torch
import tifffile
import numpy as np
from PIL import Image

from torchrs.transforms import Compose, ToTensor

Expand All @@ -22,25 +22,50 @@ class InriaAIL(torch.utils.data.Dataset):
correspond exactly to those of the color images. In the case of the reference data, the tiles are
single-channel images with values 255 for the building class and 0 for the not building class.'
"""
splits = ["train", "test"]

def __init__(
self,
root: str = ".data/inria_ail",
root: str = ".data/AerialImageDataset",
split: str = "train",
transform: Compose = Compose([ToTensor()]),
):
self.root = root
self.split = split
self.transform = transform
self.images = self.load_images(self.image_root)
self.images = self.load_images(root, split)
self.regions = sorted(list(set(image["region"] for image in self.images)))

@staticmethod
def load_images(path: str) -> List[Dict]:
pass
def load_images(path: str, split: str) -> List[Dict]:
images = sorted(glob(os.path.join(path, split, "images", "*.tif")))
pattern = re.compile("[a-zA-Z]+")
regions = [re.findall(pattern, os.path.basename(image))[0] for image in images]

if split == "train":
targets = sorted(glob(os.path.join(path, split, "gt", "*.tif")))
else:
targets = [None] * len(images)

files = [
dict(image=image, target=target, region=region)
for image, target, region in zip(images, targets, regions)
]
return files

def __len__(self) -> int:
return len(self.images)

def __getitem__(self, idx: int) -> Dict:
image_path, target_path = self.images[idx]["image"], self.images[idx]["target"]
x, y = np.load(image_path), np.load(target_path)
x, y = self.transform([x, y])
return dict(x=x, mask=y)
x = np.array(Image.open(image_path))

if self.split == "train":
y = np.array(Image.open(target_path))
y = np.clip(y, a_min=0, a_max=1)
x, y = self.transform([x, y])
output = dict(x=x, mask=y, region=self.images[idx]["region"])
else:
x = self.transform(x)
output = dict(x=x)

return output
3 changes: 2 additions & 1 deletion torchrs/train/datamodules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
from .advance import ADVANCEDataModule
from .sat import SAT4DataModule, SAT6DataModule
from .hrscd import HRSCDDataModule
from .inria_ail import InriaAILDataModule


__all__ = [
"BaseDataModule", "PROBAVDataModule", "ETCI2021DataModule", "RSVQALRDataModule",
"RSVQAxBENDataModule", "EuroSATRGBDataModule", "EuroSATMSDataModule", "RESISC45DataModule",
"RSICDDataModule", "OSCDDataModule", "S2LookingDataModule", "LEVIRCDPlusDataModule",
"FAIR1MDataModule", "SydneyCaptionsDataModule", "UCMCaptionsDataModule", "S2MTCPDataModule",
"ADVANCEDataModule", "SAT4DataModule", "SAT6DataModule", "HRSCDDataModule"
"ADVANCEDataModule", "SAT4DataModule", "SAT6DataModule", "HRSCDDataModule", "InriaAILDataModule"
]
24 changes: 24 additions & 0 deletions torchrs/train/datamodules/inria_ail.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from typing import Optional

from torchrs.transforms import Compose, ToTensor
from torchrs.datasets.utils import dataset_split
from torchrs.train.datamodules import BaseDataModule
from torchrs.datasets import InriaAIL


class InriaAILDataModule(BaseDataModule):

def __init__(
self,
root: str = ".data/AerialImageDataset",
transform: Compose = Compose([ToTensor()]),
*args, **kwargs
):
super().__init__(*args, **kwargs)
self.root = root
self.transform = transform

def setup(self, stage: Optional[str] = None):
train_dataset = InriaAIL(root=self.root, split="train", transform=self.transform)
self.train_dataset, self.val_dataset = dataset_split(train_dataset, val_pct=self.val_split)
self.test_dataset = InriaAIL(root=self.root, split="test", transform=self.transform)

0 comments on commit 589b5ec

Please sign in to comment.