Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix metric & im2rec.py (#3375)
Browse files Browse the repository at this point in the history
image io fix
  • Loading branch information
piiswrong committed Dec 29, 2016
1 parent 85478cb commit da3cbc3
Show file tree
Hide file tree
Showing 8 changed files with 178 additions and 75 deletions.
6 changes: 5 additions & 1 deletion include/mxnet/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,11 @@ class NDArray {
CHECK_GE(shape_[0], idx) << "index out of range";
size_t length = shape_.ProdShape(1, shape_.ndim());
ret.offset_ += idx * length;
ret.shape_ = TShape(shape_.data()+1, shape_.data()+shape_.ndim());
if (shape_.ndim() > 1) {
ret.shape_ = TShape(shape_.data()+1, shape_.data()+shape_.ndim());
} else {
ret.shape_ = mshadow::Shape1(1);
}
return ret;
}
/*!
Expand Down
2 changes: 1 addition & 1 deletion nnvm
136 changes: 88 additions & 48 deletions python/mxnet/image.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# coding: utf-8
# pylint: disable=no-member, too-many-lines, redefined-builtin, protected-access, unused-import, invalid-name
# pylint: disable=too-many-arguments, too-many-locals, no-name-in-module, too-many-branches
# pylint: disable=too-many-arguments, too-many-locals, no-name-in-module, too-many-branches, too-many-statements
"""Image IO API of mxnet."""
from __future__ import absolute_import
from __future__ import absolute_import, print_function

import os
import random
import logging
import numpy as np
from . import ndarray as nd
from . import _ndarray_internal as _internal
Expand All @@ -14,6 +15,11 @@
from . import io
from . import recordio

try:
import cv2
except ImportError:
cv2 = None

def imdecode(buf, **kwargs):
"""Decode an image from string. Requires OpenCV to work.
Expand Down Expand Up @@ -42,6 +48,15 @@ def scale_down(src_size, size):
w, h = sw, float(h*sw)/w
return int(w), int(h)

def resize(src, size, interp=2):
"""Scale shorter edge to size"""
h, w, _ = src.shape
if h > w:
new_h, new_w = size*h/w, size
else:
new_h, new_w = size, size*w/h
return imresize(src, new_w, new_h, interp=interp)

def fixed_crop(src, x0, y0, w, h, size=None, interp=2):
"""Crop src at fixed location, and (optionally) resize it to size"""
out = nd.crop(src, begin=(y0, x0, 0), end=(y0+h, x0+w, int(src.shape[2])))
Expand Down Expand Up @@ -73,63 +88,70 @@ def center_crop(src, size, interp=2):

def color_normalize(src, mean, std):
"""Normalize src with mean and std"""
src = src - mean
src -= mean
if std is not None:
src = src / std
src /= std
return src

def random_size_crop(src, size, min_area=0.08, ratio=(3.0/4.0, 4.0/3.0), interp=2):
def random_size_crop(src, size, min_area, ratio, interp=2):
"""Randomly crop src with size. Randomize area and aspect ratio"""
h, w, _ = src.shape
area = w*h
for _ in range(10):
new_area = random.uniform(min_area, 1.0) * area
new_ratio = random.uniform(*ratio)
new_w = int(np.sqrt(new_area*new_ratio))
new_h = int(np.sqrt(new_area/new_ratio))

if random.random() < 0.5:
new_w, new_h = new_h, new_w
new_ratio = random.uniform(*ratio)
if new_ratio * h > w:
max_area = w*int(w/new_ratio)
else:
max_area = h*int(h*new_ratio)

if new_w > w or new_h > h:
continue
min_area *= h*w
if max_area < min_area:
return random_crop(src, size, interp)
new_area = random.uniform(min_area, max_area)
new_w = int(np.sqrt(new_area*new_ratio))
new_h = int(np.sqrt(new_area/new_ratio))

x0 = random.randint(0, w - new_w)
y0 = random.randint(0, h - new_h)
assert new_w <= w and new_h <= h
x0 = random.randint(0, w - new_w)
y0 = random.randint(0, h - new_h)

out = fixed_crop(src, x0, y0, new_w, new_h, size, interp)
return out, (x0, y0, new_w, new_h)
out = fixed_crop(src, x0, y0, new_w, new_h, size, interp)
return out, (x0, y0, new_w, new_h)

return random_crop(src, size)
def ScaleAug(size, interp=2):
"""Make scale shorter edge to size augumenter"""
def aug(src):
"""Augumenter body"""
return [resize(src, size, interp)]
return aug

def RandomCropAug(size, interp=2):
"""Make random crop augumenter"""
def aug(src):
"""Augumenter body"""
return random_crop(src, size, interp)[0]
return [random_crop(src, size, interp)[0]]
return aug

def RandomSizedCropAug(size, min_area=0.08, ratio=(3.0/4.0, 4.0/3.0), interp=2):
def RandomSizedCropAug(size, min_area, ratio, interp=2):
"""Make random crop with random resizing and random aspect ratio jitter augumenter"""
def aug(src):
"""Augumenter body"""
return random_size_crop(src, size, min_area, ratio, interp)[0]
return [random_size_crop(src, size, min_area, ratio, interp)[0]]
return aug

def CenterCropAug(size, interp=2):
"""Make center crop augmenter"""
def aug(src):
"""Augumenter body"""
return center_crop(src, size, interp)[0]
return [center_crop(src, size, interp)[0]]
return aug

def RandomOrderAug(ts):
"""Apply list of augmenters in random order"""
def aug(src):
"""Augumenter body"""
src = [src]
random.shuffle(ts)
for i in ts:
src = i(src)
for t in ts:
src = [j for i in src for j in t(i)]
return src
return aug

Expand All @@ -142,7 +164,7 @@ def baug(src):
"""Augumenter body"""
alpha = 1.0 + random.uniform(-brightness, brightness)
src *= alpha
return src
return [src]
ts.append(baug)

if contrast > 0:
Expand All @@ -152,8 +174,8 @@ def caug(src):
gray = src*coef
gray = (3.0*(1.0-alpha)/gray.size)*nd.sum(gray)
src *= alpha
src = src + gray
return src
src += gray
return [src]
ts.append(caug)

if saturation > 0:
Expand All @@ -164,8 +186,8 @@ def saug(src):
gray = nd.sum(gray, axis=2, keepdims=True)
gray *= (1.0-alpha)
src *= alpha
src = src + gray
return src
src += gray
return [src]
ts.append(saug)
return RandomOrderAug(ts)

Expand All @@ -175,8 +197,8 @@ def aug(src):
"""Augumenter body"""
alpha = np.random.normal(0, alphastd, size=(3,))
rgb = np.dot(eigvec*alpha, eigval)
src = src + nd.array(rgb)
return src
src += nd.array(rgb)
return [src]
return aug

def ColorNormalizeAug(mean, std):
Expand All @@ -185,7 +207,7 @@ def ColorNormalizeAug(mean, std):
std = nd.array(std)
def aug(src):
"""Augumenter body"""
return color_normalize(src, mean, std)
return [color_normalize(src, mean, std)]
return aug

def HorizontalFlipAug(p):
Expand All @@ -194,26 +216,30 @@ def aug(src):
"""Augumenter body"""
if random.random() < p:
src = nd.flip(src, axis=1)
return src
return [src]
return aug

def CastAug():
"""Cast to float32"""
def aug(src):
"""Augumenter body"""
src = src.astype(np.float32)
return src
return [src]
return aug

def CreateAugmenter(data_shape, rand_crop=False, rand_resize=False, rand_mirror=False,
def CreateAugmenter(data_shape, scale=0, rand_crop=False, rand_resize=False, rand_mirror=False,
mean=None, std=None, brightness=0, contrast=0, saturation=0,
pca_noise=0, inter_method=2):
"""Create augumenter list"""
auglist = []

if scale > 0:
auglist.append(ScaleAug(scale, inter_method))

crop_size = (data_shape[2], data_shape[1])
if rand_resize:
assert rand_crop
auglist.append(RandomSizedCropAug(crop_size, inter_method))
auglist.append(RandomSizedCropAug(crop_size, 0.3, (3.0/4.0, 4.0/3.0), inter_method))
elif rand_crop:
auglist.append(RandomCropAug(crop_size, inter_method))
else:
Expand All @@ -238,7 +264,8 @@ def CreateAugmenter(data_shape, rand_crop=False, rand_resize=False, rand_mirror=
mean = np.array([123.68, 116.28, 103.53])
if std is True:
std = np.array([58.395, 57.12, 57.375])
if mean:
if mean is not None:
assert std is not None
auglist.append(ColorNormalizeAug(mean, std))

return auglist
Expand Down Expand Up @@ -289,6 +316,7 @@ def __init__(self, batch_size, data_shape, label_width=1,
super(ImageIter, self).__init__()
assert path_imgrec or path_imglist
if path_imgrec:
print('loading recordio...')
if path_imgidx:
self.imgrec = recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r')
self.imgidx = list(self.imgrec.keys)
Expand All @@ -299,6 +327,7 @@ def __init__(self, batch_size, data_shape, label_width=1,
self.imgrec = None

if path_imglist:
print('loading image list...')
with open(path_imglist) as fin:
imglist = {}
imgkeys = []
Expand All @@ -315,7 +344,10 @@ def __init__(self, batch_size, data_shape, label_width=1,

assert len(data_shape) == 3 and data_shape[0] == 3
self.provide_data = [('data', (batch_size,) + data_shape)]
self.provide_label = [('softmax_label', (batch_size, label_width))]
if label_width > 1:
self.provide_label = [('softmax_label', (batch_size, label_width))]
else:
self.provide_label = [('softmax_label', (batch_size,))]
self.batch_size = batch_size
self.data_shape = data_shape
self.label_width = label_width
Expand Down Expand Up @@ -370,23 +402,31 @@ def next_sample(self):
return label, img
else:
s = self.imgrec.read()
if s is None:
raise StopIteration
header, img = recordio.unpack(s)
return header.label, img

def next(self):
batch_size = self.batch_size
c, h, w = self.data_shape
batch_data = nd.zeros((batch_size, h, w, c))
batch_label = nd.zeros((batch_size, self.label_width))
batch_label = nd.zeros(self.provide_label[0][1])
i = 0
try:
for i in range(batch_size):
label, data = self.next_sample()
data = imdecode(data)
while i < batch_size:
label, s = self.next_sample()
data = [imdecode(s)]
if len(data[0].shape) == 0:
logging.debug('Invalid image, skipping.')
continue
for aug in self.auglist:
data = aug(data)
batch_data[i][:] = data
batch_label[i][:] = label
data = [ret for src in data for ret in aug(src)]
for d in data:
assert i < batch_size, 'Batch size must be multiples of augmenter output length'
batch_data[i][:] = d
batch_label[i][:] = label
i += 1
except StopIteration:
if not i:
raise StopIteration
Expand Down
7 changes: 3 additions & 4 deletions python/mxnet/recordio.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,8 @@
from .base import c_str
try:
import cv2
opencv_available = True
except ImportError:
opencv_available = False
cv2 = None

class MXRecordIO(object):
"""Python interface for read/write RecordIO data formmat
Expand Down Expand Up @@ -233,7 +232,7 @@ def unpack_img(s, iscolor=-1):
"""
header, s = unpack(s)
img = np.fromstring(s, dtype=np.uint8)
assert opencv_available
assert cv2 is not None
img = cv2.imdecode(img, iscolor)
return header, img

Expand All @@ -257,7 +256,7 @@ def pack_img(header, img, quality=95, img_fmt='.jpg'):
s : str
The packed string
"""
assert opencv_available
assert cv2 is not None
jpg_formats = ['.JPG', '.JPEG']
png_formats = ['.PNG']
encode_params = None
Expand Down
Loading

0 comments on commit da3cbc3

Please sign in to comment.