Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix metric & im2rec.py
Browse files Browse the repository at this point in the history
  • Loading branch information
piiswrong committed Dec 29, 2016
1 parent 32bc762 commit 96d4ae4
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 26 deletions.
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,9 @@ endif

all: lib/libmxnet.a lib/libmxnet.so $(BIN)

SRC = $(wildcard src/*.cc src/*/*.cc src/*/*/*.cc)
SRC = $(wildcard src/*/*/*.cc src/*/*.cc src/*.cc)
OBJ = $(patsubst %.cc, build/%.o, $(SRC))
CUSRC = $(wildcard src/*.cu src/*/*.cu src/*/*/*.cu)
CUSRC = $(wildcard src/*/*/*.cu src/*/*.cu src/*.cu)
CUOBJ = $(patsubst %.cu, build/%_gpu.o, $(CUSRC))

# extra operators
Expand Down
4 changes: 3 additions & 1 deletion python/mxnet/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,9 @@ def update(self, labels, preds):
check_label_shapes(labels, preds)

for label, pred_label in zip(labels, preds):
pred_label = ndarray.argmax_channel(pred_label).asnumpy().astype('int32')
if pred_label.shape != label.shape:
pred_label = ndarray.argmax_channel(pred_label)
pred_label = pred_label.asnumpy().astype('int32')
label = label.asnumpy().astype('int32')

check_label_shapes(label, pred_label)
Expand Down
66 changes: 43 additions & 23 deletions tools/im2rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,16 @@
import time
import traceback

try:
import multiprocessing
except ImportError:
multiprocessing = None

def list_image(root, recursive, exts):
i = 0
if recursive:
cat = {}
for path, subdirs, files in os.walk(root, followlinks=True):
subdirs.sort()
print(len(cat), path)
for path, _, files in os.walk(root, followlinks=True):
for fname in files:
fpath = os.path.join(path, fname)
suffix = os.path.splitext(fname)[1].lower()
Expand All @@ -28,6 +30,8 @@ def list_image(root, recursive, exts):
cat[path] = len(cat)
yield (i, os.path.relpath(fpath, root), cat[path])
i += 1
for k, v in cat.items():
print(os.path.relpath(k, root), v)
else:
for fname in os.listdir(root):
fpath = os.path.join(root, fname)
Expand Down Expand Up @@ -61,11 +65,14 @@ def make_list(args):
str_chunk = ''
sep = int(chunk_size * args.train_ratio)
sep_test = int(chunk_size * args.test_ratio)
if args.test_ratio:
write_list(args.prefix + str_chunk + '_test.lst', chunk[:sep_test])
if args.train_ratio + args.test_ratio < 1.0:
write_list(args.prefix + str_chunk + '_val.lst', chunk[sep_test + sep:])
write_list(args.prefix + str_chunk + '_train.lst', chunk[sep_test:sep_test + sep])
if args.train_ratio == 1.0:
write_list(args.prefix + str_chunk + '.lst', chunk)
else:
if args.test_ratio:
write_list(args.prefix + str_chunk + '_test.lst', chunk[:sep_test])
if args.train_ratio + args.test_ratio < 1.0:
write_list(args.prefix + str_chunk + '_val.lst', chunk[sep_test + sep:])
write_list(args.prefix + str_chunk + '_train.lst', chunk[sep_test:sep_test + sep])

def read_list(path_in):
with open(path_in) as fin:
Expand All @@ -79,14 +86,31 @@ def read_list(path_in):

def image_encode(args, i, item, q_out):
fullpath = os.path.join(args.root, item[1])

if len(item) > 3 and args.pack_label:
header = mx.recordio.IRHeader(0, item[2:], item[0], 0)
else:
header = mx.recordio.IRHeader(0, item[2], item[0], 0)

if args.pass_through:
try:
with open(fullpath) as fin:
img = fin.read()
s = mx.recordio.pack(header, img)
q_out.put((i, s, item))
except Exception, e:
traceback.print_exc()
print('pack_img error:', item[1], e)
return

try:
img = cv2.imread(fullpath, args.color)
except:
traceback.print_exc()
print('imread error trying to load file: %s ' % fullpath, e)
print('imread error trying to load file: %s ' % fullpath)
return
if img is None:
print('imread read blank (None) image for file: %s' % fullpath, e)
print('imread read blank (None) image for file: %s' % fullpath)
return
if args.center_crop:
if img.shape[0] > img.shape[1]:
Expand All @@ -101,10 +125,6 @@ def image_encode(args, i, item, q_out):
else:
newsize = (img.shape[1] * args.resize / img.shape[0], args.resize)
img = cv2.resize(img, newsize)
if len(item) > 3 and args.pack_label:
header = mx.recordio.IRHeader(0, item[2:], item[0], 0)
else:
header = mx.recordio.IRHeader(0, item[2], item[0], 0)

try:
s = mx.recordio.pack_img(header, img, quality=args.quality, img_fmt=args.encoding)
Expand Down Expand Up @@ -178,6 +198,8 @@ def parse_args():
im2rec will randomize the image order in <prefix>.lst')

rgroup = parser.add_argument_group('Options for creating database')
rgroup.add_argument('--pass-through', type=bool, default=False,
help='whether to skip transformation and save image as is')
rgroup.add_argument('--resize', type=int, default=0,
help='resize the shorter edge of image to the newsize, original images will\
be packed by default.')
Expand Down Expand Up @@ -221,8 +243,7 @@ def parse_args():
count += 1
image_list = read_list(fname)
# -- write_record -- #
try:
import multiprocessing
if args.num_thread > 1 and multiprocessing is not None:
q_in = [multiprocessing.Queue(1024) for i in range(args.num_thread)]
q_out = multiprocessing.Queue(1024)
read_process = [multiprocessing.Process(target=read_worker, args=(args, q_in[i], q_out)) \
Expand All @@ -241,24 +262,23 @@ def parse_args():

q_out.put(None)
write_process.join()
except ImportError:
else:
print('multiprocessing not available, fall back to single threaded encoding')
import Queue
q_out = Queue.Queue()
fname = os.path.basename(fname)
fname_rec = os.path.splitext(fname)[0] + '.rec'
fname_idx = os.path.splitext(fname)[0] + '.idx'
fidx = open(os.path.join(working_dir, fname_idx), 'w')
record = mx.recordio.MXRecordIO(os.path.join(working_dir, fname_rec), 'w')
record = mx.recordio.MXIndexedRecordIO(os.path.join(working_dir, fname_idx),
os.path.join(working_dir, fname_rec), 'w')
cnt = 0
pre_time = time.time()
for item in image_list:
image_encode(args, item, q_out)
for i, item in enumerate(image_list):
image_encode(args, i, item, q_out)
if q_out.empty():
continue
_, s, _ = q_out.get()
record.write(s)
fidx.write('%d\t%d\n'%(item[0], record.tell()))
record.write_idx(item[0], s)
if cnt % 1000 == 0:
cur_time = time.time()
print('time:', cur_time - pre_time, ' count:', cnt)
Expand Down

0 comments on commit 96d4ae4

Please sign in to comment.