diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h index ea0b7c9c8379..7c82c20c1fdc 100644 --- a/include/mxnet/ndarray.h +++ b/include/mxnet/ndarray.h @@ -259,7 +259,11 @@ class NDArray { CHECK_GE(shape_[0], idx) << "index out of range"; size_t length = shape_.ProdShape(1, shape_.ndim()); ret.offset_ += idx * length; - ret.shape_ = TShape(shape_.data()+1, shape_.data()+shape_.ndim()); + if (shape_.ndim() > 1) { + ret.shape_ = TShape(shape_.data()+1, shape_.data()+shape_.ndim()); + } else { + ret.shape_ = mshadow::Shape1(1); + } return ret; } /*! diff --git a/nnvm b/nnvm index a1d4ed8e51f2..5dc84b1bffd1 160000 --- a/nnvm +++ b/nnvm @@ -1 +1 @@ -Subproject commit a1d4ed8e51f2fa6131fe6af29664ee4c9ebcd0a8 +Subproject commit 5dc84b1bffd1f08ded76c4e89587fb1dee6ee38a diff --git a/python/mxnet/image.py b/python/mxnet/image.py index 1ef08f24f388..7d66a4451b60 100644 --- a/python/mxnet/image.py +++ b/python/mxnet/image.py @@ -1,11 +1,12 @@ # coding: utf-8 # pylint: disable=no-member, too-many-lines, redefined-builtin, protected-access, unused-import, invalid-name -# pylint: disable=too-many-arguments, too-many-locals, no-name-in-module, too-many-branches +# pylint: disable=too-many-arguments, too-many-locals, no-name-in-module, too-many-branches, too-many-statements """Image IO API of mxnet.""" -from __future__ import absolute_import +from __future__ import absolute_import, print_function import os import random +import logging import numpy as np from . import ndarray as nd from . import _ndarray_internal as _internal @@ -14,6 +15,11 @@ from . import io from . import recordio +try: + import cv2 +except ImportError: + cv2 = None + def imdecode(buf, **kwargs): """Decode an image from string. Requires OpenCV to work. @@ -42,6 +48,15 @@ def scale_down(src_size, size): w, h = sw, float(h*sw)/w return int(w), int(h) +def resize(src, size, interp=2): + """Scale shorter edge to size""" + h, w, _ = src.shape + if h > w: + new_h, new_w = size*h/w, size + else: + new_h, new_w = size, size*w/h + return imresize(src, new_w, new_h, interp=interp) + def fixed_crop(src, x0, y0, w, h, size=None, interp=2): """Crop src at fixed location, and (optionally) resize it to size""" out = nd.crop(src, begin=(y0, x0, 0), end=(y0+h, x0+w, int(src.shape[2]))) @@ -73,63 +88,70 @@ def center_crop(src, size, interp=2): def color_normalize(src, mean, std): """Normalize src with mean and std""" - src = src - mean + src -= mean if std is not None: - src = src / std + src /= std return src -def random_size_crop(src, size, min_area=0.08, ratio=(3.0/4.0, 4.0/3.0), interp=2): +def random_size_crop(src, size, min_area, ratio, interp=2): """Randomly crop src with size. Randomize area and aspect ratio""" h, w, _ = src.shape - area = w*h - for _ in range(10): - new_area = random.uniform(min_area, 1.0) * area - new_ratio = random.uniform(*ratio) - new_w = int(np.sqrt(new_area*new_ratio)) - new_h = int(np.sqrt(new_area/new_ratio)) - - if random.random() < 0.5: - new_w, new_h = new_h, new_w + new_ratio = random.uniform(*ratio) + if new_ratio * h > w: + max_area = w*int(w/new_ratio) + else: + max_area = h*int(h*new_ratio) - if new_w > w or new_h > h: - continue + min_area *= h*w + if max_area < min_area: + return random_crop(src, size, interp) + new_area = random.uniform(min_area, max_area) + new_w = int(np.sqrt(new_area*new_ratio)) + new_h = int(np.sqrt(new_area/new_ratio)) - x0 = random.randint(0, w - new_w) - y0 = random.randint(0, h - new_h) + assert new_w <= w and new_h <= h + x0 = random.randint(0, w - new_w) + y0 = random.randint(0, h - new_h) - out = fixed_crop(src, x0, y0, new_w, new_h, size, interp) - return out, (x0, y0, new_w, new_h) + out = fixed_crop(src, x0, y0, new_w, new_h, size, interp) + return out, (x0, y0, new_w, new_h) - return random_crop(src, size) +def ScaleAug(size, interp=2): + """Make scale shorter edge to size augumenter""" + def aug(src): + """Augumenter body""" + return [resize(src, size, interp)] + return aug def RandomCropAug(size, interp=2): """Make random crop augumenter""" def aug(src): """Augumenter body""" - return random_crop(src, size, interp)[0] + return [random_crop(src, size, interp)[0]] return aug -def RandomSizedCropAug(size, min_area=0.08, ratio=(3.0/4.0, 4.0/3.0), interp=2): +def RandomSizedCropAug(size, min_area, ratio, interp=2): """Make random crop with random resizing and random aspect ratio jitter augumenter""" def aug(src): """Augumenter body""" - return random_size_crop(src, size, min_area, ratio, interp)[0] + return [random_size_crop(src, size, min_area, ratio, interp)[0]] return aug def CenterCropAug(size, interp=2): """Make center crop augmenter""" def aug(src): """Augumenter body""" - return center_crop(src, size, interp)[0] + return [center_crop(src, size, interp)[0]] return aug def RandomOrderAug(ts): """Apply list of augmenters in random order""" def aug(src): """Augumenter body""" + src = [src] random.shuffle(ts) - for i in ts: - src = i(src) + for t in ts: + src = [j for i in src for j in t(i)] return src return aug @@ -142,7 +164,7 @@ def baug(src): """Augumenter body""" alpha = 1.0 + random.uniform(-brightness, brightness) src *= alpha - return src + return [src] ts.append(baug) if contrast > 0: @@ -152,8 +174,8 @@ def caug(src): gray = src*coef gray = (3.0*(1.0-alpha)/gray.size)*nd.sum(gray) src *= alpha - src = src + gray - return src + src += gray + return [src] ts.append(caug) if saturation > 0: @@ -164,8 +186,8 @@ def saug(src): gray = nd.sum(gray, axis=2, keepdims=True) gray *= (1.0-alpha) src *= alpha - src = src + gray - return src + src += gray + return [src] ts.append(saug) return RandomOrderAug(ts) @@ -175,8 +197,8 @@ def aug(src): """Augumenter body""" alpha = np.random.normal(0, alphastd, size=(3,)) rgb = np.dot(eigvec*alpha, eigval) - src = src + nd.array(rgb) - return src + src += nd.array(rgb) + return [src] return aug def ColorNormalizeAug(mean, std): @@ -185,7 +207,7 @@ def ColorNormalizeAug(mean, std): std = nd.array(std) def aug(src): """Augumenter body""" - return color_normalize(src, mean, std) + return [color_normalize(src, mean, std)] return aug def HorizontalFlipAug(p): @@ -194,7 +216,7 @@ def aug(src): """Augumenter body""" if random.random() < p: src = nd.flip(src, axis=1) - return src + return [src] return aug def CastAug(): @@ -202,18 +224,22 @@ def CastAug(): def aug(src): """Augumenter body""" src = src.astype(np.float32) - return src + return [src] return aug -def CreateAugmenter(data_shape, rand_crop=False, rand_resize=False, rand_mirror=False, +def CreateAugmenter(data_shape, scale=0, rand_crop=False, rand_resize=False, rand_mirror=False, mean=None, std=None, brightness=0, contrast=0, saturation=0, pca_noise=0, inter_method=2): """Create augumenter list""" auglist = [] + + if scale > 0: + auglist.append(ScaleAug(scale, inter_method)) + crop_size = (data_shape[2], data_shape[1]) if rand_resize: assert rand_crop - auglist.append(RandomSizedCropAug(crop_size, inter_method)) + auglist.append(RandomSizedCropAug(crop_size, 0.3, (3.0/4.0, 4.0/3.0), inter_method)) elif rand_crop: auglist.append(RandomCropAug(crop_size, inter_method)) else: @@ -238,7 +264,8 @@ def CreateAugmenter(data_shape, rand_crop=False, rand_resize=False, rand_mirror= mean = np.array([123.68, 116.28, 103.53]) if std is True: std = np.array([58.395, 57.12, 57.375]) - if mean: + if mean is not None: + assert std is not None auglist.append(ColorNormalizeAug(mean, std)) return auglist @@ -289,6 +316,7 @@ def __init__(self, batch_size, data_shape, label_width=1, super(ImageIter, self).__init__() assert path_imgrec or path_imglist if path_imgrec: + print('loading recordio...') if path_imgidx: self.imgrec = recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r') self.imgidx = list(self.imgrec.keys) @@ -299,6 +327,7 @@ def __init__(self, batch_size, data_shape, label_width=1, self.imgrec = None if path_imglist: + print('loading image list...') with open(path_imglist) as fin: imglist = {} imgkeys = [] @@ -315,7 +344,10 @@ def __init__(self, batch_size, data_shape, label_width=1, assert len(data_shape) == 3 and data_shape[0] == 3 self.provide_data = [('data', (batch_size,) + data_shape)] - self.provide_label = [('softmax_label', (batch_size, label_width))] + if label_width > 1: + self.provide_label = [('softmax_label', (batch_size, label_width))] + else: + self.provide_label = [('softmax_label', (batch_size,))] self.batch_size = batch_size self.data_shape = data_shape self.label_width = label_width @@ -370,6 +402,8 @@ def next_sample(self): return label, img else: s = self.imgrec.read() + if s is None: + raise StopIteration header, img = recordio.unpack(s) return header.label, img @@ -377,16 +411,22 @@ def next(self): batch_size = self.batch_size c, h, w = self.data_shape batch_data = nd.zeros((batch_size, h, w, c)) - batch_label = nd.zeros((batch_size, self.label_width)) + batch_label = nd.zeros(self.provide_label[0][1]) i = 0 try: - for i in range(batch_size): - label, data = self.next_sample() - data = imdecode(data) + while i < batch_size: + label, s = self.next_sample() + data = [imdecode(s)] + if len(data[0].shape) == 0: + logging.debug('Invalid image, skipping.') + continue for aug in self.auglist: - data = aug(data) - batch_data[i][:] = data - batch_label[i][:] = label + data = [ret for src in data for ret in aug(src)] + for d in data: + assert i < batch_size, 'Batch size must be multiples of augmenter output length' + batch_data[i][:] = d + batch_label[i][:] = label + i += 1 except StopIteration: if not i: raise StopIteration diff --git a/python/mxnet/recordio.py b/python/mxnet/recordio.py index 34348fe67dbe..d86709b07d1f 100644 --- a/python/mxnet/recordio.py +++ b/python/mxnet/recordio.py @@ -16,9 +16,8 @@ from .base import c_str try: import cv2 - opencv_available = True except ImportError: - opencv_available = False + cv2 = None class MXRecordIO(object): """Python interface for read/write RecordIO data formmat @@ -233,7 +232,7 @@ def unpack_img(s, iscolor=-1): """ header, s = unpack(s) img = np.fromstring(s, dtype=np.uint8) - assert opencv_available + assert cv2 is not None img = cv2.imdecode(img, iscolor) return header, img @@ -257,7 +256,7 @@ def pack_img(header, img, quality=95, img_fmt='.jpg'): s : str The packed string """ - assert opencv_available + assert cv2 is not None jpg_formats = ['.JPG', '.JPEG'] png_formats = ['.PNG'] encode_params = None diff --git a/src/io/image_aug_default.cc b/src/io/image_aug_default.cc index e36f201dbc28..7da7134cab6b 100644 --- a/src/io/image_aug_default.cc +++ b/src/io/image_aug_default.cc @@ -23,6 +23,8 @@ namespace io { /*! \brief image augmentation parameters*/ struct DefaultImageAugmentParam : public dmlc::Parameter { + /*! \brief scale shorter edge to size before applying other augmentations */ + int scale; /*! \brief whether we do random cropping */ bool rand_crop; /*! \brief where to nonrandom crop on y */ @@ -65,6 +67,9 @@ struct DefaultImageAugmentParam : public dmlc::Parameter src.cols) { + new_height = param_.scale*src.rows/src.cols; + new_width = param_.scale; + } else { + new_height = param_.scale; + new_width = param_.scale*src.cols/src.rows; + } + CHECK((param_.inter_method >= 1 && param_.inter_method <= 4) || + (param_.inter_method >= 9 && param_.inter_method <= 10)) + << "invalid inter_method: valid value 0,1,2,3,9,10"; + int interpolation_method = GetInterMethod(param_.inter_method, + src.cols, src.rows, new_width, new_height, prnd); + cv::resize(src, res, cv::Size(new_width, new_height), + 0, 0, interpolation_method); + } else { + res = src; + } // normal augmentation by affine transformation. if (param_.max_rotate_angle > 0 || param_.max_shear_ratio > 0.0f @@ -196,30 +220,28 @@ class DefaultImageAugmenter : public ImageAugmenter { float ws = ratio * hs; // new width and height float new_width = std::max(param_.min_img_size, - std::min(param_.max_img_size, scale * src.cols)); + std::min(param_.max_img_size, scale * res.cols)); float new_height = std::max(param_.min_img_size, - std::min(param_.max_img_size, scale * src.rows)); + std::min(param_.max_img_size, scale * res.rows)); cv::Mat M(2, 3, CV_32F); M.at(0, 0) = hs * a - s * b * ws; M.at(1, 0) = -b * ws; M.at(0, 1) = hs * b + s * a * ws; M.at(1, 1) = a * ws; - float ori_center_width = M.at(0, 0) * src.cols + M.at(0, 1) * src.rows; - float ori_center_height = M.at(1, 0) * src.cols + M.at(1, 1) * src.rows; + float ori_center_width = M.at(0, 0) * res.cols + M.at(0, 1) * res.rows; + float ori_center_height = M.at(1, 0) * res.cols + M.at(1, 1) * res.rows; M.at(0, 2) = (new_width - ori_center_width) / 2; M.at(1, 2) = (new_height - ori_center_height) / 2; CHECK((param_.inter_method >= 1 && param_.inter_method <= 4) || - (param_.inter_method >= 9 && param_.inter_method <= 10)) - << "invalid inter_method: valid value 0,1,2,3,9,10"; + (param_.inter_method >= 9 && param_.inter_method <= 10)) + << "invalid inter_method: valid value 0,1,2,3,9,10"; int interpolation_method = GetInterMethod(param_.inter_method, - src.cols, src.rows, new_width, new_height, prnd); - cv::warpAffine(src, temp_, M, cv::Size(new_width, new_height), + res.cols, res.rows, new_width, new_height, prnd); + cv::warpAffine(res, temp_, M, cv::Size(new_width, new_height), interpolation_method, cv::BORDER_CONSTANT, cv::Scalar(param_.fill_value, param_.fill_value, param_.fill_value)); res = temp_; - } else { - res = src; } // pad logic diff --git a/src/io/image_io.cc b/src/io/image_io.cc index b50849b7b099..c54ce5e61596 100644 --- a/src/io/image_io.cc +++ b/src/io/image_io.cc @@ -5,6 +5,7 @@ * \author Junyuan Xie */ #include +#include #include #include #include @@ -23,7 +24,8 @@ namespace mxnet { namespace io { // http://www.64lines.com/jpeg-width-height -// Gets the JPEG size from the array of data passed to the function, file reference: http://www.obrador.com/essentialjpeg/headerinfo.htm +// Gets the JPEG size from the array of data passed to the function, +// file reference: http://www.obrador.com/essentialjpeg/headerinfo.htm bool get_jpeg_size(const uint8_t* data, uint32_t data_size, uint32_t *width, uint32_t *height) { // Check for valid JPEG image uint32_t i = 0; // Keeps track of the position within the file @@ -39,7 +41,8 @@ bool get_jpeg_size(const uint8_t* data, uint32_t data_size, uint32_t *width, uin i+=block_length; // Increase the file index to get to the next block if (i >= data_size) return false; // Check to protect against segmentation faults if (data[i] != 0xFF) return false; // Check that we are truly at the start of another block - if (data[i+1] == 0xC0) { + uint8_t m = data[i+1]; + if (m == 0xC0 || (m >= 0xC1 && m <= 0xCF && m != 0xC4 && m != 0xC8 && m != 0xCC)) { // 0xFFC0 is the "Start of frame" marker which contains the file size // The structure of the 0xFFC0 block is quite simple // [0xFFC0][ushort length][uchar precision][ushort x][ushort y] @@ -99,16 +102,34 @@ void Imdecode(const nnvm::NodeAttrs& attrs, const uint8_t* str_img = reinterpret_cast(inputs[0].data().dptr_); uint32_t len = inputs[0].shape().Size(); - inputs[0].WaitToRead(); + NDArray ndin = inputs[0]; + ndin.WaitToRead(); TShape oshape(3); oshape[2] = param.flag == 0 ? 1 : 3; if (get_jpeg_size(str_img, len, &oshape[1], &oshape[0])) { } else if (get_png_size(str_img, len, &oshape[1], &oshape[0])) { } else { - LOG(FATAL) << "Only supports png and jpg."; + cv::Mat buf(1, ndin.shape().Size(), CV_8U, ndin.data().dptr_); + cv::Mat res = cv::imdecode(buf, param.flag); + if (res.empty()) { + LOG(INFO) << "Invalid image file. Only supports png and jpg."; + (*outputs)[0] = NDArray(); + return; + } + oshape[0] = res.rows; + oshape[1] = res.cols; + NDArray ndout(oshape, Context::CPU(), false, mshadow::kUint8); + cv::Mat dst(ndout.shape()[0], ndout.shape()[1], + param.flag == 0 ? CV_8U : CV_8UC3, + ndout.data().dptr_); + res.copyTo(dst); + if (param.to_rgb) { + cv::cvtColor(dst, dst, CV_BGR2RGB); + } + (*outputs)[0] = ndout; + return; } - NDArray ndin = inputs[0]; NDArray ndout(oshape, Context::CPU(), true, mshadow::kUint8); Engine::Get()->PushSync([ndin, ndout, param](RunContext ctx){ ndout.CheckAndAlloc(); diff --git a/src/io/iter_image_recordio.cc b/src/io/iter_image_recordio.cc index 8a3f6b17bc44..02436f9b13a6 100644 --- a/src/io/iter_image_recordio.cc +++ b/src/io/iter_image_recordio.cc @@ -127,7 +127,7 @@ struct ImageRecParserParam : public dmlc::Parameter { DMLC_DECLARE_FIELD(label_width).set_lower_bound(1).set_default(1) .describe("Dataset Param: How many labels for an image."); DMLC_DECLARE_FIELD(data_shape) - .enforce_nonzero() + .set_expect_ndim(3).enforce_nonzero() .describe("Dataset Param: Shape of each instance generated by the DataIter."); DMLC_DECLARE_FIELD(preprocess_threads).set_lower_bound(1).set_default(4) .describe("Backend Param: Number of thread to do preprocessing."); @@ -279,8 +279,23 @@ ParseNext(std::vector> *out_vec) { cv::Mat res; rec.Load(blob.dptr, blob.size); cv::Mat buf(1, rec.content_size, CV_8U, rec.content); - // -1 to keep the number of channel of the encoded image, and not force gray or color. - res = cv::imdecode(buf, -1); + switch (param_.data_shape[0]) { + case 1: + res = cv::imdecode(buf, 0); + break; + case 3: + res = cv::imdecode(buf, 1); + break; + case 4: + // -1 to keep the number of channel of the encoded image, and not force gray or color. + res = cv::imdecode(buf, -1); + CHECK_EQ(res.channels(), 4) + << "Invalid image with index " << rec.image_index() + << ". Expected 4 channels, got " << res.channels(); + break; + default: + LOG(FATAL) << "Invalid output shape " << param_.data_shape; + } const int n_channels = res.channels(); for (auto& aug : augmenters_[tid]) { res = aug->Process(res, prnds_[tid].get()); diff --git a/tools/im2rec.py b/tools/im2rec.py index 19855f3f80dd..a5799e61e26f 100644 --- a/tools/im2rec.py +++ b/tools/im2rec.py @@ -21,7 +21,9 @@ def list_image(root, recursive, exts): i = 0 if recursive: cat = {} - for path, _, files in os.walk(root, followlinks=True): + for path, dirs, files in os.walk(root, followlinks=True): + dirs.sort() + files.sort() for fname in files: fpath = os.path.join(path, fname) suffix = os.path.splitext(fname)[1].lower() @@ -30,10 +32,10 @@ def list_image(root, recursive, exts): cat[path] = len(cat) yield (i, os.path.relpath(fpath, root), cat[path]) i += 1 - for k, v in cat.items(): + for k, v in sorted(cat.items(), key=lambda x: x[1]): print(os.path.relpath(k, root), v) else: - for fname in os.listdir(root): + for fname in sorted(os.listdir(root)): fpath = os.path.join(root, fname) suffix = os.path.splitext(fname)[1].lower() if os.path.isfile(fpath) and (suffix in exts):