diff --git a/chainercv/links/__init__.py b/chainercv/links/__init__.py index 9f1bc25836..ad253412ea 100644 --- a/chainercv/links/__init__.py +++ b/chainercv/links/__init__.py @@ -14,6 +14,7 @@ from chainercv.links.model.fpn.faster_rcnn_fpn_resnet import MaskRCNNFPNResNet101 # NOQA from chainercv.links.model.fpn.faster_rcnn_fpn_resnet import MaskRCNNFPNResNet50 # NOQA from chainercv.links.model.light_head_rcnn.light_head_rcnn_resnet101 import LightHeadRCNNResNet101 # NOQA +from chainercv.links.model.mobilenet import MobileNetV2 # NOQA from chainercv.links.model.resnet import ResNet101 # NOQA from chainercv.links.model.resnet import ResNet152 # NOQA from chainercv.links.model.resnet import ResNet50 # NOQA diff --git a/chainercv/links/model/mobilenet/__init__.py b/chainercv/links/model/mobilenet/__init__.py new file mode 100644 index 0000000000..b6ffbb5718 --- /dev/null +++ b/chainercv/links/model/mobilenet/__init__.py @@ -0,0 +1,4 @@ +from chainercv.links.model.mobilenet.expanded_conv_2d import ExpandedConv2D # NOQA +from chainercv.links.model.mobilenet.mobilenet_v2 import MobileNetV2 # NOQA +from chainercv.links.model.mobilenet.tf_conv_2d_bn_activ import TFConv2DBNActiv # NOQA +from chainercv.links.model.mobilenet.tf_convolution_2d import TFConvolution2D # NOQA diff --git a/chainercv/links/model/mobilenet/expanded_conv_2d.py b/chainercv/links/model/mobilenet/expanded_conv_2d.py new file mode 100644 index 0000000000..47403572c0 --- /dev/null +++ b/chainercv/links/model/mobilenet/expanded_conv_2d.py @@ -0,0 +1,95 @@ +import chainer +from chainer.functions import clipped_relu + +from chainercv.links.model.mobilenet.tf_conv_2d_bn_activ import TFConv2DBNActiv +from chainercv.links.model.mobilenet.util import expand_input_by_factor + + +class ExpandedConv2D(chainer.Chain): + """An expanded convolution 2d layer + + in --> expand conv (pointwise conv) --> depthwise conv --> + project conv (pointwise conv) --> out + + Args: + in_channels (int): The number of channels of the input array. + out_channels (int): The number of channels of the output array. + expand_pad (int, tuple of ints, 'SAME' or 'VALID'): + Pad of expand conv filter application. + depthwise_stride (int or tuple of ints): + Stride of depthwise conv filter application. + depthwise_ksize (int or tuple of ints): + Kernel size of depthwise conv filter application. + depthwise_pad (int, tuple of ints, 'SAME' or 'VALID'): + Pad of depthwise conv filter application. + project_pad (int, tuple of ints, 'SAME' or 'VALID'): + Pad of project conv filter application. + initialW (callable): Initial weight value used in + the convolutional layers. + bn_kwargs (dict): Keyword arguments passed to initialize + :class:`chainer.links.BatchNormalization`. + """ + + def __init__(self, + in_channels, + out_channels, + expansion_size=expand_input_by_factor(6), + expand_pad='SAME', + depthwise_stride=1, + depthwise_ksize=3, + depthwise_pad='SAME', + project_pad='SAME', + initialW=None, + bn_kwargs={}): + super(ExpandedConv2D, self).__init__() + with self.init_scope(): + if callable(expansion_size): + self.inner_size = expansion_size(num_inputs=in_channels) + else: + self.inner_size = expansion_size + + def relu6(x): + return clipped_relu(x, 6.) + if self.inner_size > in_channels: + self.expand = TFConv2DBNActiv( + in_channels, + self.inner_size, + ksize=1, + pad=expand_pad, + nobias=True, + initialW=initialW, + bn_kwargs=bn_kwargs, + activ=relu6) + depthwise_in_channels = self.inner_size + else: + depthwise_in_channels = in_channels + self.depthwise = TFConv2DBNActiv( + depthwise_in_channels, + self.inner_size, + ksize=depthwise_ksize, + stride=depthwise_stride, + pad=depthwise_pad, + nobias=True, + initialW=initialW, + groups=depthwise_in_channels, + bn_kwargs=bn_kwargs, + activ=relu6) + self.project = TFConv2DBNActiv( + self.inner_size, + out_channels, + ksize=1, + pad=project_pad, + nobias=True, + initialW=initialW, + bn_kwargs=bn_kwargs, + activ=None) + + def __call__(self, x): + h = x + if hasattr(self, "expand"): + h = self.expand(x) + h = self.depthwise(h) + h = self.project(h) + if h.shape == x.shape: + h += x + return h diff --git a/chainercv/links/model/mobilenet/mobilenet_v2.py b/chainercv/links/model/mobilenet/mobilenet_v2.py new file mode 100644 index 0000000000..6691df20e5 --- /dev/null +++ b/chainercv/links/model/mobilenet/mobilenet_v2.py @@ -0,0 +1,239 @@ +import numpy as np + +import chainer +from chainer.functions import average_pooling_2d +from chainer.functions import clipped_relu +from chainer.functions import softmax +from chainer.functions import squeeze + +from chainercv.links.model.mobilenet.expanded_conv_2d import ExpandedConv2D +from chainercv.links.model.mobilenet.tf_conv_2d_bn_activ import TFConv2DBNActiv +from chainercv.links.model.mobilenet.tf_convolution_2d import TFConvolution2D +from chainercv.links.model.mobilenet.util import _make_divisible +from chainercv.links.model.mobilenet.util import expand_input_by_factor +from chainercv.links.model.pickable_sequential_chain import \ + PickableSequentialChain +from chainercv import utils + + +""" +Implementation of Mobilenet V2, converting the weights from the pretrained +Tensorflow model from +https://github.com/tensorflow/models/tree/master/research/slim/nets/mobilenet +This MobileNetV2 implementation is based on @alexisVallet's one. +@okdshin modified it for ChainerCV. +""" + + +def _depth_multiplied_output_channels(base_out_channels, + multiplier, + divisable_by=8, + min_depth=8): + return _make_divisible(base_out_channels * multiplier, divisable_by, + min_depth) + + +_tf_mobilenetv2_mean = np.asarray( + [128] * 3, dtype=np.float)[:, np.newaxis, np.newaxis] + +# RGB order +_imagenet_mean = np.array( + [123.68, 116.779, 103.939], dtype=np.float32)[:, np.newaxis, np.newaxis] + + +class MobileNetV2(PickableSequentialChain): + """MobileNetV2 Network. + + This is a pickable sequential link. + The network can choose output layers from set of all + intermediate layers. + The attribute :obj:`pick` is the names of the layers that are going + to be picked by :meth:`__call__`. + The attribute :obj:`layer_names` is the names of all layers + that can be picked. + + Examples: + + >>> model = MobileNetV2() + # By default, __call__ returns a probability score (after Softmax). + >>> prob = model(imgs) + >>> model.pick = 'expanded_conv_5' + # This is layer expanded_conv_5. + >>> expanded_conv_5 = model(imgs) + >>> model.pick = ['expanded_conv_5', 'conv1'] + >>> # These are layers expanded_conv_5 and conv1 (before Pool). + >>> expanded_conv_5, conv1 = model(imgs) + + .. seealso:: + :class:`chainercv.links.model.PickableSequentialChain` + + When :obj:`pretrained_model` is the path of a pre-trained chainer model + serialized as a :obj:`.npz` file in the constructor, this chain model + automatically initializes all the parameters with it. + When a string in the prespecified set is provided, a pretrained model is + loaded from weights distributed on the Internet. + The list of pretrained models supported are as follows: + + * :obj:`imagenet`: Loads weights trained with ImageNet. \ + When :obj:`arch=='tf'`, the weights distributed \ + at tensorflow/models + ``_ \ # NOQA + are used. + + Args: + n_class (int): The number of classes. If :obj:`None`, + the default values are used. + If a supported pretrained model is used, + the number of classes used to train the pretrained model + is used. Otherwise, the number of classes in ILSVRC 2012 dataset + is used. + pretrained_model (string): The destination of the pre-trained + chainer model serialized as a :obj:`.npz` file. + If this is one of the strings described + above, it automatically loads weights stored under a directory + :obj:`$CHAINER_DATASET_ROOT/pfnet/chainercv/models/`, + where :obj:`$CHAINER_DATASET_ROOT` is set as + :obj:`$HOME/.chainer/dataset` unless you specify another value + by modifying the environment variable. + mean (numpy.ndarray): A mean value. If :obj:`None`, + the default values are used. + If a supported pretrained model is used, + the mean value used to train the pretrained model is used. + Otherwise, the mean value used by TF's implementation is used. + initialW (callable): Initializer for the weights. + initial_bias (callable): Initializer for the biases. + + """ + # Batch normalization replicating default tensorflow slim parameters + # as used in the original tensorflow implementation. + _bn_tf_default_params = { + "decay": 0.999, + "eps": 0.001, + "dtype": chainer.config.dtype + } + + _models = { + 'tf': { + 1.0: { + 'imagenet': { + 'param': { + 'n_class': 1001, # first element is background + 'mean': _tf_mobilenetv2_mean, + }, + 'overwritable': ('mean',), + 'url': + 'https://chainercv-models.preferred.jp/mobilenet_v2_depth_multiplier_1.0_imagenet_converted_2019_05_13.npz', # NOQA + } + }, + 1.4: { + 'imagenet': { + 'param': { + 'n_class': 1001, # first element is background + 'mean': _tf_mobilenetv2_mean, + }, + 'overwritable': ('mean',), + 'url': + 'https://chainercv-models.preferred.jp/mobilenet_v2_depth_multiplier_1.4_imagenet_converted_2019_05_13.npz', # NOQA + } + } + } + } + + def __init__(self, + n_class=None, + pretrained_model=None, + mean=None, + initialW=None, + initial_bias=None, + arch='tf', + depth_multiplier=1., + bn_kwargs=_bn_tf_default_params, + thousand_categories_mode=False): + if depth_multiplier <= 0: + raise ValueError('depth_multiplier must be greater than 0') + + param, path = utils.prepare_pretrained_model({ + 'n_class': n_class, + 'mean': mean, + }, pretrained_model, self._models[arch][depth_multiplier], { + 'n_class': 1000, + 'mean': _imagenet_mean, + }) + self.mean = param['mean'] + self.n_class = param['n_class'] + + super(MobileNetV2, self).__init__() + + def relu6(x): + return clipped_relu(x, 6.) + with self.init_scope(): + conv_out_channels = _depth_multiplied_output_channels( + 32, depth_multiplier) + self.conv = TFConv2DBNActiv( + in_channels=3, + out_channels=conv_out_channels, + stride=2, + ksize=3, + nobias=True, + activ=relu6, + initialW=initialW, + bn_kwargs=bn_kwargs) + expanded_out_channels = _depth_multiplied_output_channels( + 16, depth_multiplier) + self.expanded_conv = ExpandedConv2D( + expansion_size=expand_input_by_factor(1, divisible_by=1), + in_channels=conv_out_channels, + out_channels=expanded_out_channels, + initialW=initialW, + bn_kwargs=bn_kwargs) + in_channels = expanded_out_channels + out_channels_list = (24, ) * 2 + (32, ) * 3 + (64, ) * 4 + ( + 96, ) * 3 + (160, ) * 3 + (320, ) + for i, out_channels in enumerate(out_channels_list): + layer_id = i + 1 + if layer_id in (1, 3, 6, 13): + stride = 2 + else: + stride = 1 + multiplied_out_channels = _depth_multiplied_output_channels( + out_channels, depth_multiplier) + setattr(self, "expanded_conv_{}".format(layer_id), + ExpandedConv2D( + in_channels=in_channels, + out_channels=multiplied_out_channels, + depthwise_stride=stride, + initialW=initialW, + bn_kwargs=bn_kwargs)) + in_channels = multiplied_out_channels + if depth_multiplier < 1: + conv1_out_channels = 1280 + else: + conv1_out_channels = _depth_multiplied_output_channels( + 1280, depth_multiplier) + self.conv1 = TFConv2DBNActiv( + in_channels=in_channels, + out_channels=conv1_out_channels, + ksize=1, + nobias=True, + initialW=initialW, + activ=relu6, + bn_kwargs=bn_kwargs) + self.global_average_pool = \ + lambda x: average_pooling_2d(x, ksize=x.shape[2:4], stride=1) + self.logits_conv = TFConvolution2D( + in_channels=conv1_out_channels, + out_channels=self.n_class, + ksize=1, + nobias=False, # bias is needed + initialW=initialW, + initial_bias=initial_bias, + ) + self.squeeze = lambda x: squeeze(x, axis=(2, 3)) + self.softmax = softmax + + if path: + chainer.serializers.load_npz(path, self) + + if thousand_categories_mode and 1000 < n_class: + self.logits_conv.W.data = np.delete(self.logits_conv.W.data, 0, 0) + self.logits_conv.b.data = np.delete(self.logits_conv.b.data, 0) diff --git a/chainercv/links/model/mobilenet/tf_conv_2d_bn_activ.py b/chainercv/links/model/mobilenet/tf_conv_2d_bn_activ.py new file mode 100644 index 0000000000..67e59821b9 --- /dev/null +++ b/chainercv/links/model/mobilenet/tf_conv_2d_bn_activ.py @@ -0,0 +1,130 @@ +import chainer +from chainer.functions import relu +from chainer.links import BatchNormalization + +from chainercv.links.model.mobilenet.tf_convolution_2d import TFConvolution2D + +try: + from chainermn.links import MultiNodeBatchNormalization +except ImportError: + pass + + +class TFConv2DBNActiv(chainer.Chain): + """TFConvolution2D --> Batch Normalization --> Activation + + This is a chain that sequentially applies a two-dimensional convolution, + a batch normalization and an activation. + This chain is similar to :class:`chainer.links.Conv2DBNActiv`, + but this uses TFConvolution2D instead of Convolution2D. + Especially, `pad` is different from it. + + The arguments are the same as that of + :class:`chainer.links.Convolution2D` + except for :obj:`activ` and :obj:`bn_kwargs`. + :obj:`bn_kwargs` can include :obj:`comm` key and a communicator of + ChainerMN as the value to use + :class:`chainermn.links.MultiNodeBatchNormalization`. If + :obj:`comm` is not included in :obj:`bn_kwargs`, + :class:`chainer.links.BatchNormalization` link from Chainer is used. + Note that the default value for the :obj:`nobias` + is changed to :obj:`True`. + + Example: + + There are several ways to initialize a :class:`TFConv2DBNActiv`. + + 1. Give the first three arguments explicitly: + + >>> l = TFConv2DBNActiv(5, 10, 3) + + 2. Omit :obj:`in_channels` or fill it with :obj:`None`: + + In these ways, attributes are initialized at runtime based on + the channel size of the input. + + >>> l = TFConv2DBNActiv(10, 3) + >>> l = TFConv2DBNActiv(None, 10, 3) + + Args: + in_channels (int or None): Number of channels of input arrays. + If :obj:`None`, parameter initialization will be deferred until the + first forward data pass at which time the size will be determined. + out_channels (int): Number of channels of output arrays. + ksize (int or tuple of ints): Size of filters (a.k.a. kernels). + :obj:`ksize=k` and :obj:`ksize=(k, k)` are equivalent. + stride (int or tuple of ints): Stride of filter applications. + :obj:`stride=s` and :obj:`stride=(s, s)` are equivalent. + pad (int, tuple of ints, 'SAME' or 'VALID'): Spatial padding width for + input arrays. :obj:`pad=p` and :obj:`pad=(p, p)` are equivalent. + dilate (int or tuple of ints): Dilation factor of filter applications. + :obj:`dilate=d` and :obj:`dilate=(d, d)` are equivalent. + groups (int): The number of groups to use grouped convolution. The + default is one, where grouped convolution is not used. + nobias (bool): If :obj:`True`, + then this link does not use the bias term. + initialW (callable): Initial weight value. If :obj:`None`, the default + initializer is used. + May also be a callable that takes :obj:`numpy.ndarray` or + :obj:`cupy.ndarray` and edits its value. + initial_bias (callable): Initial bias value. If :obj:`None`, the bias + is set to 0. + May also be a callable that takes :obj:`numpy.ndarray` or + :obj:`cupy.ndarray` and edits its value. + activ (callable): An activation function. The default value is + :func:`chainer.functions.relu`. If this is :obj:`None`, + no activation is applied (i.e. the activation is the identity + function). + bn_kwargs (dict): Keyword arguments passed to initialize + :class:`chainer.links.BatchNormalization`. If a ChainerMN + communicator (:class:`~chainermn.communicators.CommunicatorBase`) + is given with the key :obj:`comm`, + :obj:`~chainermn.links.MultiNodeBatchNormalization` will be used + for the batch normalization. Otherwise, + :obj:`~chainer.links.BatchNormalization` will be used. + + """ + + def __init__(self, + in_channels, + out_channels, + ksize=None, + stride=1, + pad='SAME', + dilate=1, + groups=1, + nobias=True, + initialW=None, + initial_bias=None, + activ=relu, + bn_kwargs={}): + if ksize is None: + out_channels, ksize, in_channels = in_channels, out_channels, None + + self.activ = activ + super(TFConv2DBNActiv, self).__init__() + with self.init_scope(): + self.conv = TFConvolution2D( + in_channels, + out_channels, + ksize, + stride, + pad, + nobias, + initialW, + initial_bias, + dilate=dilate, + groups=groups) + if 'comm' in bn_kwargs: + self.bn = MultiNodeBatchNormalization(out_channels, + **bn_kwargs) + else: + self.bn = BatchNormalization(out_channels, **bn_kwargs) + + def __call__(self, x): + h = self.conv(x) + h = self.bn(h) + if self.activ is None: + return h + else: + return self.activ(h) diff --git a/chainercv/links/model/mobilenet/tf_convolution_2d.py b/chainercv/links/model/mobilenet/tf_convolution_2d.py new file mode 100644 index 0000000000..90218525ac --- /dev/null +++ b/chainercv/links/model/mobilenet/tf_convolution_2d.py @@ -0,0 +1,91 @@ +import numpy as np + +import chainer +from chainer.functions import pad +from chainer.links import Convolution2D +from chainer.utils import conv + + +def _pair(x): + if hasattr(x, '__getitem__'): + return x + return x, x + + +def _get_pad(in_size, ksize, stride, tf_padding): + if tf_padding == 'SAME': + tf_out_size = int(np.ceil(float(in_size) / stride)) + elif tf_padding == 'VALID': + tf_out_size = int(np.ceil(float(in_size - ksize + 1) / stride)) + pad = int(stride * tf_out_size - in_size + ksize - stride) + assert conv.get_conv_outsize(in_size + pad, ksize, stride, + 0) == tf_out_size + return pad + + +def _tf_padding(x, ksize, stride, tf_padding): + pad_2 = _get_pad(x.shape[2], ksize[0], stride[0], tf_padding) + pad_3 = _get_pad(x.shape[3], ksize[1], stride[1], tf_padding) + if pad_2 or pad_3: + return pad( + x, + ((0, 0), (0, 0), (pad_2 // 2, int(np.ceil(float(pad_2) / 2))), + (pad_3 // 2, int(np.ceil(float(pad_3) / 2)))), + mode='constant') + else: + return x + + +class TFConvolution2D(chainer.Chain): + """Tensorflow compatible Convolution2D + + This is a Convolution2D chain that imitates Tensorflow's tf.nn.conv2d. + + The arguments are the same as that of + :class:`chainer.links.Convolution2D` except for `pad`. + :obj:`pad` can be set TF's "SAME" or "VALID" in addition to integer value. + If integer value is set, + this chain is equal to :class:`chainer.links.Convolution2D`. + """ + + def __init__(self, + in_channels, + out_channels, + ksize=None, + stride=1, + pad='SAME', + nobias=False, + initialW=None, + initial_bias=None, + **kwargs): + super(TFConvolution2D, self).__init__() + if ksize is None: + out_channels, ksize, in_channels = in_channels, out_channels, None + + if pad in ('SAME', 'VALID'): # TF compatible pad + self.padding = lambda x: _tf_padding(x, _pair(self.conv.ksize), + _pair(self.conv.stride), pad) + conv_pad = 0 + else: + self.padding = None + assert isinstance(pad, int) + conv_pad = pad + + with self.init_scope(): + self.conv = Convolution2D(in_channels, out_channels, ksize, stride, + conv_pad, nobias, initialW, initial_bias, + **kwargs) + + @property + def W(self): + return self.conv.W + + @property + def b(self): + return self.conv.b + + def forward(self, x): + if self.padding is None: + return self.conv(x) + else: + return self.conv(self.padding(x)) diff --git a/chainercv/links/model/mobilenet/util.py b/chainercv/links/model/mobilenet/util.py new file mode 100644 index 0000000000..c9f66be7ca --- /dev/null +++ b/chainercv/links/model/mobilenet/util.py @@ -0,0 +1,13 @@ +# utility functions taken straight-up from the source project. +def _make_divisible(v, divisor, min_value=None): + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +def expand_input_by_factor(n, divisible_by=8): + return lambda num_inputs: _make_divisible(num_inputs * n, divisible_by) diff --git a/examples/classification/README.md b/examples/classification/README.md index 048cfe6979..0490d1ebb7 100644 --- a/examples/classification/README.md +++ b/examples/classification/README.md @@ -8,6 +8,8 @@ Single crop error rates of the models with the weights converted from Caffe weig | Model | Top 1 | Original Top 1 | |:-:|:-:|:-:| +| MobileNetV2-1.0 | 28.3 % | 28.0 % [6] | +| MobileNetV2-1.4 | 24.3 % | 25.3 % [6] | | VGG16 | 29.0 % | 28.5 % [1] | | ResNet50 (`arch=he`) | 24.8 % | 24.7 % [2] | | ResNet101 (`arch=he`) | 23.6 % | 23.6 % [2] | @@ -22,6 +24,8 @@ Ten crop error rate. | Model | Top 1 | Original Top 1 | |:-:|:-:|:-:| +| MobileNetV2-1.0 | 25.6 % | | +| MobileNetV2-1.4 | 22.4 % | | | VGG16 | 27.1 % | | | ResNet50 (`arch=he`) | 23.0 % | 22.9 % [2] | | ResNet101 (`arch=he`) | 21.8 % | 21.8 % [2] | @@ -37,7 +41,7 @@ The results can be reproduced by the following command. These scores are obtained using OpenCV backend. If Pillow is used, scores would differ. ``` -$ python eval_imagenet.py [--model vgg16|resnet50|resnet101|resnet152|se-resnet50|se-resnet101|se-resnet152] [--pretrained-model ] [--batchsize ] [--gpu ] [--crop center|10] +$ python eval_imagenet.py [--model mobilenet_v2|vgg16|resnet50|resnet101|resnet152|se-resnet50|se-resnet101|se-resnet152] [--pretrained-model ] [--batchsize ] [--gpu ] [--crop center|10] ``` ### Trained model @@ -104,3 +108,4 @@ The ImageNet Large Scale Visual Recognition Challenge (ILSVRC) dataset has 1000 3. Jie Hu, Li Shen, Gang Sun. "Squeeze-and-Excitation Networks" CVPR 2018 4. https://github.com/hujie-frank/SENet 5. Priya Goyal, Piotr Dollár, Ross Girshick, Pieter Noordhuis, Lukasz Wesolowski, Aapo Kyrola, Andrew Tulloch, Yangqing Jia, Kaiming He. "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour" https://arxiv.org/abs/1706.02677 +6. Mark Sandler, Andrew Howard, Menglong Zhu, Andrey Zhmoginov, Liang-Chieh Chen. "MobileNetV2: Inverted Residuals and Linear Bottlenecks" https://arxiv.org/abs/1801.04381 diff --git a/examples/classification/eval_imagenet.py b/examples/classification/eval_imagenet.py index 32a06ca793..959dcaa600 100644 --- a/examples/classification/eval_imagenet.py +++ b/examples/classification/eval_imagenet.py @@ -9,6 +9,7 @@ from chainercv.datasets import directory_parsing_label_names from chainercv.datasets import DirectoryParsingLabelDataset from chainercv.links import FeaturePredictor +from chainercv.links import MobileNetV2 from chainercv.links import ResNet101 from chainercv.links import ResNet152 from chainercv.links import ResNet50 @@ -35,6 +36,8 @@ 'se-resnet152': (SEResNet152, {}, 32, 'center', None), 'se-resnext50': (SEResNeXt50, {}, 32, 'center', None), 'se-resnext101': (SEResNeXt101, {}, 32, 'center', None), + 'mobilenet_v2_1.0': (MobileNetV2, {}, 32, 'center', None), + 'mobilenet_v2_1.4': (MobileNetV2, {}, 32, 'center', None) } diff --git a/examples/mobilenet/README.md b/examples/mobilenet/README.md new file mode 100644 index 0000000000..a2d3cb19c3 --- /dev/null +++ b/examples/mobilenet/README.md @@ -0,0 +1,13 @@ +# MobileNet + +For evaluation, please go to [`examples/classification`](https://github.com/chainer/chainercv/tree/master/examples/classification). + +## Convert TensorFlow model +Convert TensorFlow's `*.ckpt` to `*.npz`. + +``` +$ python tfckpt2npz.py mobilenetv2 .ckpt .npz +``` + +The pretrained `.ckpt` for mobilenet can be downloaded from here. +https://github.com/tensorflow/models/tree/master/research/slim/nets/mobilenet diff --git a/examples/mobilenet/tfckpt2npz.py b/examples/mobilenet/tfckpt2npz.py new file mode 100644 index 0000000000..89673f4060 --- /dev/null +++ b/examples/mobilenet/tfckpt2npz.py @@ -0,0 +1,188 @@ +import argparse + +import chainer +import tensorflow as tf + +from chainercv.links import MobileNetV2 + + +def load_expanded_conv(econv, expand_params, depthwise_params, project_params): + if hasattr(econv, 'expand'): + assert expand_params is not None + c_to_p = [(econv.expand.conv, econv.expand.bn, expand_params)] + else: + assert expand_params is None + c_to_p = [] + c_to_p.extend([(econv.depthwise.conv, econv.depthwise.bn, + depthwise_params), (econv.project.conv, econv.project.bn, + project_params)]) + + for conv, bn, params in c_to_p: + init_conv_with_tf_weights(conv, params["weights"]) + init_bn_with_tf_params(bn, params["beta"], params["gamma"], + params["moving_mean"], + params["moving_variance"]) + + +def init_conv_with_tf_weights(conv, weights, bias=None): + # Shifting input and output channel dimensions. + weights = weights.transpose((3, 2, 0, 1)) + if conv.W.shape != weights.shape: # for depthwise conv + weights = weights.transpose((1, 0, 2, 3)) + conv.W.data[:] = weights.data[:] + + if bias is not None: + conv.b.data[:] = bias.data[:] + + +def init_bn_with_tf_params(bn, beta, gamma, moving_mean, moving_variance): + beta = beta.flatten().astype(chainer.config.dtype) + bn.beta.initializer = chainer.initializers.Constant( + beta, dtype=chainer.config.dtype) + bn.beta.initialize(shape=beta.shape) + gamma = gamma.flatten().astype(chainer.config.dtype) + bn.gamma.initializer = chainer.initializers.Constant( + gamma, dtype=chainer.config.dtype) + bn.gamma.initialize(shape=gamma.shape) + bn.avg_mean = moving_mean.flatten().astype(chainer.config.dtype) + bn.avg_var = moving_variance.flatten().astype(chainer.config.dtype) + + +def get_tensor(ckpt_reader, name, ema_ratio=0.999): + if (name + '/ExponentialMovingAverage' + ) in ckpt_reader.get_variable_to_shape_map().keys(): + base = ckpt_reader.get_tensor(name) + ema = ckpt_reader.get_tensor(name + '/ExponentialMovingAverage') + + return (1.0 - ema_ratio) * base + ema_ratio * ema + else: + return ckpt_reader.get_tensor(name) + + +def load_mobilenetv2_from_tensorflow_checkpoint(model, checkpoint_filename): + ckpt_reader = tf.train.NewCheckpointReader(checkpoint_filename) + + # Loading weights for the expanded convolutions. + tf_scope_to_expanded_conv = { + "MobilenetV2/expanded_conv": model.expanded_conv, + } + for i in range(16): + tf_scope_to_expanded_conv["MobilenetV2/expanded_conv_{}".format( + i + 1)] = getattr(model, "expanded_conv_{}".format(i + 1)) + for tf_scope, expanded_conv in tf_scope_to_expanded_conv.items(): + print("Loading weights for %s" % tf_scope) + # Expand convolution parameters + if hasattr(expanded_conv, 'expand'): + expand_params = { + "weights": + get_tensor(ckpt_reader, tf_scope + '/expand/weights'), + "beta": + get_tensor(ckpt_reader, tf_scope + '/expand/BatchNorm/beta'), + "gamma": + get_tensor(ckpt_reader, tf_scope + '/expand/BatchNorm/gamma'), + "moving_mean": + get_tensor(ckpt_reader, + tf_scope + '/expand/BatchNorm/moving_mean'), + "moving_variance": + get_tensor(ckpt_reader, + tf_scope + '/expand/BatchNorm/moving_variance') + } + else: + print("Skipping expanded convolution for {}".format(tf_scope)) + expand_params = None + # Depthwise convolution parameters + depthwise_params = { + "weights": + get_tensor(ckpt_reader, tf_scope + '/depthwise/depthwise_weights'), + "beta": + get_tensor(ckpt_reader, tf_scope + '/depthwise/BatchNorm/beta'), + "gamma": + get_tensor(ckpt_reader, tf_scope + '/depthwise/BatchNorm/gamma'), + "moving_mean": + get_tensor(ckpt_reader, + tf_scope + '/depthwise/BatchNorm/moving_mean'), + "moving_variance": + get_tensor(ckpt_reader, + tf_scope + '/depthwise/BatchNorm/moving_variance') + } + + # Project convolution parameters + project_params = { + "weights": + get_tensor(ckpt_reader, tf_scope + '/project/weights'), + "beta": + get_tensor(ckpt_reader, tf_scope + '/project/BatchNorm/beta'), + "gamma": + get_tensor(ckpt_reader, tf_scope + '/project/BatchNorm/gamma'), + "moving_mean": + get_tensor(ckpt_reader, + tf_scope + '/project/BatchNorm/moving_mean'), + "moving_variance": + get_tensor(ckpt_reader, + tf_scope + '/project/BatchNorm/moving_variance'), + } + load_expanded_conv( + expanded_conv, + expand_params=expand_params, + depthwise_params=depthwise_params, + project_params=project_params, + ) + # Similarly loading the vanilla convolutions. + # Initial convolution + init_conv_with_tf_weights( + model.conv.conv, + weights=get_tensor(ckpt_reader, 'MobilenetV2/Conv/weights')) + init_bn_with_tf_params( + model.conv.bn, + beta=get_tensor(ckpt_reader, 'MobilenetV2/Conv/BatchNorm/beta'), + gamma=get_tensor(ckpt_reader, 'MobilenetV2/Conv/BatchNorm/gamma'), + moving_mean=get_tensor(ckpt_reader, + 'MobilenetV2/Conv/BatchNorm/moving_mean'), + moving_variance=get_tensor( + ckpt_reader, 'MobilenetV2/Conv/BatchNorm/moving_variance')) + # Final convolution before dropout (conv1) + init_conv_with_tf_weights( + model.conv1.conv, + weights=get_tensor(ckpt_reader, 'MobilenetV2/Conv_1/weights')) + init_bn_with_tf_params( + model.conv1.bn, + beta=get_tensor(ckpt_reader, 'MobilenetV2/Conv_1/BatchNorm/beta'), + gamma=get_tensor(ckpt_reader, 'MobilenetV2/Conv_1/BatchNorm/gamma'), + moving_mean=get_tensor(ckpt_reader, + 'MobilenetV2/Conv_1/BatchNorm/moving_mean'), + moving_variance=get_tensor( + ckpt_reader, 'MobilenetV2/Conv_1/BatchNorm/moving_variance')) + # Logits convolution + init_conv_with_tf_weights( + model.logits_conv, + weights=get_tensor(ckpt_reader, + 'MobilenetV2/Logits/Conv2d_1c_1x1/weights'), + bias=get_tensor(ckpt_reader, + 'MobilenetV2/Logits/Conv2d_1c_1x1/biases')) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + 'model_name', choices=('mobilenetv2', ), default='mobilenetv2') + parser.add_argument('pretrained_model') + parser.add_argument('--n-class', type=int, default=1001) + parser.add_argument('--depth-multiplier', type=float, default=1.0) + parser.add_argument('output', nargs='?', default=None) + args = parser.parse_args() + + model = MobileNetV2(args.n_class, depth_multiplier=args.depth_multiplier) + load_mobilenetv2_from_tensorflow_checkpoint(model, args.pretrained_model) + + if args.output is None: + output = '{}_{}_imagenet_convert.npz'.format(args.model_name, + args.depth_multiplier) + else: + output = args.output + model.conv.conv.W.array /= 255.0 # scaling [0, 255] -> [0, 1.0] + chainer.serializers.save_npz(output, model) + print("output: ", output) + + +if __name__ == '__main__': + main() diff --git a/tests/links_tests/model_tests/mobilenet_tests/__init__.py b/tests/links_tests/model_tests/mobilenet_tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/links_tests/model_tests/mobilenet_tests/test_expanded_conv_2d.py b/tests/links_tests/model_tests/mobilenet_tests/test_expanded_conv_2d.py new file mode 100644 index 0000000000..0e2a769926 --- /dev/null +++ b/tests/links_tests/model_tests/mobilenet_tests/test_expanded_conv_2d.py @@ -0,0 +1,107 @@ +import unittest + +import numpy as np + +import chainer +from chainer.backends import cuda +from chainer import testing +from chainermn import create_communicator + +from chainercv.links.model.mobilenet import ExpandedConv2D +from chainercv.utils.testing import attr + + +@testing.parameterize(*testing.product({ + 'expansion_size': [1, 2, 3], +})) +class TestExpandedConv2D(unittest.TestCase): + in_channels = 1 + out_channels = 1 + expand_pad = 'SAME' + depthwise_ksize = 3 + depthwise_pad = 'SAME' + depthwise_stride = 1 + project_pad = 'SAME' + + def setUp(self): + self.x = np.random.uniform( + -1, 1, (5, self.in_channels, 5, 5)).astype(np.float32) + self.gy = np.random.uniform( + -1, 1, (5, self.out_channels, 5, 5)).astype(np.float32) + + # Convolution is the identity function. + expand_initialW = np.ones(( + self.expansion_size, self.in_channels), + dtype=np.float32).reshape( + (self.expansion_size, self.in_channels, 1, 1)) + depthwise_initialW = np.array([[0, 0, 0], [0, 1, 0], [0, 0, 0]]*self.expansion_size, + dtype=np.float32).reshape((self.expansion_size, 1, 3, 3)) + project_initialW = np.ones( + (self.out_channels, self.expansion_size), + dtype=np.float32).reshape( + (self.out_channels, self.expansion_size, 1, 1)) + bn_kwargs = {'decay': 0.8} + self.l = ExpandedConv2D( + self.in_channels, self.out_channels, expansion_size=self.expansion_size, + expand_pad=self.expand_pad, depthwise_stride=self.depthwise_stride, + depthwise_ksize=self.depthwise_ksize, depthwise_pad=self.depthwise_pad, + project_pad=self.project_pad, bn_kwargs=bn_kwargs) + if self.expansion_size > self.in_channels: + self.l.expand.conv.W.array = expand_initialW + self.l.depthwise.conv.W.array = depthwise_initialW + self.l.project.conv.W.array = project_initialW + + def check_forward(self, x_data): + x = chainer.Variable(x_data) + # Make the batch normalization to be the identity function. + if self.expansion_size != 1: + self.l.expand.bn.avg_var[:] = 1 + self.l.expand.bn.avg_mean[:] = 0 + self.l.depthwise.bn.avg_var[:] = 1 + self.l.depthwise.bn.avg_mean[:] = 0 + self.l.project.bn.avg_var[:] = 1 + self.l.project.bn.avg_mean[:] = 0 + with chainer.using_config('train', False): + y = self.l(x) + + self.assertIsInstance(y, chainer.Variable) + self.assertIsInstance(y.array, self.l.xp.ndarray) + + _x_data = x_data + if self.expansion_size > self.in_channels: + np.testing.assert_almost_equal( + cuda.to_cpu(y.array), cuda.to_cpu(_x_data)+self.expansion_size * + np.maximum(np.minimum(cuda.to_cpu(_x_data), 6), 0), + decimal=4 + ) + else: + np.testing.assert_almost_equal( + cuda.to_cpu(y.array), cuda.to_cpu(_x_data) + + np.maximum(np.minimum(cuda.to_cpu(_x_data), 6), 0), + decimal=4 + ) + + def test_forward_cpu(self): + self.check_forward(self.x) + + @attr.gpu + def test_forward_gpu(self): + self.l.to_gpu() + self.check_forward(cuda.to_gpu(self.x)) + + def check_backward(self, x_data, y_grad): + x = chainer.Variable(x_data) + y = self.l(x) + y.grad = y_grad + y.backward() + + def test_backward_cpu(self): + self.check_backward(self.x, self.gy) + + @attr.gpu + def test_backward_gpu(self): + self.l.to_gpu() + self.check_backward(cuda.to_gpu(self.x), cuda.to_gpu(self.gy)) + + +testing.run_module(__name__, __file__) diff --git a/tests/links_tests/model_tests/mobilenet_tests/test_mobilenet.py b/tests/links_tests/model_tests/mobilenet_tests/test_mobilenet.py new file mode 100644 index 0000000000..e389bcb14a --- /dev/null +++ b/tests/links_tests/model_tests/mobilenet_tests/test_mobilenet.py @@ -0,0 +1,88 @@ +import unittest + +import numpy as np + +from chainer.testing import attr +from chainer import Variable + +from chainercv.links import MobileNetV2 +from chainercv.utils import testing + + +@testing.parameterize(*( + testing.product_dict( + [ + {'pick': 'softmax', 'shapes': (1, 200), 'n_class': 200}, + {'pick': 'conv1', + 'shapes': (1, 1280, 7, 7), 'n_class': None}, + {'pick': ['expanded_conv_2', 'conv'], + 'shapes': ((1, 24, 56, 56), (1, 32, 112, 112)), 'n_class': None}, + ], + [ + {'model_class': MobileNetV2}, + ], + [ + {'arch': 'tf'}, + ] + ) +)) +class TestMobileNetCall(unittest.TestCase): + + def setUp(self): + self.link = self.model_class( + n_class=self.n_class, pretrained_model=None, arch=self.arch) + self.link.pick = self.pick + + def check_call(self): + xp = self.link.xp + + x = Variable(xp.asarray(np.random.uniform( + -1, 1, (1, 3, 224, 224)).astype(np.float32))) + features = self.link(x) + if isinstance(features, tuple): + for activation, shape in zip(features, self.shapes): + self.assertEqual(activation.shape, shape) + else: + self.assertEqual(features.shape, self.shapes) + self.assertEqual(features.dtype, np.float32) + + @attr.slow + def test_call_cpu(self): + self.check_call() + + @attr.gpu + @attr.slow + def test_call_gpu(self): + self.link.to_gpu() + self.check_call() + + +@testing.parameterize(*testing.product({ + 'model': [MobileNetV2], + 'n_class': [None, 500, 1001], + 'pretrained_model': ['imagenet'], + 'mean': [None, np.random.uniform((3, 1, 1)).astype(np.float32)], + 'arch': ['tf'], +})) +class TestMobileNetPretrained(unittest.TestCase): + + @attr.slow + def test_pretrained(self): + kwargs = { + 'n_class': self.n_class, + 'pretrained_model': self.pretrained_model, + 'mean': self.mean, + 'arch': self.arch, + } + + if self.pretrained_model == 'imagenet': + valid = self.n_class in {None, 1001} + + if valid: + self.model(**kwargs) + else: + with self.assertRaises(ValueError): + self.model(**kwargs) + + +testing.run_module(__name__, __file__) diff --git a/tests/links_tests/model_tests/mobilenet_tests/test_tf_conv_2d_bn_activ.py b/tests/links_tests/model_tests/mobilenet_tests/test_tf_conv_2d_bn_activ.py new file mode 100644 index 0000000000..3f4cd52ddf --- /dev/null +++ b/tests/links_tests/model_tests/mobilenet_tests/test_tf_conv_2d_bn_activ.py @@ -0,0 +1,190 @@ +import unittest + +import numpy as np + +import chainer +from chainer.backends import cuda +from chainer.functions import relu +from chainer import testing +from chainermn import create_communicator + +from chainercv.links.model.mobilenet import TFConv2DBNActiv +from chainercv.utils.testing import attr + + +def _add_one(x): + return x + 1 + + +@testing.parameterize(*testing.product({ + 'dilate': [1, 2], + 'pad': [1, 'SAME'], + 'args_style': ['explicit', 'None', 'omit'], + 'activ': ['relu', 'add_one', None], +})) +class TestTFConv2DBNActiv(unittest.TestCase): + + in_channels = 1 + out_channels = 1 + ksize = 3 + stride = 1 + pad = 1 + + def setUp(self): + if self.activ == 'relu': + activ = relu + elif self.activ == 'add_one': + activ = _add_one + elif self.activ is None: + activ = None + self.x = np.random.uniform( + -1, 1, (5, self.in_channels, 5, 5)).astype(np.float32) + self.gy = np.random.uniform( + -1, 1, (5, self.out_channels, 5, 5)).astype(np.float32) + + # Convolution is the identity function. + initialW = np.array([[0, 0, 0], [0, 1, 0], [0, 0, 0]], + dtype=np.float32).reshape((1, 1, 3, 3)) + bn_kwargs = {'decay': 0.8} + initial_bias = 0 + if self.args_style == 'explicit': + self.l = TFConv2DBNActiv( + self.in_channels, self.out_channels, self.ksize, + self.stride, self.pad, self.dilate, + initialW=initialW, initial_bias=initial_bias, + activ=activ, bn_kwargs=bn_kwargs) + elif self.args_style == 'None': + self.l = TFConv2DBNActiv( + None, self.out_channels, self.ksize, self.stride, self.pad, + self.dilate, initialW=initialW, initial_bias=initial_bias, + activ=activ, bn_kwargs=bn_kwargs) + elif self.args_style == 'omit': + self.l = TFConv2DBNActiv( + self.out_channels, self.ksize, stride=self.stride, + pad=self.pad, dilate=self.dilate, initialW=initialW, + initial_bias=initial_bias, activ=activ, bn_kwargs=bn_kwargs) + + def check_forward(self, x_data): + x = chainer.Variable(x_data) + # Make the batch normalization to be the identity function. + self.l.bn.avg_var[:] = 1 + self.l.bn.avg_mean[:] = 0 + with chainer.using_config('train', False): + y = self.l(x) + + self.assertIsInstance(y, chainer.Variable) + self.assertIsInstance(y.array, self.l.xp.ndarray) + + if self.dilate == 1: + _x_data = x_data + elif self.dilate == 2: + _x_data = x_data[:, :, 1:-1, 1:-1] + if self.activ == 'relu': + np.testing.assert_almost_equal( + cuda.to_cpu(y.array), np.maximum(cuda.to_cpu(_x_data), 0), + decimal=4 + ) + elif self.activ == 'add_one': + np.testing.assert_almost_equal( + cuda.to_cpu(y.array), cuda.to_cpu(_x_data) + 1, + decimal=4 + ) + elif self.activ is None: + np.testing.assert_almost_equal( + cuda.to_cpu(y.array), cuda.to_cpu(_x_data), + decimal=4 + ) + + def test_forward_cpu(self): + self.check_forward(self.x) + + @attr.gpu + def test_forward_gpu(self): + self.l.to_gpu() + self.check_forward(cuda.to_gpu(self.x)) + + def check_backward(self, x_data, y_grad): + x = chainer.Variable(x_data) + y = self.l(x) + if self.dilate == 1: + y.grad = y_grad + elif self.dilate == 2: + y.grad = y_grad[:, :, 1:-1, 1:-1] + y.backward() + + def test_backward_cpu(self): + self.check_backward(self.x, self.gy) + + @attr.gpu + def test_backward_gpu(self): + self.l.to_gpu() + self.check_backward(cuda.to_gpu(self.x), cuda.to_gpu(self.gy)) + + +@attr.mpi +class TestTFConv2DMultiNodeBNActiv(unittest.TestCase): + + in_channels = 1 + out_channels = 1 + ksize = 3 + stride = 1 + pad = 1 + dilate = 1 + + def setUp(self): + self.x = np.random.uniform( + -1, 1, (5, self.in_channels, 5, 5)).astype(np.float32) + self.gy = np.random.uniform( + -1, 1, (5, self.out_channels, 5, 5)).astype(np.float32) + + # Convolution is the identity function. + initialW = np.array([[0, 0, 0], [0, 1, 0], [0, 0, 0]], + dtype=np.float32).reshape((1, 1, 3, 3)) + bn_kwargs = {'decay': 0.8, 'comm': create_communicator('naive')} + initial_bias = 0 + activ = relu + self.l = TFConv2DBNActiv( + self.in_channels, self.out_channels, self.ksize, self.stride, + self.pad, self.dilate, initialW=initialW, + initial_bias=initial_bias, activ=activ, bn_kwargs=bn_kwargs) + + def check_forward(self, x_data): + x = chainer.Variable(x_data) + # Make the batch normalization to be the identity function. + self.l.bn.avg_var[:] = 1 + self.l.bn.avg_mean[:] = 0 + with chainer.using_config('train', False): + y = self.l(x) + + self.assertIsInstance(y, chainer.Variable) + self.assertIsInstance(y.array, self.l.xp.ndarray) + + np.testing.assert_almost_equal( + cuda.to_cpu(y.array), np.maximum(cuda.to_cpu(x_data), 0), + decimal=4 + ) + + def test_multi_node_batch_normalization_forward_cpu(self): + self.check_forward(self.x) + + @attr.gpu + def test_multi_node_batch_normalization_forward_gpu(self): + self.l.to_gpu() + self.check_forward(cuda.to_gpu(self.x)) + + def check_backward(self, x_data, y_grad): + x = chainer.Variable(x_data) + y = self.l(x) + y.grad = y_grad + y.backward() + + def test_multi_node_batch_normalization_backward_cpu(self): + self.check_backward(self.x, self.gy) + + @attr.gpu + def test_multi_node_batch_normalization_backward_gpu(self): + self.l.to_gpu() + self.check_backward(cuda.to_gpu(self.x), cuda.to_gpu(self.gy)) + + +testing.run_module(__name__, __file__) diff --git a/tests/links_tests/model_tests/mobilenet_tests/test_tf_convolution_2d.py b/tests/links_tests/model_tests/mobilenet_tests/test_tf_convolution_2d.py new file mode 100644 index 0000000000..591ab53dcd --- /dev/null +++ b/tests/links_tests/model_tests/mobilenet_tests/test_tf_convolution_2d.py @@ -0,0 +1,93 @@ +import unittest + +import numpy as np + +import chainer +from chainer.backends import cuda +from chainer import testing +from chainermn import create_communicator + +from chainercv.links.model.mobilenet import TFConvolution2D +from chainercv.utils.testing import attr + + +def _add_one(x): + return x + 1 + + +@testing.parameterize(*testing.product({ + 'pad': [1, 'SAME'], + 'args_style': ['explicit', 'None', 'omit'], +})) +class TestTFConvolution2D(unittest.TestCase): + + in_channels = 1 + out_channels = 1 + ksize = 3 + stride = 1 + pad = 1 + dilate = 1 + + def setUp(self): + self.x = np.random.uniform( + -1, 1, (5, self.in_channels, 5, 5)).astype(np.float32) + self.gy = np.random.uniform( + -1, 1, (5, self.out_channels, 5, 5)).astype(np.float32) + + # Convolution is the identity function. + initialW = np.array([[0, 0, 0], [0, 1, 0], [0, 0, 0]], + dtype=np.float32).reshape((1, 1, 3, 3)) + initial_bias = 0 + if self.args_style == 'explicit': + self.l = TFConvolution2D( + self.in_channels, self.out_channels, self.ksize, + self.stride, self.pad, self.dilate, + initialW=initialW, initial_bias=initial_bias) + elif self.args_style == 'None': + self.l = TFConvolution2D( + None, self.out_channels, self.ksize, self.stride, self.pad, + self.dilate, initialW=initialW, initial_bias=initial_bias) + elif self.args_style == 'omit': + self.l = TFConvolution2D( + self.out_channels, self.ksize, stride=self.stride, + pad=self.pad, dilate=self.dilate, initialW=initialW, + initial_bias=initial_bias) + + def check_forward(self, x_data): + x = chainer.Variable(x_data) + with chainer.using_config('train', False): + y = self.l(x) + + self.assertIsInstance(y, chainer.Variable) + self.assertIsInstance(y.array, self.l.xp.ndarray) + + _x_data = x_data + np.testing.assert_almost_equal( + cuda.to_cpu(y.array), cuda.to_cpu(_x_data), + decimal=4 + ) + + def test_forward_cpu(self): + self.check_forward(self.x) + + @attr.gpu + def test_forward_gpu(self): + self.l.to_gpu() + self.check_forward(cuda.to_gpu(self.x)) + + def check_backward(self, x_data, y_grad): + x = chainer.Variable(x_data) + y = self.l(x) + y.grad = y_grad + y.backward() + + def test_backward_cpu(self): + self.check_backward(self.x, self.gy) + + @attr.gpu + def test_backward_gpu(self): + self.l.to_gpu() + self.check_backward(cuda.to_gpu(self.x), cuda.to_gpu(self.gy)) + + +testing.run_module(__name__, __file__) diff --git a/tests/links_tests/model_tests/test_feature_predictor.py b/tests/links_tests/model_tests/test_feature_predictor.py index a9ff13b6e4..e64def12a3 100644 --- a/tests/links_tests/model_tests/test_feature_predictor.py +++ b/tests/links_tests/model_tests/test_feature_predictor.py @@ -79,7 +79,7 @@ def test_gpu(self): 'crop_size': [192, (192, 256), (256, 192)], 'scale_size': [None, 256, (256, 256)], 'in_channels': [1, 3], - 'mean': [None, np.float32(1)] + 'mean': [None, np.float32(1)], })) class TestFeaturePredictor(unittest.TestCase):