From dfe1db54e7d81872989f8748369693d59cbdf380 Mon Sep 17 00:00:00 2001 From: Junyuan Xie Date: Mon, 19 Sep 2016 13:42:37 -0700 Subject: [PATCH] fix image io --- python/mxnet/image.py | 37 +++++++------- python/mxnet/recordio.py | 44 ++++++++--------- src/operator/tensor/elemwise_unary_op.cu | 2 +- tests/python/unittest/test_recordio.py | 2 +- tools/im2rec.py | 62 ++++++++++++------------ 5 files changed, 72 insertions(+), 75 deletions(-) diff --git a/python/mxnet/image.py b/python/mxnet/image.py index 519af8f22612..1ef08f24f388 100644 --- a/python/mxnet/image.py +++ b/python/mxnet/image.py @@ -285,27 +285,29 @@ class ImageIter(io.DataIter): """ def __init__(self, batch_size, data_shape, label_width=1, path_imgrec=None, path_imglist=None, path_root=None, path_imgidx=None, - shuffle=False, part_index=0, num_parts=1, **kwargs): + shuffle=False, part_index=0, num_parts=1, aug_list=None, **kwargs): super(ImageIter, self).__init__() assert path_imgrec or path_imglist if path_imgrec: if path_imgidx: self.imgrec = recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r') - self.imgidx = self.imgrec.idx + self.imgidx = list(self.imgrec.keys) else: self.imgrec = recordio.MXRecordIO(path_imgrec, 'r') + self.imgidx = None else: self.imgrec = None + if path_imglist: with open(path_imglist) as fin: imglist = {} - while True: - line = fin.readline() - if not line: - break + imgkeys = [] + for line in iter(fin.readline, ''): line = line.strip().split('\t') label = nd.array([float(i) for i in line[1:-1]]) - imglist[int(line[0])] = (label, line[-1]) + key = int(line[0]) + imglist[key] = (label, line[-1]) + imgkeys.append(key) self.imglist = imglist else: self.imglist = None @@ -319,12 +321,11 @@ def __init__(self, batch_size, data_shape, label_width=1, self.label_width = label_width self.shuffle = shuffle - if shuffle or num_parts > 1: - if self.imgrec is None: - self.seq = self.imglist.keys() - else: - assert self.imgidx is not None - self.seq = self.imgidx.keys() + if self.imgrec is None: + self.seq = imgkeys + elif shuffle or num_parts > 1: + assert self.imgidx is not None + self.seq = self.imgidx else: self.seq = None @@ -333,8 +334,10 @@ def __init__(self, batch_size, data_shape, label_width=1, N = len(self.seq) C = N/num_parts self.seq = self.seq[part_index*C:(part_index+1)*C] - - self.auglist = CreateAugmenter(data_shape, **kwargs) + if aug_list is None: + self.auglist = CreateAugmenter(data_shape, **kwargs) + else: + self.auglist = aug_list self.cur = 0 self.reset() @@ -360,7 +363,7 @@ def next_sample(self): else: return self.imglist[idx][0], img else: - label, fname = self.imglist[self.seq[self.cur]] + label, fname = self.imglist[idx] if self.imgrec is None: with open(os.path.join(self.path_root, fname), 'rb') as fin: img = fin.read() @@ -389,4 +392,4 @@ def next(self): raise StopIteration batch_data = nd.transpose(batch_data, axes=(0, 3, 1, 2)) - return io.DataBatch(batch_data, batch_label, batch_size-1-i) + return io.DataBatch([batch_data], [batch_label], batch_size-1-i) diff --git a/python/mxnet/recordio.py b/python/mxnet/recordio.py index 3be6db258db6..34348fe67dbe 100644 --- a/python/mxnet/recordio.py +++ b/python/mxnet/recordio.py @@ -5,7 +5,6 @@ from __future__ import absolute_import from collections import namedtuple -import os import ctypes import struct import numbers @@ -117,28 +116,28 @@ class MXIndexedRecordIO(MXRecordIO): data type for keys """ def __init__(self, idx_path, uri, flag, key_type=int): - super(MXIndexedRecordIO, self).__init__(uri, flag) self.idx_path = idx_path self.idx = {} + self.keys = [] self.key_type = key_type - if not self.writable and os.path.isfile(idx_path): - with open(idx_path) as fin: - for line in fin.readlines(): - line = line.strip().split('\t') - self.idx[key_type(line[0])] = int(line[1]) + self.fidx = None + super(MXIndexedRecordIO, self).__init__(uri, flag) + + def open(self): + super(MXIndexedRecordIO, self).open() + self.idx = {} + self.keys = [] + self.fidx = open(self.idx_path, self.flag) + if not self.writable: + for line in iter(self.fidx.readline, ''): + line = line.strip().split('\t') + key = self.key_type(line[0]) + self.idx[key] = int(line[1]) + self.keys.append(key) def close(self): - if self.writable: - with open(self.idx_path, 'w') as fout: - for k, v in self.idx.items(): - fout.write(str(k)+'\t'+str(v)+'\n') super(MXIndexedRecordIO, self).close() - - def reset(self): - if self.writable: - self.idx = {} - super(MXIndexedRecordIO, self).close() - super(MXIndexedRecordIO, self).open() + self.fidx.close() def seek(self, idx): """Query current read head position""" @@ -160,15 +159,12 @@ def read_idx(self, idx): def write_idx(self, idx, buf): """Write record with index""" + key = self.key_type(idx) pos = self.tell() - self.idx[self.key_type(idx)] = pos self.write(buf) - - def keys(self): - """List all keys from index""" - return list(self.idx.keys()) - - + self.fidx.write('%s\t%d\n'%(str(key), pos)) + self.idx[key] = pos + self.keys.append(key) IRHeader = namedtuple('HEADER', ['flag', 'label', 'id', 'id2']) diff --git a/src/operator/tensor/elemwise_unary_op.cu b/src/operator/tensor/elemwise_unary_op.cu index a30448f9afab..2abe3ee77a86 100644 --- a/src/operator/tensor/elemwise_unary_op.cu +++ b/src/operator/tensor/elemwise_unary_op.cu @@ -114,7 +114,7 @@ NNVM_REGISTER_OP(_backward_arccos) NNVM_REGISTER_OP(arctan) .set_attr("FCompute", UnaryCompute); -NNVM_REGISTER_OP(_backward_arccos) +NNVM_REGISTER_OP(_backward_arctan) .set_attr("FCompute", BinaryCompute >); } // namespace op diff --git a/tests/python/unittest/test_recordio.py b/tests/python/unittest/test_recordio.py index a3853ee891c2..f4489bdfe641 100644 --- a/tests/python/unittest/test_recordio.py +++ b/tests/python/unittest/test_recordio.py @@ -40,7 +40,7 @@ def test_indexed_recordio(): del writer reader = mx.recordio.MXIndexedRecordIO(fidx, frec, 'r') - keys = reader.keys() + keys = reader.keys assert sorted(keys) == [i for i in range(N)] random.shuffle(keys) for i in keys: diff --git a/tools/im2rec.py b/tools/im2rec.py index b935557c7a9c..42c8743c98a9 100644 --- a/tools/im2rec.py +++ b/tools/im2rec.py @@ -77,16 +77,16 @@ def read_list(path_in): item = [int(line[0])] + [line[-1]] + [float(i) for i in line[1:-1]] yield item -def image_encode(args, item, q_out): +def image_encode(args, i, item, q_out): fullpath = os.path.join(args.root, item[1]) try: img = cv2.imread(fullpath, args.color) except: traceback.print_exc() - print('imread error trying to load file: %s ' % fullpath) + print('imread error trying to load file: %s ' % fullpath, e) return if img is None: - print('imread read blank (None) image for file: %s' % fullpath) + print('imread read blank (None) image for file: %s' % fullpath, e) return if args.center_crop: if img.shape[0] > img.shape[1]: @@ -108,18 +108,19 @@ def image_encode(args, item, q_out): try: s = mx.recordio.pack_img(header, img, quality=args.quality, img_fmt=args.encoding) - q_out.put((s, item)) - except Exception: + q_out.put((i, s, item)) + except Exception, e: traceback.print_exc() - print('pack_img error on file: %s' % fullpath) + print('pack_img error on file: %s' % fullpath, e) return def read_worker(args, q_in, q_out): while True: - item = q_in.get() - if item is None: + deq = q_in.get() + if deq is None: break - image_encode(args, item, q_out) + i, item = deq + image_encode(args, i, item, q_out) def write_worker(q_out, fname, working_dir): pre_time = time.time() @@ -127,30 +128,27 @@ def write_worker(q_out, fname, working_dir): fname = os.path.basename(fname) fname_rec = os.path.splitext(fname)[0] + '.rec' fname_idx = os.path.splitext(fname)[0] + '.idx' - fout = open(fname+'.tmp', 'w') record = mx.recordio.MXIndexedRecordIO(os.path.join(working_dir, fname_idx), os.path.join(working_dir, fname_rec), 'w') - while True: + buf = {} + more = True + while more: deq = q_out.get() - if deq is None: - break - s, item = deq - record.write_idx(item[0], s) - - line = '%d\t' % item[0] - for j in item[2:]: - line += '%f\t' % j - line += '%s\n' % item[1] - fout.write(line) + if deq is not None: + i, s, item = deq + buf[i] = (s, item) + else: + more = False + while count in buf: + s, item = buf[count] + del buf[count] + record.write_idx(item[0], s) - if count % 1000 == 0: - cur_time = time.time() - print('time:', cur_time - pre_time, ' count:', count) - pre_time = cur_time - count += 1 - fout.close() - os.remove(fname) - os.rename(fname+'.tmp', fname) + if count % 1000 == 0: + cur_time = time.time() + print('time:', cur_time - pre_time, ' count:', count) + pre_time = cur_time + count += 1 def parse_args(): parser = argparse.ArgumentParser( @@ -176,6 +174,8 @@ def parse_args(): help='If true recursively walk through subdirs and assign an unique label\ to images in each folder. Otherwise only include images in the root folder\ and give them label 0.') + cgroup.add_argument('--shuffle', default=True, help='If this is set as True, \ + im2rec will randomize the image order in .lst') rgroup = parser.add_argument_group('Options for creating database') rgroup.add_argument('--resize', type=int, default=0, @@ -196,8 +196,6 @@ def parse_args(): -1:Loads image as such including alpha channel.') rgroup.add_argument('--encoding', type=str, default='.jpg', choices=['.jpg', '.png'], help='specify the encoding of the images.') - rgroup.add_argument('--shuffle', default=True, help='If this is set as True, \ - im2rec will randomize the image order in .lst') rgroup.add_argument('--pack-label', default=False, help='Whether to also pack multi dimensional label in the record file') args = parser.parse_args() @@ -235,7 +233,7 @@ def parse_args(): write_process.start() for i, item in enumerate(image_list): - q_in[i % len(q_in)].put(item) + q_in[i % len(q_in)].put((i, item)) for q in q_in: q.put(None) for p in read_process: