Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions test/datasets_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@
from collections import defaultdict
from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple, Union

import numpy as np

import PIL
import PIL.Image
import pytest
import torch
import torchvision.datasets
import torchvision.io
from common_utils import disable_console_output, get_tmp_dir
from torchvision.transforms.functional import get_dimensions


__all__ = [
Expand Down Expand Up @@ -748,6 +751,49 @@ def size(idx: int) -> Tuple[int, int, int]:
]


def shape_test_for_stereo_gt_w_mask(
left: PIL.Image.Image, right: PIL.Image.Image, disparity: np.ndarray, valid_mask: np.ndarray
):
left_dims = get_dimensions(left)
right_dims = get_dimensions(right)
c, h, w = left_dims
# check that left and right are the same size
assert left_dims == right_dims
# check general shapes
assert c == 3
assert disparity.ndim == 3
assert disparity.shape == (1, h, w)
# check that valid mask is the same size as the disparity

_, dh, dw = disparity.shape
mh, mw = valid_mask.shape
assert dh == mh
assert dw == mw


def shape_test_for_stereo_gt_no_mask(left: PIL.Image.Image, right: PIL.Image.Image, disparity: np.ndarray):
left_dims = get_dimensions(left)
right_dims = get_dimensions(right)
c, h, w = left_dims
# check that left and right are the same size
assert left_dims == right_dims
# check general shapes
assert c == 3
assert disparity.ndim == 3
assert disparity.shape == (1, h, w)


def shape_test_for_stereo_no_gt(left: PIL.Image.Image, right: PIL.Image.Image, disparity: None):
left_dims = get_dimensions(left)
right_dims = get_dimensions(right)
c, _, _ = left_dims
# check that left and right are the same size
assert left_dims == right_dims
# check general shapes
assert c == 3
assert disparity is None


@requires_lazy_imports("av")
def create_video_file(
root: Union[pathlib.Path, str],
Expand Down
168 changes: 168 additions & 0 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import unittest
import xml.etree.ElementTree as ET
import zipfile
from typing import Union

import datasets_utils
import numpy as np
Expand Down Expand Up @@ -2671,5 +2672,172 @@ def inject_fake_data(self, tmpdir: str, config):
return len(sampled_classes) * num_images_per_class[config["split"]]


class Kitti2012StereoTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.Kitti2012Stereo
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test"))
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)), (np.ndarray, type(None)))

def inject_fake_data(self, tmpdir, config):
kitti_dir = pathlib.Path(tmpdir) / "Kitti2012"
os.makedirs(kitti_dir, exist_ok=True)

split_dir = kitti_dir / (config["split"] + "ing")
os.makedirs(split_dir, exist_ok=True)

num_examples = {"train": 4, "test": 3}.get(config["split"], 0)

datasets_utils.create_image_folder(
root=split_dir,
name="colored_0",
file_name_fn=lambda i: f"{i:06d}_10.png",
num_examples=num_examples,
size=(3, 100, 200),
)
datasets_utils.create_image_folder(
root=split_dir,
name="colored_1",
file_name_fn=lambda i: f"{i:06d}_10.png",
num_examples=num_examples,
size=(3, 100, 200),
)

if config["split"] == "train":
datasets_utils.create_image_folder(
root=split_dir,
name="disp_noc",
file_name_fn=lambda i: f"{i:06d}.png",
num_examples=num_examples,
# Kitti2012 uses a single channel image for disparities
size=(1, 100, 200),
)

return num_examples

def test_train_splits(self):
for split in ["train"]:
with self.create_dataset(split=split) as (dataset, _):
for left, right, disparity, mask in dataset:
assert mask is None
datasets_utils.shape_test_for_stereo_gt_no_mask(left, right, disparity)

def test_test_split(self):
for split in ["test"]:
with self.create_dataset(split=split) as (dataset, _):
for left, right, disparity, mask in dataset:
assert mask is None
datasets_utils.shape_test_for_stereo_no_gt(left, right, disparity)

def test_bad_input(self):
with pytest.raises(ValueError, match="Unknown value 'bad' for argument split"):
with self.create_dataset(split="bad"):
pass


class Kitti2015StereoTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.Kitti2015Stereo
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test"))
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)), (np.ndarray, type(None)))

def inject_fake_data(self, tmpdir, config):
kitti_dir = pathlib.Path(tmpdir) / "Kitti2015"
os.makedirs(kitti_dir, exist_ok=True)

split_dir = kitti_dir / (config["split"] + "ing")
os.makedirs(split_dir, exist_ok=True)

num_examples = {"train": 4, "test": 6}.get(config["split"], 0)

datasets_utils.create_image_folder(
root=split_dir,
name="image_2",
file_name_fn=lambda i: f"{i:06d}_10.png",
num_examples=num_examples,
size=(3, 100, 200),
)
datasets_utils.create_image_folder(
root=split_dir,
name="image_3",
file_name_fn=lambda i: f"{i:06d}_10.png",
num_examples=num_examples,
size=(3, 100, 200),
)

if config["split"] == "train":
datasets_utils.create_image_folder(
root=split_dir,
name="disp_occ_0",
file_name_fn=lambda i: f"{i:06d}.png",
num_examples=num_examples,
# Kitti2015 uses a single channel image for disparities
size=(1, 100, 200),
)

datasets_utils.create_image_folder(
root=split_dir,
name="disp_occ_1",
file_name_fn=lambda i: f"{i:06d}.png",
num_examples=num_examples,
# Kitti2015 uses a single channel image for disparities
size=(1, 100, 200),
)

return num_examples

def test_train_splits(self):
for split in ["train"]:
with self.create_dataset(split=split) as (dataset, _):
for left, right, disparity, mask in dataset:
assert mask is None
datasets_utils.shape_test_for_stereo_gt_no_mask(left, right, disparity)

def test_test_split(self):
for split in ["test"]:
with self.create_dataset(split=split) as (dataset, _):
for left, right, disparity, mask in dataset:
assert mask is None
datasets_utils.shape_test_for_stereo_no_gt(left, right, disparity)

def test_bad_input(self):
with pytest.raises(ValueError, match="Unknown value 'bad' for argument split"):
with self.create_dataset(split="bad"):
pass


class CarlaStereoTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.CarlaStereo
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, None))

@staticmethod
def _create_scene_folders(num_examples: int, root_dir: Union[str, pathlib.Path]):
# make the root_dir if it does not exits
os.makedirs(root_dir, exist_ok=True)

for i in range(num_examples):
scene_dir = pathlib.Path(root_dir) / f"scene_{i}"
os.makedirs(scene_dir, exist_ok=True)
# populate with left right images
datasets_utils.create_image_file(root=scene_dir, name="im0.png", size=(100, 100))
datasets_utils.create_image_file(root=scene_dir, name="im1.png", size=(100, 100))
datasets_utils.make_fake_pfm_file(100, 100, file_name=str(scene_dir / "disp0GT.pfm"))
datasets_utils.make_fake_pfm_file(100, 100, file_name=str(scene_dir / "disp1GT.pfm"))

def inject_fake_data(self, tmpdir, config):
carla_dir = pathlib.Path(tmpdir) / "carla-highres"
os.makedirs(carla_dir, exist_ok=True)

split_dir = pathlib.Path(carla_dir) / "trainingF"
os.makedirs(split_dir, exist_ok=True)

num_examples = 6
self._create_scene_folders(num_examples=num_examples, root_dir=split_dir)

return num_examples

def test_train_splits(self):
with self.create_dataset() as (dataset, _):
for left, right, disparity in dataset:
datasets_utils.shape_test_for_stereo_gt_no_mask(left, right, disparity)


if __name__ == "__main__":
unittest.main()
4 changes: 4 additions & 0 deletions torchvision/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from ._optical_flow import FlyingChairs, FlyingThings3D, HD1K, KittiFlow, Sintel
from ._stereo_matching import CarlaStereo, Kitti2012Stereo, Kitti2015Stereo
from .caltech import Caltech101, Caltech256
from .celeba import CelebA
from .cifar import CIFAR10, CIFAR100
Expand Down Expand Up @@ -105,4 +106,7 @@
"FGVCAircraft",
"EuroSAT",
"RenderedSST2",
"Kitti2012Stereo",
"Kitti2015Stereo",
"CarlaStereo",
)
Loading