Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix image io
Browse files Browse the repository at this point in the history
  • Loading branch information
piiswrong committed Dec 29, 2016
1 parent 550fff1 commit dfe1db5
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 75 deletions.
37 changes: 20 additions & 17 deletions python/mxnet/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,27 +285,29 @@ class ImageIter(io.DataIter):
"""
def __init__(self, batch_size, data_shape, label_width=1,
path_imgrec=None, path_imglist=None, path_root=None, path_imgidx=None,
shuffle=False, part_index=0, num_parts=1, **kwargs):
shuffle=False, part_index=0, num_parts=1, aug_list=None, **kwargs):
super(ImageIter, self).__init__()
assert path_imgrec or path_imglist
if path_imgrec:
if path_imgidx:
self.imgrec = recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r')
self.imgidx = self.imgrec.idx
self.imgidx = list(self.imgrec.keys)
else:
self.imgrec = recordio.MXRecordIO(path_imgrec, 'r')
self.imgidx = None
else:
self.imgrec = None

if path_imglist:
with open(path_imglist) as fin:
imglist = {}
while True:
line = fin.readline()
if not line:
break
imgkeys = []
for line in iter(fin.readline, ''):
line = line.strip().split('\t')
label = nd.array([float(i) for i in line[1:-1]])
imglist[int(line[0])] = (label, line[-1])
key = int(line[0])
imglist[key] = (label, line[-1])
imgkeys.append(key)
self.imglist = imglist
else:
self.imglist = None
Expand All @@ -319,12 +321,11 @@ def __init__(self, batch_size, data_shape, label_width=1,
self.label_width = label_width

self.shuffle = shuffle
if shuffle or num_parts > 1:
if self.imgrec is None:
self.seq = self.imglist.keys()
else:
assert self.imgidx is not None
self.seq = self.imgidx.keys()
if self.imgrec is None:
self.seq = imgkeys
elif shuffle or num_parts > 1:
assert self.imgidx is not None
self.seq = self.imgidx
else:
self.seq = None

Expand All @@ -333,8 +334,10 @@ def __init__(self, batch_size, data_shape, label_width=1,
N = len(self.seq)
C = N/num_parts
self.seq = self.seq[part_index*C:(part_index+1)*C]

self.auglist = CreateAugmenter(data_shape, **kwargs)
if aug_list is None:
self.auglist = CreateAugmenter(data_shape, **kwargs)
else:
self.auglist = aug_list
self.cur = 0
self.reset()

Expand All @@ -360,7 +363,7 @@ def next_sample(self):
else:
return self.imglist[idx][0], img
else:
label, fname = self.imglist[self.seq[self.cur]]
label, fname = self.imglist[idx]
if self.imgrec is None:
with open(os.path.join(self.path_root, fname), 'rb') as fin:
img = fin.read()
Expand Down Expand Up @@ -389,4 +392,4 @@ def next(self):
raise StopIteration

batch_data = nd.transpose(batch_data, axes=(0, 3, 1, 2))
return io.DataBatch(batch_data, batch_label, batch_size-1-i)
return io.DataBatch([batch_data], [batch_label], batch_size-1-i)
44 changes: 20 additions & 24 deletions python/mxnet/recordio.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from __future__ import absolute_import
from collections import namedtuple

import os
import ctypes
import struct
import numbers
Expand Down Expand Up @@ -117,28 +116,28 @@ class MXIndexedRecordIO(MXRecordIO):
data type for keys
"""
def __init__(self, idx_path, uri, flag, key_type=int):
super(MXIndexedRecordIO, self).__init__(uri, flag)
self.idx_path = idx_path
self.idx = {}
self.keys = []
self.key_type = key_type
if not self.writable and os.path.isfile(idx_path):
with open(idx_path) as fin:
for line in fin.readlines():
line = line.strip().split('\t')
self.idx[key_type(line[0])] = int(line[1])
self.fidx = None
super(MXIndexedRecordIO, self).__init__(uri, flag)

def open(self):
super(MXIndexedRecordIO, self).open()
self.idx = {}
self.keys = []
self.fidx = open(self.idx_path, self.flag)
if not self.writable:
for line in iter(self.fidx.readline, ''):
line = line.strip().split('\t')
key = self.key_type(line[0])
self.idx[key] = int(line[1])
self.keys.append(key)

def close(self):
if self.writable:
with open(self.idx_path, 'w') as fout:
for k, v in self.idx.items():
fout.write(str(k)+'\t'+str(v)+'\n')
super(MXIndexedRecordIO, self).close()

def reset(self):
if self.writable:
self.idx = {}
super(MXIndexedRecordIO, self).close()
super(MXIndexedRecordIO, self).open()
self.fidx.close()

def seek(self, idx):
"""Query current read head position"""
Expand All @@ -160,15 +159,12 @@ def read_idx(self, idx):

def write_idx(self, idx, buf):
"""Write record with index"""
key = self.key_type(idx)
pos = self.tell()
self.idx[self.key_type(idx)] = pos
self.write(buf)

def keys(self):
"""List all keys from index"""
return list(self.idx.keys())


self.fidx.write('%s\t%d\n'%(str(key), pos))
self.idx[key] = pos
self.keys.append(key)


IRHeader = namedtuple('HEADER', ['flag', 'label', 'id', 'id2'])
Expand Down
2 changes: 1 addition & 1 deletion src/operator/tensor/elemwise_unary_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ NNVM_REGISTER_OP(_backward_arccos)
NNVM_REGISTER_OP(arctan)
.set_attr<FCompute>("FCompute<gpu>", UnaryCompute<gpu, mshadow_op::arctan>);

NNVM_REGISTER_OP(_backward_arccos)
NNVM_REGISTER_OP(_backward_arctan)
.set_attr<FCompute>("FCompute<gpu>", BinaryCompute<gpu, unary_bwd<mshadow_op::arctan_grad> >);

} // namespace op
Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_recordio.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_indexed_recordio():
del writer

reader = mx.recordio.MXIndexedRecordIO(fidx, frec, 'r')
keys = reader.keys()
keys = reader.keys
assert sorted(keys) == [i for i in range(N)]
random.shuffle(keys)
for i in keys:
Expand Down
62 changes: 30 additions & 32 deletions tools/im2rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,16 +77,16 @@ def read_list(path_in):
item = [int(line[0])] + [line[-1]] + [float(i) for i in line[1:-1]]
yield item

def image_encode(args, item, q_out):
def image_encode(args, i, item, q_out):
fullpath = os.path.join(args.root, item[1])
try:
img = cv2.imread(fullpath, args.color)
except:
traceback.print_exc()
print('imread error trying to load file: %s ' % fullpath)
print('imread error trying to load file: %s ' % fullpath, e)
return
if img is None:
print('imread read blank (None) image for file: %s' % fullpath)
print('imread read blank (None) image for file: %s' % fullpath, e)
return
if args.center_crop:
if img.shape[0] > img.shape[1]:
Expand All @@ -108,49 +108,47 @@ def image_encode(args, item, q_out):

try:
s = mx.recordio.pack_img(header, img, quality=args.quality, img_fmt=args.encoding)
q_out.put((s, item))
except Exception:
q_out.put((i, s, item))
except Exception, e:
traceback.print_exc()
print('pack_img error on file: %s' % fullpath)
print('pack_img error on file: %s' % fullpath, e)
return

def read_worker(args, q_in, q_out):
while True:
item = q_in.get()
if item is None:
deq = q_in.get()
if deq is None:
break
image_encode(args, item, q_out)
i, item = deq
image_encode(args, i, item, q_out)

def write_worker(q_out, fname, working_dir):
pre_time = time.time()
count = 0
fname = os.path.basename(fname)
fname_rec = os.path.splitext(fname)[0] + '.rec'
fname_idx = os.path.splitext(fname)[0] + '.idx'
fout = open(fname+'.tmp', 'w')
record = mx.recordio.MXIndexedRecordIO(os.path.join(working_dir, fname_idx),
os.path.join(working_dir, fname_rec), 'w')
while True:
buf = {}
more = True
while more:
deq = q_out.get()
if deq is None:
break
s, item = deq
record.write_idx(item[0], s)

line = '%d\t' % item[0]
for j in item[2:]:
line += '%f\t' % j
line += '%s\n' % item[1]
fout.write(line)
if deq is not None:
i, s, item = deq
buf[i] = (s, item)
else:
more = False
while count in buf:
s, item = buf[count]
del buf[count]
record.write_idx(item[0], s)

if count % 1000 == 0:
cur_time = time.time()
print('time:', cur_time - pre_time, ' count:', count)
pre_time = cur_time
count += 1
fout.close()
os.remove(fname)
os.rename(fname+'.tmp', fname)
if count % 1000 == 0:
cur_time = time.time()
print('time:', cur_time - pre_time, ' count:', count)
pre_time = cur_time
count += 1

def parse_args():
parser = argparse.ArgumentParser(
Expand All @@ -176,6 +174,8 @@ def parse_args():
help='If true recursively walk through subdirs and assign an unique label\
to images in each folder. Otherwise only include images in the root folder\
and give them label 0.')
cgroup.add_argument('--shuffle', default=True, help='If this is set as True, \
im2rec will randomize the image order in <prefix>.lst')

rgroup = parser.add_argument_group('Options for creating database')
rgroup.add_argument('--resize', type=int, default=0,
Expand All @@ -196,8 +196,6 @@ def parse_args():
-1:Loads image as such including alpha channel.')
rgroup.add_argument('--encoding', type=str, default='.jpg', choices=['.jpg', '.png'],
help='specify the encoding of the images.')
rgroup.add_argument('--shuffle', default=True, help='If this is set as True, \
im2rec will randomize the image order in <prefix>.lst')
rgroup.add_argument('--pack-label', default=False,
help='Whether to also pack multi dimensional label in the record file')
args = parser.parse_args()
Expand Down Expand Up @@ -235,7 +233,7 @@ def parse_args():
write_process.start()

for i, item in enumerate(image_list):
q_in[i % len(q_in)].put(item)
q_in[i % len(q_in)].put((i, item))
for q in q_in:
q.put(None)
for p in read_process:
Expand Down

0 comments on commit dfe1db5

Please sign in to comment.