Skip to content

Added KITTI dataset #3640

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Apr 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/source/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,13 @@ Kinetics-400
:members: __getitem__
:special-members:

KITTI
~~~~~~~~~

.. autoclass:: Kitti
:members: __getitem__
:special-members:

KMNIST
~~~~~~~~~~~~~

Expand Down
36 changes: 36 additions & 0 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1702,5 +1702,41 @@ def test_classes(self, config):
self.assertSequenceEqual(dataset.classes, info["classes"])


class KittiTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.Kitti
FEATURE_TYPES = (PIL.Image.Image, (list, type(None))) # test split returns None as target
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(train=(True, False))

def inject_fake_data(self, tmpdir, config):
kitti_dir = os.path.join(tmpdir, "Kitti", "raw")
os.makedirs(kitti_dir)

split_to_num_examples = {
True: 1,
False: 2,
}

# We need to create all folders(training and testing).
for is_training in (True, False):
num_examples = split_to_num_examples[is_training]

datasets_utils.create_image_folder(
root=kitti_dir,
name=os.path.join("training" if is_training else "testing", "image_2"),
file_name_fn=lambda image_idx: f"{image_idx:06d}.png",
num_examples=num_examples,
)
if is_training:
for image_idx in range(num_examples):
target_file_dir = os.path.join(kitti_dir, "training", "label_2")
os.makedirs(target_file_dir)
target_file_name = os.path.join(target_file_dir, f"{image_idx:06d}.txt")
target_contents = "Pedestrian 0.00 0 -0.20 712.40 143.00 810.73 307.92 1.89 0.48 1.20 1.84 1.47 8.41 0.01\n" # noqa
with open(target_file_name, "w") as target_file:
target_file.write(target_contents)

return split_to_num_examples[config["train"]]


if __name__ == "__main__":
unittest.main()
4 changes: 3 additions & 1 deletion torchvision/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from .hmdb51 import HMDB51
from .ucf101 import UCF101
from .places365 import Places365
from .kitti import Kitti

__all__ = ('LSUN', 'LSUNClass',
'ImageFolder', 'DatasetFolder', 'FakeData',
Expand All @@ -34,4 +35,5 @@
'VOCSegmentation', 'VOCDetection', 'Cityscapes', 'ImageNet',
'Caltech101', 'Caltech256', 'CelebA', 'WIDERFace', 'SBDataset',
'VisionDataset', 'USPS', 'Kinetics400', 'HMDB51', 'UCF101',
'Places365')
'Places365', 'Kitti',
)
161 changes: 161 additions & 0 deletions torchvision/datasets/kitti.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
import csv
import os
from typing import Any, Callable, List, Optional, Tuple

from PIL import Image

from .utils import download_and_extract_archive
from .vision import VisionDataset


class Kitti(VisionDataset):
"""`KITTI <http://www.cvlibs.net/datasets/kitti>`_ Dataset.

Args:
root (string): Root directory where images are downloaded to.
Expects the following folder structure if download=False:

.. code::

<root>
└── Kitti
└─ raw
├── training
| ├── image_2
| └── label_2
└── testing
└── image_2
train (bool, optional): Use ``train`` split if true, else ``test`` split.
Defaults to ``train``.
transform (callable, optional): A function/transform that takes in a PIL image
and returns a transformed version. E.g, ``transforms.ToTensor``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
transforms (callable, optional): A function/transform that takes input sample
and its target as entry and returns a transformed version.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.

"""

data_url = "https://s3.eu-central-1.amazonaws.com/avg-kitti/"
resources = [
"data_object_image_2.zip",
"data_object_label_2.zip",
]
image_dir_name = "image_2"
labels_dir_name = "label_2"

def __init__(
self,
root: str,
train: bool = True,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
transforms: Optional[Callable] = None,
download: bool = False,
):
super().__init__(
root,
transform=transform,
target_transform=target_transform,
transforms=transforms,
)
self.images = []
self.targets = []
self.root = root
self.train = train
self._location = "training" if self.train else "testing"

if download:
self.download()
if not self._check_exists():
raise RuntimeError(
"Dataset not found. You may use download=True to download it."
)

image_dir = os.path.join(self._raw_folder, self._location, self.image_dir_name)
if self.train:
labels_dir = os.path.join(self._raw_folder, self._location, self.labels_dir_name)
for img_file in os.listdir(image_dir):
self.images.append(os.path.join(image_dir, img_file))
if self.train:
self.targets.append(
os.path.join(labels_dir, f"{img_file.split('.')[0]}.txt")
)

def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""Get item at a given index.

Args:
index (int): Index
Returns:
tuple: (image, target), where
target is a list of dictionaries with the following keys:

- type: str
- truncated: float
- occluded: int
- alpha: float
- bbox: float[4]
- dimensions: float[3]
- locations: float[3]
- rotation_y: float

"""
image = Image.open(self.images[index])
target = self._parse_target(index) if self.train else None
if self.transforms:
image, target = self.transforms(image, target)
return image, target

def _parse_target(self, index: int) -> List:
target = []
with open(self.targets[index]) as inp:
content = csv.reader(inp, delimiter=" ")
for line in content:
target.append({
"type": line[0],
"truncated": float(line[1]),
"occluded": int(line[2]),
"alpha": float(line[3]),
"bbox": [float(x) for x in line[4:8]],
"dimensions": [float(x) for x in line[8:11]],
"location": [float(x) for x in line[11:14]],
"rotation_y": float(line[14]),
})
return target

def __len__(self) -> int:
return len(self.images)

@property
def _raw_folder(self) -> str:
return os.path.join(self.root, self.__class__.__name__, "raw")

def _check_exists(self) -> bool:
"""Check if the data directory exists."""
folders = [self.image_dir_name]
if self.train:
folders.append(self.labels_dir_name)
return all(
os.path.isdir(os.path.join(self._raw_folder, self._location, fname))
for fname in folders
)

def download(self) -> None:
"""Download the KITTI data if it doesn't exist already."""

if self._check_exists():
return

os.makedirs(self._raw_folder, exist_ok=True)

# download files
for fname in self.resources:
download_and_extract_archive(
url=f"{self.data_url}{fname}",
download_root=self._raw_folder,
filename=fname,
)