Skip to content
This repository has been archived by the owner on Jul 2, 2021. It is now read-only.

Change name to CUBPointDataset and add tests #528

Merged
merged 8 commits into from
Apr 18, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion chainercv/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from chainercv.datasets.cityscapes.cityscapes_test_image_dataset import CityscapesTestImageDataset # NOQA
from chainercv.datasets.cityscapes.cityscapes_utils import cityscapes_semantic_segmentation_label_colors # NOQA
from chainercv.datasets.cityscapes.cityscapes_utils import cityscapes_semantic_segmentation_label_names # NOQA
from chainercv.datasets.cub.cub_keypoint_dataset import CUBKeypointDataset # NOQA
from chainercv.datasets.cub.cub_label_dataset import CUBLabelDataset # NOQA
from chainercv.datasets.cub.cub_point_dataset import CUBPointDataset # NOQA
from chainercv.datasets.cub.cub_utils import cub_label_names # NOQA
from chainercv.datasets.directory_parsing_label_dataset import directory_parsing_label_names # NOQA
from chainercv.datasets.directory_parsing_label_dataset import DirectoryParsingLabelDataset # NOQA
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,34 +6,34 @@
from chainercv import utils


class CUBKeypointDataset(CUBDatasetBase):
class CUBPointDataset(CUBDatasetBase):

"""`Caltech-UCSD Birds-200-2011`_ dataset with annotated keypoints.
"""`Caltech-UCSD Birds-200-2011`_ dataset with annotated points.

.. _`Caltech-UCSD Birds-200-2011`:
http://www.vision.caltech.edu/visipedia/CUB-200-2011.html

An index corresponds to each image.

When queried by an index, this dataset returns the corresponding
:obj:`img, keypoint, kp_mask`, a tuple of an image, keypoints
and a keypoint mask that indicates visible keypoints in the image.
:obj:`img, point, mask`, a tuple of an image, points
and a point mask that indicates visible points in the image.
The data type of the three elements are :obj:`float32, float32, bool`.
If :obj:`return_bb = True`, a bounding box :obj:`bb` is appended to the
tuple.
If :obj:`return_prob_map = True`, a probability map :obj:`prob_map` is
appended.

keypoints are packed into a two dimensional array of shape
:math:`(K, 2)`, where :math:`K` is the number of keypoints.
Note that :math:`K=15` in CUB dataset. Also note that not all fifteen
keypoints are visible in an image. When a keypoint is not visible,
the values stored for that keypoint are undefined. The second axis
Points are packed into a two dimensional array of shape
:math:`(P, 2)`, where :math:`P` is the number of points.
Note that :math:`P=15` in CUB dataset. Also note that not all fifteen
points are visible in an image. When a point is not visible,
the coordinates of the point are undefined. The second axis
corresponds to the :math:`y` and :math:`x` coordinates of the
keypoints in the image.
points in the image.

A keypoint mask array indicates whether a keypoint is visible in the
image or not. This is a boolean array of shape :math:`(K,)`.
A point mask array indicates whether a point is visible in the
image or not. This is a boolean array of shape :math:`(P,)`.

A bounding box is a one-dimensional array of shape :math:`(4,)`.
The elements of the bounding box corresponds to
Expand Down Expand Up @@ -67,29 +67,24 @@ class CUBKeypointDataset(CUBDatasetBase):

def __init__(self, data_dir='auto', return_bb=False,
prob_map_dir='auto', return_prob_map=False):
super(CUBKeypointDataset, self).__init__(
super(CUBPointDataset, self).__init__(
data_dir=data_dir, return_bb=return_bb,
prob_map_dir=prob_map_dir, return_prob_map=return_prob_map)

# load keypoint
# load point
parts_loc_file = os.path.join(self.data_dir, 'parts', 'part_locs.txt')
self.kp_dict = collections.OrderedDict()
self.kp_mask_dict = collections.OrderedDict()
self._point_dict = collections.defaultdict(list)
self._point_mask_dict = collections.defaultdict(list)
for loc in open(parts_loc_file):
values = loc.split()
id_ = int(values[0]) - 1

if id_ not in self.kp_dict:
self.kp_dict[id_] = []
if id_ not in self.kp_mask_dict:
self.kp_mask_dict[id_] = []

# (y, x) order
keypoint = [float(v) for v in values[3:1:-1]]
kp_mask = bool(int(values[4]))
point = [float(v) for v in values[3:1:-1]]
point_mask = bool(int(values[4]))

self.kp_dict[id_].append(keypoint)
self.kp_mask_dict[id_].append(kp_mask)
self._point_dict[id_].append(point)
self._point_mask_dict[id_].append(point_mask)

def get_example(self, i):
"""Returns the i-th example.
Expand All @@ -98,32 +93,32 @@ def get_example(self, i):
i (int): The index of the example.

Returns:
tuple of an image, keypoints and a keypoint mask.
tuple of an image, points and a point mask.
The image is in CHW format and its color channel is ordered in
RGB.
If :obj:`return_bb = True`,
a bounding box is appended to the returned value.
If :obj:`return_mask = True`,
If :obj:`return_prob_map = True`,
a probability map is appended to the returned value.

"""
img = utils.read_image(
os.path.join(self.data_dir, 'images', self.paths[i]),
color=True)
keypoint = np.array(self.kp_dict[i], dtype=np.float32)
kp_mask = np.array(self.kp_mask_dict[i], dtype=np.bool)
point = np.array(self._point_dict[i], dtype=np.float32)
point_mask = np.array(self._point_mask_dict[i], dtype=np.bool)

if not self.return_prob_map:
if self.return_bb:
return img, keypoint, kp_mask, self.bbs[i]
return img, point, point_mask, self.bbs[i]
else:
return img, keypoint, kp_mask
return img, point, point_mask

prob_map = utils.read_image(self.prob_map_paths[i],
dtype=np.uint8, color=False)
prob_map = prob_map.astype(np.float32) / 255 # [0, 255] -> [0, 1]
prob_map = prob_map[0] # (1, H, W) --> (H, W)
if self.return_bb:
return img, keypoint, kp_mask, self.bbs[i], prob_map
return img, point, point_mask, self.bbs[i], prob_map
else:
return img, keypoint, kp_mask, prob_map
return img, point, point_mask, prob_map
2 changes: 1 addition & 1 deletion chainercv/visualizations/vis_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def vis_point(img, point, mask=None, ax=None):

>>> import chainercv
>>> import matplotlib.pyplot as plt
>>> dataset = chainercv.datasets.CUBKeypointDataset()
>>> dataset = chainercv.datasets.CUBPointDataset()
>>> img, point, mask = dataset[0]
>>> chainercv.visualizations.vis_point(img, point, mask)
>>> plt.show()
Expand Down
6 changes: 3 additions & 3 deletions docs/source/reference/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ CUBLabelDataset
~~~~~~~~~~~~~~~
.. autoclass:: CUBLabelDataset

CUBKeypointDataset
~~~~~~~~~~~~~~~~~~~
.. autoclass:: CUBKeypointDataset
CUBPointDataset
~~~~~~~~~~~~~~~
.. autoclass:: CUBPointDataset


OnlineProducts
Expand Down
45 changes: 45 additions & 0 deletions tests/datasets_tests/cub_tests/test_cub_point_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import unittest

import numpy as np

from chainer import testing

from chainercv.datasets import CUBPointDataset
from chainercv.testing import attr
from chainercv.utils import assert_is_bbox
from chainercv.utils import assert_is_point_dataset


@testing.parameterize(*testing.product({
'return_bb': [True, False],
'return_prob_map': [True, False]}
))
class TestCUBPointDataset(unittest.TestCase):

def setUp(self):
self.dataset = CUBPointDataset(return_bb=self.return_bb,
return_prob_map=self.return_prob_map)

@attr.slow
@attr.disk
def test_camvid_dataset(self):
assert_is_point_dataset(
self.dataset, n_point=15, n_example=10)

idx = np.random.choice(np.arange(10))
if self.return_bb:
if self.return_prob_map:
bb = self.dataset[idx][-2]
else:
bb = self.dataset[idx][-1]
assert_is_bbox(bb[np.newaxis])
if self.return_prob_map:
img = self.dataset[idx][0]
prob_map = self.dataset[idx][-1]
self.assertEqual(prob_map.dtype, np.float32)
self.assertEqual(prob_map.shape, img.shape[1:])
self.assertTrue(np.min(prob_map) >= 0)
self.assertTrue(np.max(prob_map) <= 1)


testing.run_module(__name__, __file__)