Skip to content
This repository has been archived by the owner on Jul 2, 2021. It is now read-only.

Commit

Permalink
Merge branch 'master' into train-ssd
Browse files Browse the repository at this point in the history
  • Loading branch information
Hakuyume committed Jun 25, 2017
2 parents 7aaf2b7 + 739d3c8 commit ba11474
Show file tree
Hide file tree
Showing 19 changed files with 298 additions and 102 deletions.
2 changes: 1 addition & 1 deletion chainercv/datasets/voc/voc_detection_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class VOCDetectionDataset(chainer.dataset.DatasetMixin):
:math:`(R, 4)`, where :math:`R` is the number of bounding boxes in
the image. The second axis represents attributes of the bounding box.
They are :obj:`(y_min, x_min, y_max, x_max)`, where the
four attributes are coordinates of the bottom left and the top right
four attributes are coordinates of the top left and the bottom right
vertices.
The labels are packed into a one dimensional tensor of shape :math:`(R,)`.
Expand Down
2 changes: 1 addition & 1 deletion chainercv/links/model/segnet/segnet_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,6 @@ def predict(self, imgs):
dtype = score.dtype
score = resize(score, (H, W)).astype(dtype)

label = np.argmax(score, axis=0)
label = np.argmax(score, axis=0).astype(np.int32)
labels.append(label)
return labels
6 changes: 3 additions & 3 deletions chainercv/links/model/ssd/multibox_coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,8 @@ def decode(self, mb_loc, mb_conf, nms_thresh, score_thresh):
label.append(xp.array((l,) * len(bbox_l)))
score.append(score_l)

bbox = xp.vstack(bbox)
label = xp.hstack(label).astype(int)
score = xp.hstack(score)
bbox = xp.vstack(bbox).astype(np.float32)
label = xp.hstack(label).astype(np.int32)
score = xp.hstack(score).astype(np.float32)

return bbox, label, score
12 changes: 6 additions & 6 deletions chainercv/transforms/bbox/crop_bbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,29 @@
def crop_bbox(
bbox, y_slice=None, x_slice=None,
allow_outside_center=True, return_param=False):
"""Crop bounding boxes.
"""Translate bounding boxes to fit within the cropped area of an image.
This method is mainly used together with image cropping.
This method translates the coordinates of bounding boxes like
:func:`~chainercv.transforms.translate_bbox`. In addition,
this function truncates the bounding boxes to fit within the cropping area.
If a bounding box does not overlap with the cropping area,
this function truncates the bounding boxes to fit within the cropped area.
If a bounding box does not overlap with the cropped area,
this bounding box will be removed.
The bounding boxes are expected to be packed into a two dimensional
tensor of shape :math:`(R, 4)`, where :math:`R` is the number of
bounding boxes in the image. The second axis represents attributes of
the bounding box. They are :obj:`(y_min, x_min, y_max, x_max)`,
where the four attributes are coordinates of the bottom left and the
top right vertices.
where the four attributes are coordinates of the top left and the
bottom right vertices.
Args:
bbox (~numpy.ndarray): Bounding boxes to be transformed. The shape is
:math:`(R, 4)`. :math:`R` is the number of bounding boxes.
y_slice (slice): The slice of y axis.
x_slice (slice): The slice of x axis.
allow_outside_center (bool): If this argument is :obj:`False`,
bounding boxes whose centers are outside of the cropping area
bounding boxes whose centers are outside of the cropped area
are removed. The default value is :obj:`True`.
return_param (bool): If :obj:`True`, this function returns
indices of kept bounding boxes.
Expand Down
4 changes: 2 additions & 2 deletions chainercv/transforms/bbox/flip_bbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ def flip_bbox(bbox, size, y_flip=False, x_flip=False):
tensor of shape :math:`(R, 4)`, where :math:`R` is the number of
bounding boxes in the image. The second axis represents attributes of
the bounding box. They are :obj:`(y_min, x_min, y_max, x_max)`,
where the four attributes are coordinates of the bottom left and the
top right vertices.
where the four attributes are coordinates of the top left and the
bottom right vertices.
Args:
bbox (~numpy.ndarray): An array whose shape is :math:`(R, 4)`.
Expand Down
4 changes: 2 additions & 2 deletions chainercv/transforms/bbox/resize_bbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ def resize_bbox(bbox, in_size, out_size):
tensor of shape :math:`(R, 4)`, where :math:`R` is the number of
bounding boxes in the image. The second axis represents attributes of
the bounding box. They are :obj:`(y_min, x_min, y_max, x_max)`,
where the four attributes are coordinates of the bottom left and the
top right vertices.
where the four attributes are coordinates of the top left and the
bottom right vertices.
Args:
bbox (~numpy.ndarray): An array whose shape is :math:`(R, 4)`.
Expand Down
4 changes: 2 additions & 2 deletions chainercv/transforms/bbox/translate_bbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ def translate_bbox(bbox, y_offset=0, x_offset=0):
tensor of shape :math:`(R, 4)`, where :math:`R` is the number of
bounding boxes in the image. The second axis represents attributes of
the bounding box. They are :obj:`(y_min, x_min, y_max, x_max)`,
where the four attributes are coordinates of the bottom left and the
top right vertices.
where the four attributes are coordinates of the top left and the
bottom right vertices.
Args:
bbox (~numpy.ndarray): Bounding boxes to be transformed. The shape is
Expand Down
2 changes: 2 additions & 0 deletions chainercv/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
from chainercv.utils.iterator import unzip # NOQA
from chainercv.utils.testing import assert_is_bbox # NOQA
from chainercv.utils.testing import assert_is_detection_dataset # NOQA
from chainercv.utils.testing import assert_is_detection_link # NOQA
from chainercv.utils.testing import assert_is_image # NOQA
from chainercv.utils.testing import assert_is_semantic_segmentation_dataset # NOQA
from chainercv.utils.testing import assert_is_semantic_segmentation_link # NOQA
from chainercv.utils.testing import ConstantStubLink # NOQA
from chainercv.utils.testing import generate_random_bbox # NOQA
26 changes: 18 additions & 8 deletions chainercv/utils/download.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from __future__ import division
from __future__ import print_function

import hashlib
import os
import shutil
Expand All @@ -13,22 +15,28 @@
import time

from chainer.dataset.download import get_dataset_directory
from chainer.dataset.download import get_dataset_root


def _reporthook(count, block_size, total_size):
global start_time
if count == 0:
start_time = time.time()
print(' % Total Recv Speed Time left')
return
duration = time.time() - start_time
progress_size = int(count * block_size)
progress_size = count * block_size
try:
speed = int(progress_size / (1024 * duration))
speed = progress_size / duration
except ZeroDivisionError:
speed = float('inf')
percent = int(count * block_size * 100 / total_size)
sys.stdout.write('\r...{}, {} MB, {} KB/s, {} seconds passed'.format(
percent, progress_size / (1024 * 1024), speed, duration))
percent = progress_size / total_size * 100
eta = int((total_size - progress_size) / speed)
sys.stdout.write(
'\r{:3.0f} {:4.0f}MiB {:4.0f}MiB {:6.0f}KiB/s {:4d}:{:02d}:{:02d}'
.format(
percent, total_size / (1 << 20), progress_size / (1 << 20),
speed / (1 << 10), eta // 60 // 60, (eta // 60) % 60, eta % 60))
sys.stdout.flush()


Expand All @@ -51,12 +59,12 @@ def cached_download(url):
str: Path to the downloaded file.
"""
cache_root = get_dataset_directory('_dl_cache')
cache_root = os.path.join(get_dataset_root(), '_dl_cache')
try:
os.makedirs(cache_root)
except OSError:
if not os.path.exists(cache_root):
raise RuntimeError('cannot create download cache directory')
raise

lock_path = os.path.join(cache_root, '_dl_lock')
urlhash = hashlib.md5(url.encode('utf-8')).hexdigest()
Expand All @@ -69,7 +77,9 @@ def cached_download(url):
temp_root = tempfile.mkdtemp(dir=cache_root)
try:
temp_path = os.path.join(temp_root, 'dl')
print('Downloading from {}...'.format(url))
print('Downloading ...')
print('From: {:s}'.format(url))
print('To: {:s}'.format(cache_path))
request.urlretrieve(url, temp_path, _reporthook)
with filelock.FileLock(lock_path):
shutil.move(temp_path, cache_path)
Expand Down
2 changes: 2 additions & 0 deletions chainercv/utils/testing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from chainercv.utils.testing.assertions import assert_is_bbox # NOQA
from chainercv.utils.testing.assertions import assert_is_detection_dataset # NOQA
from chainercv.utils.testing.assertions import assert_is_detection_link # NOQA
from chainercv.utils.testing.assertions import assert_is_image # NOQA
from chainercv.utils.testing.assertions import assert_is_semantic_segmentation_dataset # NOQA
from chainercv.utils.testing.assertions import assert_is_semantic_segmentation_link # NOQA
from chainercv.utils.testing.constant_stub_link import ConstantStubLink # NOQA
from chainercv.utils.testing.generate_random_bbox import generate_random_bbox # NOQA
2 changes: 2 additions & 0 deletions chainercv/utils/testing/assertions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from chainercv.utils.testing.assertions.assert_is_bbox import assert_is_bbox # NOQA
from chainercv.utils.testing.assertions.assert_is_detection_dataset import assert_is_detection_dataset # NOQA
from chainercv.utils.testing.assertions.assert_is_detection_link import assert_is_detection_link # NOQA
from chainercv.utils.testing.assertions.assert_is_image import assert_is_image # NOQA
from chainercv.utils.testing.assertions.assert_is_semantic_segmentation_dataset import assert_is_semantic_segmentation_dataset # NOQA
from chainercv.utils.testing.assertions.assert_is_semantic_segmentation_link import assert_is_semantic_segmentation_link # NOQA
59 changes: 59 additions & 0 deletions chainercv/utils/testing/assertions/assert_is_detection_link.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import numpy as np
import six

from chainercv.utils.testing.assertions.assert_is_bbox import assert_is_bbox


def assert_is_detection_link(link, n_fg_class):
"""Checks if a link satisfies detection link APIs.
This function checks if a given link satisfies detection link APIs
or not.
If the link does not satifiy the APIs, this function raises an
:class:`AssertionError`.
Args:
link: A link to be checked.
n_fg_class (int): The number of foreground classes.
"""

imgs = [
np.random.randint(0, 256, size=(3, 480, 640)).astype(np.float32),
np.random.randint(0, 256, size=(3, 480, 320)).astype(np.float32)]

result = link.predict(imgs)
assert len(result) == 3, \
'Link must return three elements: bboxes, labels and scores.'
bboxes, labels, scores = result

assert len(bboxes) == len(imgs), \
'The length of bboxes must be same as that of imgs.'
assert len(labels) == len(imgs), \
'The length of labels must be same as that of imgs.'
assert len(scores) == len(imgs), \
'The length of scores must be same as that of imgs.'

for bbox, label, score in six.moves.zip(bboxes, labels, scores):
assert_is_bbox(bbox)

assert isinstance(label, np.ndarray), \
'label must be a numpy.ndarray.'
assert label.dtype == np.int32, \
'The type of label must be numpy.int32.'
assert label.shape[1:] == (), \
'The shape of label must be (*,).'
assert len(label) == len(bbox), \
'The length of label must be same as that of bbox.'
if len(label) > 0:
assert label.min() >= 0 and label.max() < n_fg_class, \
'The value of label must be in [0, n_fg_class - 1].'

assert isinstance(score, np.ndarray), \
'score must be a numpy.ndarray.'
assert score.dtype == np.float32, \
'The type of score must be numpy.float32.'
assert score.shape[1:] == (), \
'The shape of score must be (*,).'
assert len(score) == len(bbox), \
'The length of score must be same as that of bbox.'
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import numpy as np
import six


def assert_is_semantic_segmentation_link(link, n_class):
"""Checks if a link satisfies semantic segmentation link APIs.
This function checks if a given link satisfies semantic segmentation link
APIs or not.
If the link does not satifiy the APIs, this function raises an
:class:`AssertionError`.
Args:
link: A link to be checked.
n_class (int): The number of classes including background.
"""

imgs = [
np.random.randint(0, 256, size=(3, 480, 640)).astype(np.float32),
np.random.randint(0, 256, size=(3, 480, 320)).astype(np.float32)]

labels = link.predict(imgs)
assert len(labels) == len(imgs), \
'The length of labels must be same as that of imgs.'

for img, label in six.moves.zip(imgs, labels):
assert isinstance(label, np.ndarray), \
'label must be a numpy.ndarray.'
assert label.dtype == np.int32, \
'The type of label must be numpy.int32.'
assert label.shape == img.shape[1:], \
'The shape of label must be (H, W).'
assert label.min() >= 0 and label.max() < n_class, \
'The value of label must be in [0, n_class - 1].'
16 changes: 12 additions & 4 deletions docs/source/reference/utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -57,19 +57,27 @@ Testing Utilities

assert_is_bbox
~~~~~~~~~~~~~~
.. autofunctions:: assert_is_bbox
.. autofunction:: assert_is_bbox

assert_is_detection_dataset
~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunctions:: assert_is_detection_dataset
.. autofunction:: assert_is_detection_dataset

assert_is_detection_link
~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: assert_is_detection_link

assert_is_image
~~~~~~~~~~~~~~~
.. autofunctions:: assert_is_image
.. autofunction:: assert_is_image

assert_is_semantic_segmentation_dataset
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunctions:: assert_is_semantic_segmentation_dataset
.. autofunction:: assert_is_semantic_segmentation_dataset

assert_is_semantic_segmentation_link
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: assert_is_semantic_segmentation_link

ConstantStubLink
~~~~~~~~~~~~~~~~
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from chainer import testing
from chainer.testing import attr

from chainercv.utils import assert_is_detection_link

from dummy_faster_rcnn import DummyFasterRCNN


Expand Down Expand Up @@ -58,39 +60,13 @@ def test_call_gpu(self):
self.link.to_gpu()
self.check_call()

def check_predict(self):
imgs = [
_random_array(np, (3, 640, 480)),
_random_array(np, (3, 320, 320))]

bboxes, labels, scores = self.link.predict(imgs)

self.assertEqual(len(bboxes), len(imgs))
self.assertEqual(len(labels), len(imgs))
self.assertEqual(len(scores), len(imgs))

for bbox, label, score in zip(bboxes, labels, scores):
self.assertIsInstance(bbox, np.ndarray)
self.assertEqual(bbox.dtype, np.float32)
self.assertEqual(bbox.ndim, 2)
self.assertLessEqual(bbox.shape[0], self.n_roi)
self.assertEqual(bbox.shape[1], 4)

self.assertIsInstance(label, np.ndarray)
self.assertEqual(label.dtype, np.int32)
self.assertEqual(label.shape, (bbox.shape[0],))

self.assertIsInstance(score, np.ndarray)
self.assertEqual(score.dtype, np.float32)
self.assertEqual(score.shape, (bbox.shape[0],))

def test_predict_cpu(self):
self.check_predict()
assert_is_detection_link(self.link, self.n_class - 1)

@attr.gpu
def test_predict_gpu(self):
self.link.to_gpu()
self.check_predict()
assert_is_detection_link(self.link, self.n_class - 1)


@testing.parameterize(
Expand Down
Loading

0 comments on commit ba11474

Please sign in to comment.