From 84ce29bcbbe0b17240d1976f49871eab0560238c Mon Sep 17 00:00:00 2001 From: Eric Junyuan Xie Date: Mon, 24 Jul 2017 20:39:50 -0700 Subject: [PATCH] Gluon data pipeline (#7155) * add data pipeline to gluon * add cifar * fix * fix * fix --- example/gluon/actor_critic.py | 2 +- example/gluon/dcgan.py | 339 ++++++++++++----------- example/gluon/mnist.py | 35 ++- example/gluon/super_resolution.py | 7 +- python/mxnet/gluon/.gitignore | 1 + python/mxnet/gluon/__init__.py | 2 + python/mxnet/gluon/data/__init__.py | 11 + python/mxnet/gluon/data/dataloader.py | 70 +++++ python/mxnet/gluon/data/dataset.py | 89 ++++++ python/mxnet/gluon/data/sampler.py | 120 ++++++++ python/mxnet/gluon/data/vision.py | 125 +++++++++ python/mxnet/gluon/utils.py | 50 ++++ python/mxnet/image/image.py | 4 +- python/mxnet/ndarray.py | 12 +- src/operator/elemwise_op_common.h | 16 +- src/operator/tensor/elemwise_sum.cc | 20 +- src/operator/tensor/matrix_op-inl.h | 109 ++++++++ src/operator/tensor/matrix_op.cc | 52 ++++ src/operator/tensor/matrix_op.cu | 6 + tests/python/unittest/test_gluon_data.py | 53 ++++ tests/python/unittest/test_operator.py | 17 ++ 21 files changed, 933 insertions(+), 207 deletions(-) create mode 100644 python/mxnet/gluon/.gitignore create mode 100644 python/mxnet/gluon/data/__init__.py create mode 100644 python/mxnet/gluon/data/dataloader.py create mode 100644 python/mxnet/gluon/data/dataset.py create mode 100644 python/mxnet/gluon/data/sampler.py create mode 100644 python/mxnet/gluon/data/vision.py create mode 100644 tests/python/unittest/test_gluon_data.py diff --git a/example/gluon/actor_critic.py b/example/gluon/actor_critic.py index 7910c73030e1..9c475ce15017 100644 --- a/example/gluon/actor_critic.py +++ b/example/gluon/actor_critic.py @@ -43,7 +43,7 @@ def forward(self, x): return F.softmax(probs), values net = Policy() -net.collect_params().initialize(mx.init.Uniform(0.02)) +net.initialize(mx.init.Uniform(0.02)) trainer = gluon.Trainer(net.collect_params(), 'adam', {'learning_rate': 3e-2}) loss = gluon.loss.L1Loss() diff --git a/example/gluon/dcgan.py b/example/gluon/dcgan.py index 17d02e7fbede..7f644cba5962 100644 --- a/example/gluon/dcgan.py +++ b/example/gluon/dcgan.py @@ -7,10 +7,8 @@ from mxnet import gluon from mxnet.gluon import nn from mxnet import autograd -from data import cifar10_iterator import numpy as np import logging -import cv2 from datetime import datetime import os import time @@ -32,173 +30,190 @@ def visual(title, X, name): buff = np.zeros((int(n*X.shape[1]), int(n*X.shape[2]), int(X.shape[3])), dtype=np.uint8) for i, img in enumerate(X): fill_buf(buff, i, img, X.shape[1:3]) - buff = cv2.cvtColor(buff, cv2.COLOR_BGR2RGB) + buff = buff[:,:,::-1] plt.imshow(buff) plt.title(title) plt.savefig(name) - return None -if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument('--dataset', type=str, default='cifar10', help='dataset to use. options are cifar10 and imagenet.') - parser.add_argument('--batchSize', type=int, default=64, help='input batch size') - parser.add_argument('--imageSize', type=int, default=64, help='the height / width of the input image to network') - parser.add_argument('--nz', type=int, default=100, help='size of the latent z vector') - parser.add_argument('--ngf', type=int, default=64) - parser.add_argument('--ndf', type=int, default=64) - parser.add_argument('--niter', type=int, default=25, help='number of epochs to train for') - parser.add_argument('--lr', type=float, default=0.0002, help='learning rate, default=0.0002') - parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5') - parser.add_argument('--cuda', action='store_true', help='enables cuda') - parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use') - parser.add_argument('--netG', default='', help="path to netG (to continue training)") - parser.add_argument('--netD', default='', help="path to netD (to continue training)") - parser.add_argument('--outf', default='./results', help='folder to output images and model checkpoints') - parser.add_argument('--manualSeed', type=int, help='manual seed') - parser.add_argument('--check_point', default=True, help="save results at each epoch or not") - - opt = parser.parse_args() - print(opt) - - logging.basicConfig(level=logging.DEBUG) - ngpu = int(opt.ngpu) - nz = int(opt.nz) - ngf = int(opt.ngf) - ndf = int(opt.ndf) - nc = 3 + +parser = argparse.ArgumentParser() +parser.add_argument('--dataset', type=str, default='cifar10', help='dataset to use. options are cifar10 and imagenet.') +parser.add_argument('--batch-size', type=int, default=64, help='input batch size') +parser.add_argument('--nz', type=int, default=100, help='size of the latent z vector') +parser.add_argument('--ngf', type=int, default=64) +parser.add_argument('--ndf', type=int, default=64) +parser.add_argument('--nepoch', type=int, default=25, help='number of epochs to train for') +parser.add_argument('--lr', type=float, default=0.0002, help='learning rate, default=0.0002') +parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5') +parser.add_argument('--cuda', action='store_true', help='enables cuda') +parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use') +parser.add_argument('--netG', default='', help="path to netG (to continue training)") +parser.add_argument('--netD', default='', help="path to netD (to continue training)") +parser.add_argument('--outf', default='./results', help='folder to output images and model checkpoints') +parser.add_argument('--check-point', default=True, help="save results at each epoch or not") + +opt = parser.parse_args() +print(opt) + +logging.basicConfig(level=logging.DEBUG) +ngpu = int(opt.ngpu) +nz = int(opt.nz) +ngf = int(opt.ngf) +ndf = int(opt.ndf) +nc = 3 +if opt.cuda: ctx = mx.gpu(0) - check_point = bool(opt.check_point) - outf = opt.outf - - if not os.path.exists(outf): - os.makedirs(outf) - - if opt.dataset == 'cifar10': - train_iter, val_iter = cifar10_iterator(opt.batchSize, (3, 64, 64), 64) - - # build the generator - netG = nn.Sequential() - with netG.name_scope(): - # input is Z, going into a convolution - netG.add(nn.Conv2DTranspose(ngf * 8, 4, 1, 0, use_bias=False)) - netG.add(nn.BatchNorm()) - netG.add(nn.Activation('relu')) - # state size. (ngf*8) x 4 x 4 - netG.add(nn.Conv2DTranspose(ngf * 4, 4, 2, 1, use_bias=False)) - netG.add(nn.BatchNorm()) - netG.add(nn.Activation('relu')) - # state size. (ngf*8) x 8 x 8 - netG.add(nn.Conv2DTranspose(ngf * 2, 4, 2, 1, use_bias=False)) - netG.add(nn.BatchNorm()) - netG.add(nn.Activation('relu')) - # state size. (ngf*8) x 16 x 16 - netG.add(nn.Conv2DTranspose(ngf, 4, 2, 1, use_bias=False)) - netG.add(nn.BatchNorm()) - netG.add(nn.Activation('relu')) - # state size. (ngf*8) x 32 x 32 - netG.add(nn.Conv2DTranspose(nc, 4, 2, 1, use_bias=False)) - netG.add(nn.Activation('tanh')) - # state size. (nc) x 64 x 64 - - # build the discriminator - netD = nn.Sequential() - with netD.name_scope(): - # input is (nc) x 64 x 64 - netD.add(nn.Conv2D(ndf, 4, 2, 1, use_bias=False)) - netD.add(nn.LeakyReLU(0.2)) - # state size. (ndf) x 32 x 32 - netD.add(nn.Conv2D(ndf * 2, 4, 2, 1, use_bias=False)) - netD.add(nn.BatchNorm()) - netD.add(nn.LeakyReLU(0.2)) - # state size. (ndf) x 16 x 16 - netD.add(nn.Conv2D(ndf * 4, 4, 2, 1, use_bias=False)) - netD.add(nn.BatchNorm()) - netD.add(nn.LeakyReLU(0.2)) - # state size. (ndf) x 8 x 8 - netD.add(nn.Conv2D(ndf * 8, 4, 2, 1, use_bias=False)) - netD.add(nn.BatchNorm()) - netD.add(nn.LeakyReLU(0.2)) - # state size. (ndf) x 4 x 4 - netD.add(nn.Conv2D(2, 4, 1, 0, use_bias=False)) - - # loss - loss = gluon.loss.SoftmaxCrossEntropyLoss() - - # initialize the generator and the discriminator - netG.collect_params().initialize(mx.init.Normal(0.02), ctx=ctx) - netD.collect_params().initialize(mx.init.Normal(0.02), ctx=ctx) - - # trainer for the generator and the discriminator - trainerG = gluon.Trainer(netG.collect_params(), 'adam', {'learning_rate': opt.lr, 'beta1': opt.beta1}) - trainerD = gluon.Trainer(netD.collect_params(), 'adam', {'learning_rate': opt.lr, 'beta1': opt.beta1}) - - # ============printing============== - real_label = mx.nd.ones((opt.batchSize,), ctx=ctx) - fake_label = mx.nd.zeros((opt.batchSize,), ctx=ctx) - - metric = mx.metric.Accuracy() - print('Training... ') - stamp = datetime.now().strftime('%Y_%m_%d-%H_%M') - - iter = 0 - for epoch in range(opt.niter): - tic = time.time() - train_iter.reset() - btic = time.time() - for batch in train_iter: - ############################ - # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z))) - ########################### - # train with real_t - data = batch.data[0].copyto(ctx) - noise = mx.nd.random_normal(0, 1, shape=(opt.batchSize, nz, 1, 1), ctx=ctx) - - with autograd.record(): - output = netD(data) - output = output.reshape((opt.batchSize, 2)) - errD_real = loss(output, real_label) - metric.update([real_label,], [output,]) - - fake = netG(noise) - output = netD(fake.detach()) - output = output.reshape((opt.batchSize, 2)) - errD_fake = loss(output, fake_label) - errD = errD_real + errD_fake - errD.backward() - metric.update([fake_label,], [output,]) - - trainerD.step(opt.batchSize) - - ############################ - # (2) Update G network: maximize log(D(G(z))) - ########################### - with autograd.record(): - output = netD(fake) - output = output.reshape((opt.batchSize, 2)) - errG = loss(output, real_label) - errG.backward() - - trainerG.step(opt.batchSize) - - name, acc = metric.get() - # logging.info('speed: {} samples/s'.format(opt.batchSize / (time.time() - btic))) - logging.info('discriminator loss = %f, generator loss = %f, binary training acc = %f at iter %d epoch %d' %(mx.nd.mean(errD).asscalar(), mx.nd.mean(errG).asscalar(), acc, iter, epoch)) - if iter % 200 == 0: - visual('gout', fake.asnumpy(), name=os.path.join(outf,'fake_img_iter_%d.png' %iter)) - visual('data', batch.data[0].asnumpy(), name=os.path.join(outf,'real_img_iter_%d.png' %iter)) - - iter = iter + 1 - btic = time.time() +else: + ctx = mx.cpu() +check_point = bool(opt.check_point) +outf = opt.outf + +if not os.path.exists(outf): + os.makedirs(outf) + + +def transformer(data, label): + # resize to 64x64 + data = mx.image.imresize(data, 64, 64) + # transpose from (64, 64, 3) to (3, 64, 64) + data = mx.nd.transpose(data, (2,0,1)) + # normalize to [-1, 1] + data = data.astype(np.float32)/128 - 1 + # if image is greyscale, repeat 3 times to get RGB image. + if data.shape[0] == 1: + data = mx.nd.tile(data, (3, 1, 1)) + return data, label + +train_data = gluon.data.DataLoader( + gluon.data.vision.MNIST('./data', train=True, transform=transformer), + batch_size=opt.batch_size, shuffle=True, last_batch='discard') + +val_data = gluon.data.DataLoader( + gluon.data.vision.MNIST('./data', train=False, transform=transformer), + batch_size=opt.batch_size, shuffle=False) + + +# build the generator +netG = nn.Sequential() +with netG.name_scope(): + # input is Z, going into a convolution + netG.add(nn.Conv2DTranspose(ngf * 8, 4, 1, 0, use_bias=False)) + netG.add(nn.BatchNorm()) + netG.add(nn.Activation('relu')) + # state size. (ngf*8) x 4 x 4 + netG.add(nn.Conv2DTranspose(ngf * 4, 4, 2, 1, use_bias=False)) + netG.add(nn.BatchNorm()) + netG.add(nn.Activation('relu')) + # state size. (ngf*8) x 8 x 8 + netG.add(nn.Conv2DTranspose(ngf * 2, 4, 2, 1, use_bias=False)) + netG.add(nn.BatchNorm()) + netG.add(nn.Activation('relu')) + # state size. (ngf*8) x 16 x 16 + netG.add(nn.Conv2DTranspose(ngf, 4, 2, 1, use_bias=False)) + netG.add(nn.BatchNorm()) + netG.add(nn.Activation('relu')) + # state size. (ngf*8) x 32 x 32 + netG.add(nn.Conv2DTranspose(nc, 4, 2, 1, use_bias=False)) + netG.add(nn.Activation('tanh')) + # state size. (nc) x 64 x 64 + +# build the discriminator +netD = nn.Sequential() +with netD.name_scope(): + # input is (nc) x 64 x 64 + netD.add(nn.Conv2D(ndf, 4, 2, 1, use_bias=False)) + netD.add(nn.LeakyReLU(0.2)) + # state size. (ndf) x 32 x 32 + netD.add(nn.Conv2D(ndf * 2, 4, 2, 1, use_bias=False)) + netD.add(nn.BatchNorm()) + netD.add(nn.LeakyReLU(0.2)) + # state size. (ndf) x 16 x 16 + netD.add(nn.Conv2D(ndf * 4, 4, 2, 1, use_bias=False)) + netD.add(nn.BatchNorm()) + netD.add(nn.LeakyReLU(0.2)) + # state size. (ndf) x 8 x 8 + netD.add(nn.Conv2D(ndf * 8, 4, 2, 1, use_bias=False)) + netD.add(nn.BatchNorm()) + netD.add(nn.LeakyReLU(0.2)) + # state size. (ndf) x 4 x 4 + netD.add(nn.Conv2D(2, 4, 1, 0, use_bias=False)) + +# loss +loss = gluon.loss.SoftmaxCrossEntropyLoss() + +# initialize the generator and the discriminator +netG.initialize(mx.init.Normal(0.02), ctx=ctx) +netD.initialize(mx.init.Normal(0.02), ctx=ctx) + +# trainer for the generator and the discriminator +trainerG = gluon.Trainer(netG.collect_params(), 'adam', {'learning_rate': opt.lr, 'beta1': opt.beta1}) +trainerD = gluon.Trainer(netD.collect_params(), 'adam', {'learning_rate': opt.lr, 'beta1': opt.beta1}) + +# ============printing============== +real_label = mx.nd.ones((opt.batch_size,), ctx=ctx) +fake_label = mx.nd.zeros((opt.batch_size,), ctx=ctx) + +metric = mx.metric.Accuracy() +print('Training... ') +stamp = datetime.now().strftime('%Y_%m_%d-%H_%M') + +iter = 0 +for epoch in range(opt.nepoch): + tic = time.time() + btic = time.time() + for data, _ in train_data: + ############################ + # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z))) + ########################### + # train with real_t + data = data.as_in_context(ctx) + noise = mx.nd.random_normal(0, 1, shape=(opt.batch_size, nz, 1, 1), ctx=ctx) + + with autograd.record(): + output = netD(data) + output = output.reshape((opt.batch_size, 2)) + errD_real = loss(output, real_label) + metric.update([real_label,], [output,]) + + fake = netG(noise) + output = netD(fake.detach()) + output = output.reshape((opt.batch_size, 2)) + errD_fake = loss(output, fake_label) + errD = errD_real + errD_fake + errD.backward() + metric.update([fake_label,], [output,]) + + trainerD.step(opt.batch_size) + + ############################ + # (2) Update G network: maximize log(D(G(z))) + ########################### + with autograd.record(): + output = netD(fake) + output = output.reshape((-1, 2)) + errG = loss(output, real_label) + errG.backward() + + trainerG.step(opt.batch_size) name, acc = metric.get() - metric.reset() - logging.info('\nbinary training acc at epoch %d: %s=%f' % (epoch, name, acc)) - logging.info('time: %f' % (time.time() - tic)) + # logging.info('speed: {} samples/s'.format(opt.batch_size / (time.time() - btic))) + logging.info('discriminator loss = %f, generator loss = %f, binary training acc = %f at iter %d epoch %d' %(mx.nd.mean(errD).asscalar(), mx.nd.mean(errG).asscalar(), acc, iter, epoch)) + if iter % 1 == 0: + visual('gout', fake.asnumpy(), name=os.path.join(outf,'fake_img_iter_%d.png' %iter)) + visual('data', data.asnumpy(), name=os.path.join(outf,'real_img_iter_%d.png' %iter)) + + iter = iter + 1 + btic = time.time() - if check_point: - netG.collect_params().save(os.path.join(outf,'generator_epoch_%d.params' %epoch)) - netD.collect_params().save(os.path.join(outf,'discriminator_epoch_%d.params' % epoch)) + name, acc = metric.get() + metric.reset() + logging.info('\nbinary training acc at epoch %d: %s=%f' % (epoch, name, acc)) + logging.info('time: %f' % (time.time() - tic)) - netG.collect_params().save(os.path.join(outf, 'generator.params')) - netD.collect_params().save(os.path.join(outf, 'discriminator.params')) + if check_point: + netG.save_params(os.path.join(outf,'generator_epoch_%d.params' %epoch)) + netD.save_params(os.path.join(outf,'discriminator_epoch_%d.params' % epoch)) +netG.save_params(os.path.join(outf, 'generator.params')) +netD.save_params(os.path.join(outf, 'discriminator.params')) diff --git a/example/gluon/mnist.py b/example/gluon/mnist.py index 99ac2a9a8d48..9d567d5011cb 100644 --- a/example/gluon/mnist.py +++ b/example/gluon/mnist.py @@ -10,9 +10,6 @@ from mxnet import gluon, autograd from mxnet.gluon import nn -from data import mnist_iterator - - # Parse CLI arguments parser = argparse.ArgumentParser(description='MXNet Gluon MNIST Example') @@ -41,16 +38,25 @@ # data -train_data, val_data = mnist_iterator(batch_size=opt.batch_size, input_shape=(28*28,)) +def transformer(data, label): + data = data.reshape((-1,)).astype(np.float32)/255 + return data, label + +train_data = gluon.data.DataLoader( + gluon.data.vision.MNIST('./data', train=True, transform=transformer), + batch_size=opt.batch_size, shuffle=True, last_batch='discard') + +val_data = gluon.data.DataLoader( + gluon.data.vision.MNIST('./data', train=False, transform=transformer), + batch_size=opt.batch_size, shuffle=False) # train def test(ctx): metric = mx.metric.Accuracy() - val_data.reset() - for batch in val_data: - data = batch.data[0].as_in_context(ctx) - label = batch.label[0].as_in_context(ctx) + for data, label in val_data: + data = data.as_in_context(ctx) + label = label.as_in_context(ctx) output = net(data) metric.update([label], [output]) @@ -59,21 +65,20 @@ def test(ctx): def train(epochs, ctx): # Collect all parameters from net and its children, then initialize them. - net.collect_params().initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx) + net.initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx) # Trainer is for updating parameters with gradient. trainer = gluon.Trainer(net.collect_params(), 'sgd', - {'learning_rate': opt.lr, 'momentum': opt.momentum}) + {'learning_rate': opt.lr, 'momentum': opt.momentum}) metric = mx.metric.Accuracy() loss = gluon.loss.SoftmaxCrossEntropyLoss() for epoch in range(epochs): # reset data iterator and metric at begining of epoch. - train_data.reset() metric.reset() - for i, batch in enumerate(train_data): + for i, (data, label) in enumerate(train_data): # Copy data to ctx if necessary - data = batch.data[0].as_in_context(ctx) - label = batch.label[0].as_in_context(ctx) + data = data.as_in_context(ctx) + label = label.as_in_context(ctx) # Start recording computation graph with record() section. # Recorded graphs can then be differentiated with backward. with autograd.record(): @@ -95,7 +100,7 @@ def train(epochs, ctx): name, val_acc = test(ctx) print('[Epoch %d] Validation: %s=%f'%(epoch, name, val_acc)) - net.collect_params().save('mnist.params') + net.save_params('mnist.params') if __name__ == '__main__': diff --git a/example/gluon/super_resolution.py b/example/gluon/super_resolution.py index 521c17aeb71d..d61fb160e197 100644 --- a/example/gluon/super_resolution.py +++ b/example/gluon/super_resolution.py @@ -125,7 +125,8 @@ def test(ctx): def train(epoch, ctx): if isinstance(ctx, mx.Context): ctx = [ctx] - net.collect_params().initialize(mx.init.Orthogonal(), ctx=ctx) + net.initialize(mx.init.Orthogonal(), ctx=ctx) + # re-initialize conv4's weight to be Orthogonal net.conv4.collect_params().initialize(mx.init.Orthogonal(scale=1), ctx=ctx) trainer = gluon.Trainer(net.collect_params(), 'adam', {'learning_rate': opt.lr}) loss = gluon.loss.L2Loss() @@ -150,13 +151,13 @@ def train(epoch, ctx): print('training mse at epoch %d: %s=%f'%(i, name, acc)) test(ctx) - net.collect_params().save('superres.params') + net.save_params('superres.params') def resolve(ctx): from PIL import Image if isinstance(ctx, list): ctx = [ctx[0]] - net.collect_params().load('superres.params', ctx=ctx) + net.load_params('superres.params', ctx=ctx) img = Image.open(opt.resolve_img).convert('YCbCr') y, cb, cr = img.split() data = mx.nd.expand_dims(mx.nd.expand_dims(mx.nd.array(y), axis=0), axis=0) diff --git a/python/mxnet/gluon/.gitignore b/python/mxnet/gluon/.gitignore new file mode 100644 index 000000000000..8436a89ff416 --- /dev/null +++ b/python/mxnet/gluon/.gitignore @@ -0,0 +1 @@ +!data diff --git a/python/mxnet/gluon/__init__.py b/python/mxnet/gluon/__init__.py index 0910fdf8ce06..c559e7af343b 100644 --- a/python/mxnet/gluon/__init__.py +++ b/python/mxnet/gluon/__init__.py @@ -15,3 +15,5 @@ from . import loss from . import utils + +from . import data diff --git a/python/mxnet/gluon/data/__init__.py b/python/mxnet/gluon/data/__init__.py new file mode 100644 index 000000000000..a0623257417c --- /dev/null +++ b/python/mxnet/gluon/data/__init__.py @@ -0,0 +1,11 @@ +# coding: utf-8 +# pylint: disable=wildcard-import +"""Dataset utilities.""" + +from .dataset import * + +from .sampler import * + +from .dataloader import * + +from . import vision diff --git a/python/mxnet/gluon/data/dataloader.py b/python/mxnet/gluon/data/dataloader.py new file mode 100644 index 000000000000..148d7cd6e53c --- /dev/null +++ b/python/mxnet/gluon/data/dataloader.py @@ -0,0 +1,70 @@ +# coding: utf-8 +# pylint: disable= +"""Dataset generator.""" + +import numpy as np + +from . import sampler as _sampler +from ... import nd + + +def _batchify(data): + """Collate data into batch.""" + if isinstance(data[0], nd.NDArray): + return nd.stack(*data) + elif isinstance(data[0], tuple): + data = zip(*data) + return [_batchify(i) for i in data] + else: + data = np.asarray(data) + return nd.array(data, dtype=data.dtype) + + +class DataLoader(object): + """Loads data from a dataset and returns mini-batches of data. + + Parameters + ---------- + dataset : Dataset + Source dataset. Note that numpy and mxnet arrays can be directly used + as a Dataset. + batch_size : int + Size of mini-batch. + shuffle : bool + Whether to shuffle the samples. + sampler : Sampler + The sampler to use. Either specify sampler or shuffle, not both. + batch_sampler : Sampler + A sampler that returns mini-batches. Do not specify batch_size, + shuffle, sampler, and last_batch if batch_sampler is specified. + """ + def __init__(self, dataset, batch_size=None, shuffle=False, sampler=None, + last_batch=None, batch_sampler=None): + self._dataset = dataset + + if batch_sampler is None: + if batch_size is None: + raise ValueError("batch_size must be specified unless " \ + "batch_sampler is specified") + if sampler is None: + if shuffle: + sampler = _sampler.RandomSampler(len(dataset)) + else: + sampler = _sampler.SequentialSampler(len(dataset)) + elif shuffle: + raise ValueError("shuffle must not be specified if sampler is specified") + + batch_sampler = _sampler.BatchSampler(sampler, batch_size, last_batch) + elif batch_size is not None or shuffle or sampler is not None or \ + last_batch is not None: + raise ValueError("batch_size, shuffle, sampler and last_batch must " \ + "not be specified if batch_sampler is specified.") + + self._batch_sampler = batch_sampler + + def __iter__(self): + for batch in self._batch_sampler: + yield _batchify([self._dataset[idx] for idx in batch]) + + def __len__(self): + return self._batch_sampler diff --git a/python/mxnet/gluon/data/dataset.py b/python/mxnet/gluon/data/dataset.py new file mode 100644 index 000000000000..aefff0af16c9 --- /dev/null +++ b/python/mxnet/gluon/data/dataset.py @@ -0,0 +1,89 @@ +# coding: utf-8 +# pylint: disable= +"""Dataset container.""" +import os + +from ... import recordio, image + +class Dataset(object): + """Abstract dataset class. All datasets should have this interface. + + Subclasses need to override `__getitem__`, which returns the i-th + element, and `__len__`, which returns the total number elements. + + .. note:: An mxnet or numpy array can be directly used as a dataset. + """ + def __getitem__(self, idx): + raise NotImplementedError + + def __len__(self): + raise NotImplementedError + + +class ArrayDataset(Dataset): + """A dataset with a data array and a label array. + + The i-th sample is `(data[i], lable[i])`. + + Parameters + ---------- + data : array-like object + The data array. Can be mxnet or numpy array. + label : array-like object + The label array. Can be mxnet or numpy array. + """ + def __init__(self, data, label): + assert len(data) == len(label) + self._data = data + self._label = label + + def __getitem__(self, idx): + return self._data[idx], self._label[idx] + + def __len__(self): + return len(self._data) + + +class RecordFileDataset(Dataset): + """A dataset wrapping over a RecordIO (.rec) file. + + Each sample is a string representing the raw content of an record. + + Parameters + ---------- + filename : str + Path to rec file. + """ + def __init__(self, filename): + idx_file = os.path.splitext(filename)[0] + '.idx' + self._record = recordio.MXIndexedRecordIO(idx_file, filename, 'r') + + def __getitem__(self, idx): + return self._record.read_idx(idx) + + def __len__(self): + return len(self._record.keys) + + +class ImageRecordDataset(RecordFileDataset): + """A dataset wrapping over a RecordIO file containing images. + + Each sample is an image and its corresponding label. + + Parameters + ---------- + filename : str + Path to rec file. + flag : {0, 1}, default 1 + If 0, always convert images to greyscale. + + If 1, always convert images to colored (RGB). + """ + def __init__(self, filename, flag=1): + super(ImageRecordDataset, self).__init__(filename) + self._flag = flag + + def __getitem__(self, idx): + record = super(ImageRecordDataset, self).__getitem__(idx) + header, img = recordio.unpack(record) + return image.imdecode(img, self._flag), header.label diff --git a/python/mxnet/gluon/data/sampler.py b/python/mxnet/gluon/data/sampler.py new file mode 100644 index 000000000000..7bfc418399f5 --- /dev/null +++ b/python/mxnet/gluon/data/sampler.py @@ -0,0 +1,120 @@ +# coding: utf-8 +# pylint: disable= +"""Dataset sampler.""" + +import random + +class Sampler(object): + """Base class for samplers. + + All samplers should subclass `Sampler` and define `__iter__` and `__len__` + methods. + """ + def __iter__(self): + raise NotImplementedError + + def __len__(self): + raise NotImplementedError + + +class SequentialSampler(Sampler): + """Samples elements from [0, length) sequentially. + + Parameters + ---------- + length : int + Length of the sequence. + """ + def __init__(self, length): + self._length = length + + def __iter__(self): + return iter(range(self._length)) + + def __len__(self): + return self._length + + +class RandomSampler(Sampler): + """Samples elements from [0, length) randomly without replacement. + + Parameters + ---------- + length : int + Length of the sequence. + """ + def __init__(self, length): + self._length = length + + def __iter__(self): + indices = range(self._length) + random.shuffle(indices) + return iter(indices) + + def __len__(self): + return self._length + + +class BatchSampler(Sampler): + """Wraps over another `Sampler` and return mini-batches of samples. + + Parameters + ---------- + sampler : Sampler + The source Sampler. + batch_size : int + Size of mini-batch. + last_batch : {'keep', 'discard', 'rollover'} + Specifies how the last batch is handled if batch_size does not evenly + divide sequence length. + + If 'keep', the last batch will be returned directly, but will contain + less element than `batch_size` requires. + + If 'discard', the last batch will be discarded. + + If 'rollover', the remaining elements will be rolled over to the next + iteration. + + Examples + -------- + >>> sampler = gluon.data.SequentialSampler(10) + >>> batch_sampler = gluon.data.BatchSampler(sampler, 3, 'keep') + >>> list(batch_sampler) + [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] + """ + def __init__(self, sampler, batch_size, last_batch='keep'): + self._sampler = sampler + self._batch_size = batch_size + self._last_batch = last_batch + self._prev = [] + + def __iter__(self): + batch, self._prev = self._prev, [] + for i in self._sampler: + batch.append(i) + if len(batch) == self._batch_size: + yield batch + batch = [] + if batch: + if self._last_batch == 'keep': + yield batch + elif self._last_batch == 'discard': + return + elif self._last_batch == 'rollover': + self._prev = batch + else: + raise ValueError( + "last_batch must be one of 'keep', 'discard', or 'rollover', " \ + "but got %s"%self._last_batch) + + def __len__(self): + if self._last_batch == 'keep': + return (len(self._sampler) + self._batch_size - 1) // self._batch_size + if self._last_batch == 'discard': + return len(self._sampler) // self._batch_size + if self._last_batch == 'rollover': + return (len(self._prev) + len(self._sampler)) // self._batch_size + raise ValueError( + "last_batch must be one of 'keep', 'discard', or 'rollover', " \ + "but got %s"%self._last_batch) diff --git a/python/mxnet/gluon/data/vision.py b/python/mxnet/gluon/data/vision.py new file mode 100644 index 000000000000..36c4642e7665 --- /dev/null +++ b/python/mxnet/gluon/data/vision.py @@ -0,0 +1,125 @@ +# coding: utf-8 +# pylint: disable= +"""Dataset container.""" + +import os +import gzip +import tarfile +import struct +import numpy as np + +from . import dataset +from ..utils import download +from ... import nd + + +class _DownloadedDataset(dataset.Dataset): + """Base class for MNIST, cifar10, etc.""" + def __init__(self, root, train, transform): + self._root = os.path.expanduser(root) + self._train = train + self._transform = transform + self._data = None + self._label = None + + self._get_data() + + def __getitem__(self, idx): + return self._transform(self._data[idx], self._label[idx]) + + def __len__(self): + return len(self._label) + + def _get_data(self): + raise NotImplementedError + + +class MNIST(_DownloadedDataset): + """MNIST handwritten digits dataset from `http://yann.lecun.com/exdb/mnist`_. + + Each sample is an image (in 3D NDArray) with shape (28, 28, 1). + + Parameters + ---------- + root : str + Path to temp folder for storing data. + train : bool + Whether to load the training or testing set. + transform : function + A user defined callback that transforms each instance. For example:: + + transform=lambda data, label: (data.astype(np.float32)/255, label) + """ + def __init__(self, root, train=True, transform=lambda data, label: (data, label)): + super(MNIST, self).__init__(root, train, transform) + + def _get_data(self): + if not os.path.isdir(self._root): + os.makedirs(self._root) + url = 'http://data.mxnet.io/data/mnist/' + if self._train: + data_file = download(url+'train-images-idx3-ubyte.gz', self._root) + label_file = download(url+'train-labels-idx1-ubyte.gz', self._root) + else: + data_file = download(url+'t10k-images-idx3-ubyte.gz', self._root) + label_file = download(url+'t10k-labels-idx1-ubyte.gz', self._root) + + with gzip.open(label_file, 'rb') as fin: + struct.unpack(">II", fin.read(8)) + label = np.fromstring(fin.read(), dtype=np.uint8).astype(np.int32) + + with gzip.open(data_file, 'rb') as fin: + struct.unpack(">IIII", fin.read(16)) + data = np.fromstring(fin.read(), dtype=np.uint8) + data = data.reshape(len(label), 28, 28, 1) + + self._data = [nd.array(x, dtype=x.dtype) for x in data] + self._label = label + + +class CIFAR10(_DownloadedDataset): + """CIFAR10 image classification dataset from `https://www.cs.toronto.edu/~kriz/cifar.html`_. + + Each sample is an image (in 3D NDArray) with shape (32, 32, 1). + + Parameters + ---------- + root : str + Path to temp folder for storing data. + train : bool + Whether to load the training or testing set. + transform : function + A user defined callback that transforms each instance. For example:: + + transform=lambda data, label: (data.astype(np.float32)/255, label) + """ + def __init__(self, root, train=True, transform=lambda data, label: (data, label)): + super(CIFAR10, self).__init__(root, train, transform) + + def _read_batch(self, filename): + with open(filename, 'rb') as fin: + data = np.fromstring(fin.read(), dtype=np.uint8).reshape(-1, 3072+1) + + return data[:, 1:].reshape(-1, 3, 32, 32).transpose(0, 2, 3, 1), \ + data[:, 0].astype(np.int32) + + def _get_data(self): + if not os.path.isdir(self._root): + os.makedirs(self._root) + url = 'https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz' + filename = download(url, self._root) + + with tarfile.open(filename) as tar: + tar.extractall(self._root) + + if self._train: + filename = os.path.join(self._root, 'cifar-10-batches-bin/data_batch_%d.bin') + data, label = zip(*[self._read_batch(filename%i) for i in range(1, 6)]) + data = np.concatenate(data) + label = np.concatenate(label) + else: + filename = os.path.join(self._root, 'cifar-10-batches-bin/test_batch.bin') + data, label = self._read_batch(filename) + + self._data = [nd.array(x, dtype=x.dtype) for x in data] + self._label = label diff --git a/python/mxnet/gluon/utils.py b/python/mxnet/gluon/utils.py index 27576b55f7f9..505fbc55248c 100644 --- a/python/mxnet/gluon/utils.py +++ b/python/mxnet/gluon/utils.py @@ -1,6 +1,14 @@ # coding: utf-8 # pylint: disable= """Parallelization utility optimizer.""" +import os +try: + import requests +except ImportError: + class requests_failed_to_import(object): + pass + requests = requests_failed_to_import + import math from .. import ndarray @@ -109,3 +117,45 @@ def _indent(s_, numSpaces): s = [first] + [(numSpaces * ' ') + line for line in s] s = '\n'.join(s) return s + + +def download(url, path=None, overwrite=False): + """Download an given URL + + Parameters + ---------- + url : str + URL to download + path : str, optional + Destination path to store downloaded file. By default stores to the + current directory with same name as in url. + overwrite : bool, optional + Whether to overwrite destination file if already exists. + + Returns + ------- + str + The filename of the downloaded file. + """ + if path is None: + fname = url.split('/')[-1] + elif os.path.isdir(path): + fname = os.path.join(path, url.split('/')[-1]) + else: + fname = path + + if overwrite or not os.path.exists(fname): + dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname))) + if not os.path.exists(dirname): + os.makedirs(dirname) + + print('Downloading %s from %s...'%(fname, url)) + r = requests.get(url, stream=True) + if r.status_code != 200: + raise RuntimeError("Failed downloading url %s"%url) + with open(fname, 'wb') as f: + for chunk in r.iter_content(chunk_size=1024): + if chunk: # filter out keep-alive new chunks + f.write(chunk) + + return fname diff --git a/python/mxnet/image/image.py b/python/mxnet/image/image.py index 5bf2afd09204..32b7c4f282b9 100644 --- a/python/mxnet/image/image.py +++ b/python/mxnet/image/image.py @@ -24,7 +24,7 @@ from .. import recordio -def imdecode(buf, **kwargs): +def imdecode(buf, *args, **kwargs): """Decode an image to an NDArray. Note: `imdecode` uses OpenCV (not the CV2 Python library). @@ -75,7 +75,7 @@ def imdecode(buf, **kwargs): """ if not isinstance(buf, nd.NDArray): buf = nd.array(np.frombuffer(buf, dtype=np.uint8), dtype=np.uint8) - return _internal._cvimdecode(buf, **kwargs) + return _internal._cvimdecode(buf, *args, **kwargs) def scale_down(src_size, size): diff --git a/python/mxnet/ndarray.py b/python/mxnet/ndarray.py index 4939b6c221a5..dff4889749c0 100644 --- a/python/mxnet/ndarray.py +++ b/python/mxnet/ndarray.py @@ -53,7 +53,9 @@ np.float64 : 1, np.float16 : 2, np.uint8 : 3, - np.int32 : 4 + np.int32 : 4, + np.int8 : 5, + np.int64 : 6, } _DTYPE_MX_TO_NP = { @@ -62,7 +64,9 @@ 1 : np.float64, 2 : np.float16, 3 : np.uint8, - 4 : np.int32 + 4 : np.int32, + 5 : np.int8, + 6 : np.int64, } _GRAD_REQ_MAP = { @@ -272,6 +276,10 @@ def __bool__(self): __nonzero__ = __bool__ + def __len__(self): + """Number of element along the first axis.""" + return self.shape[0] + def __getstate__(self): handle = self.handle this = {'handle' : None} diff --git a/src/operator/elemwise_op_common.h b/src/operator/elemwise_op_common.h index aa95d2d8696a..228303c85a82 100644 --- a/src/operator/elemwise_op_common.h +++ b/src/operator/elemwise_op_common.h @@ -66,8 +66,12 @@ template inline bool ElemwiseShape(const nnvm::NodeAttrs& attrs, std::vector *in_attrs, std::vector *out_attrs) { - CHECK_EQ(in_attrs->size(), static_cast(n_in)) << " in operator " << attrs.name; - CHECK_EQ(out_attrs->size(), static_cast(n_out)) << " in operator " << attrs.name; + if (n_in != -1) { + CHECK_EQ(in_attrs->size(), static_cast(n_in)) << " in operator " << attrs.name; + } + if (n_out != -1) { + CHECK_EQ(out_attrs->size(), static_cast(n_out)) << " in operator " << attrs.name; + } return ElemwiseAttr( attrs, in_attrs, out_attrs, TShape()); } @@ -76,8 +80,12 @@ template inline bool ElemwiseType(const nnvm::NodeAttrs& attrs, std::vector *in_attrs, std::vector *out_attrs) { - CHECK_EQ(in_attrs->size(), static_cast(n_in)) << " in operator " << attrs.name; - CHECK_EQ(out_attrs->size(), static_cast(n_out)) << " in operator " << attrs.name; + if (n_in != -1) { + CHECK_EQ(in_attrs->size(), static_cast(n_in)) << " in operator " << attrs.name; + } + if (n_out != -1) { + CHECK_EQ(out_attrs->size(), static_cast(n_out)) << " in operator " << attrs.name; + } return ElemwiseAttr( attrs, in_attrs, out_attrs, -1); } diff --git a/src/operator/tensor/elemwise_sum.cc b/src/operator/tensor/elemwise_sum.cc index 3c4bf719e18f..7ae7ae97acea 100644 --- a/src/operator/tensor/elemwise_sum.cc +++ b/src/operator/tensor/elemwise_sum.cc @@ -36,22 +36,6 @@ std::vector ElementWiseSumGrad( return ret; } -bool ElementWiseSumShape(const nnvm::NodeAttrs& attrs, - std::vector *in_attrs, - std::vector *out_attrs) { - CHECK_EQ(out_attrs->size(), 1); - return ElemwiseAttr( - attrs, in_attrs, out_attrs, TShape()); -} - -bool ElementWiseSumType(const nnvm::NodeAttrs& attrs, - std::vector *in_attrs, - std::vector *out_attrs) { - CHECK_EQ(out_attrs->size(), 1); - return ElemwiseAttr( - attrs, in_attrs, out_attrs, -1); -} - NNVM_REGISTER_OP(add_n) .add_alias("ElementWiseSum") .describe(R"doc(Adds all input arguments element-wise. @@ -81,8 +65,8 @@ NNVM_REGISTER_OP(add_n) "FInplaceOption", [](const NodeAttrs& attrs) { return std::vector >{{0, 0}}; }) -.set_attr("FInferShape", ElementWiseSumShape) -.set_attr("FInferType", ElementWiseSumType) +.set_attr("FInferShape", ElemwiseShape<-1, 1>) +.set_attr("FInferType", ElemwiseType<-1, 1>) .set_attr("FGradient", CloneGradient{"_backward_add_n"}) .add_argument("args", "NDArray-or-Symbol[]", "Positional input arguments"); diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index 2e1aa6661b67..75da055d0098 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -12,6 +12,7 @@ #include #include "../mshadow_op.h" #include "../elemwise_op_common.h" +#include "../channel_op_common.h" #include "../mxnet_op.h" #include "broadcast_reduce_op.h" @@ -1775,6 +1776,114 @@ void ReverseOpForward(const nnvm::NodeAttrs& attrs, } +struct StackParam : public dmlc::Parameter { + int axis; + int num_args; + DMLC_DECLARE_PARAMETER(StackParam) { + DMLC_DECLARE_FIELD(axis) + .set_default(0) + .describe("The axis in the result array along which the input arrays are stacked."); + DMLC_DECLARE_FIELD(num_args).set_lower_bound(1) + .describe("Number of inputs to be stacked."); + } +}; + + +inline bool StackOpShape(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + const StackParam& param = dmlc::get(attrs.parsed); + + TShape dshape; + for (const TShape& i : (*in_attrs)) { + shape_assign(&dshape, i); + } + if (dshape.ndim() == 0) return false; + + TShape oshape(dshape.ndim() + 1); + int axis = CheckAxis(param.axis, oshape.ndim()); + for (int i = 0; i < axis; ++i) { + oshape[i] = dshape[i]; + } + oshape[axis] = param.num_args; + for (index_t i = axis + 1; i < oshape.ndim(); ++i) { + oshape[i] = dshape[i-1]; + } + SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape); + + return true; +} + + +template +void StackOpForward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow::expr; + const StackParam& param = dmlc::get(attrs.parsed); + int axis = CheckAxis(param.axis, outputs[0].ndim()); + + Stream *s = ctx.get_stream(); + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { + std::vector > data(inputs.size()); + Tensor out; + size_t leading = 1, trailing = 1; + for (int i = 0; i < axis; ++i) { + leading *= outputs[0].shape_[i]; + } + for (index_t i = axis + 1; i < outputs[0].ndim(); ++i) { + trailing *= outputs[0].shape_[i]; + } + size_t mid = outputs[0].shape_[axis]; + Shape<3> oshape = Shape3(leading, mid, trailing); + out = outputs[0].get_with_shape(oshape, s); + + for (index_t i = 0; i < inputs.size(); ++i) { + Shape<3> dshape = Shape3(leading, 1, trailing); + data[i] = inputs[i].get_with_shape(dshape, s); + } + Concatenate(data, &out, 1, req[0]); + }) +} + +template +void StackOpBackward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow::expr; + const StackParam& param = dmlc::get(attrs.parsed); + int axis = CheckAxis(param.axis, inputs[0].ndim()); + + Stream *s = ctx.get_stream(); + MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, { + std::vector > grad_in(outputs.size()); + Tensor grad; + size_t leading = 1, trailing = 1; + for (int i = 0; i < axis; ++i) { + leading *= inputs[0].shape_[i]; + } + for (index_t i = axis + 1; i < inputs[0].ndim(); ++i) { + trailing *= inputs[0].shape_[i]; + } + size_t mid = inputs[0].shape_[axis]; + Shape<3> oshape = Shape3(leading, mid, trailing); + grad = inputs[0].get_with_shape(oshape, s); + + for (index_t i = 0; i < outputs.size(); ++i) { + Shape<3> dshape = Shape3(leading, 1, trailing); + grad_in[i] = outputs[i].get_with_shape(dshape, s); + } + Split(grad, &grad_in, 1, req); + }) +} + + } // namespace op } // namespace mxnet diff --git a/src/operator/tensor/matrix_op.cc b/src/operator/tensor/matrix_op.cc index 6a51d46db25c..4832b13f56c5 100644 --- a/src/operator/tensor/matrix_op.cc +++ b/src/operator/tensor/matrix_op.cc @@ -20,6 +20,7 @@ DMLC_REGISTER_PARAMETER(DotParam); DMLC_REGISTER_PARAMETER(RepeatParam); DMLC_REGISTER_PARAMETER(TileParam); DMLC_REGISTER_PARAMETER(ReverseParam); +DMLC_REGISTER_PARAMETER(StackParam); NNVM_REGISTER_OP(Reshape) .add_alias("reshape") @@ -627,5 +628,56 @@ NNVM_REGISTER_OP(_backward_reverse) return std::vector {ResourceRequest::kTempSpace}; }) .set_attr("FCompute", ReverseOpForward); + +NNVM_REGISTER_OP(stack) +.describe(R"code(Join a sequence of arrays along a new axis. + +The axis parameter specifies the index of the new axis in the dimensions of the +result. For example, if axis=0 it will be the first dimension and if axis=-1 it +will be the last dimension. + +Examples:: + + x = [1, 2] + y = [3, 4] + + stack(x, y) = [[1, 2], + [3, 4]] + stack(x, y, axis=1) = [[1, 3], + [2, 4]] +)code") +.set_num_inputs([](const nnvm::NodeAttrs& attrs) { + const StackParam& param = dmlc::get(attrs.parsed); + return static_cast(param.num_args); + }) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + uint32_t num_args = dmlc::get(attrs.parsed).num_args; + std::vector ret; + for (uint32_t i = 0; i < num_args; ++i) { + ret.push_back(std::string("arg") + std::to_string(i)); + } + return ret; + }) +.set_attr("key_var_num_args", "num_args") +.set_attr("FInferShape", StackOpShape) +.set_attr("FInferType", ElemwiseType<-1, 1>) +.set_attr("FCompute", StackOpForward) +.set_attr("FGradient", ElemwiseGradUseNone{"_backward_stack"}) +.add_argument("data", "NDArray-or-Symbol[]", "List of arrays to stack") +.add_arguments(StackParam::__FIELDS__()); + +NNVM_REGISTER_OP(_backward_stack) +.set_num_inputs(1) +.set_num_outputs([](const nnvm::NodeAttrs& attrs) { + const StackParam& param = dmlc::get(attrs.parsed); + return static_cast(param.num_args); + }) +.set_attr_parser(ParamParser) +.set_attr("TIsBackward", true) +.set_attr("FCompute", StackOpBackward); + } // namespace op } // namespace mxnet diff --git a/src/operator/tensor/matrix_op.cu b/src/operator/tensor/matrix_op.cu index 96c075a7d483..8cf656e999b8 100644 --- a/src/operator/tensor/matrix_op.cu +++ b/src/operator/tensor/matrix_op.cu @@ -74,5 +74,11 @@ NNVM_REGISTER_OP(reverse) NNVM_REGISTER_OP(_backward_reverse) .set_attr("FCompute", ReverseOpForward); + +NNVM_REGISTER_OP(stack) +.set_attr("FCompute", StackOpForward); + +NNVM_REGISTER_OP(_backward_stack) +.set_attr("FCompute", StackOpBackward); } // namespace op } // namespace mxnet diff --git a/tests/python/unittest/test_gluon_data.py b/tests/python/unittest/test_gluon_data.py new file mode 100644 index 000000000000..0d25cc497d32 --- /dev/null +++ b/tests/python/unittest/test_gluon_data.py @@ -0,0 +1,53 @@ +import os +import mxnet as mx +import numpy as np +from mxnet import gluon + +def test_array_dataset(): + X = np.random.uniform(size=(10, 20)) + Y = np.random.uniform(size=(10,)) + dataset = gluon.data.ArrayDataset(X, Y) + loader = gluon.data.DataLoader(dataset, 2) + + for i, (x, y) in enumerate(loader): + assert mx.test_utils.almost_equal(x.asnumpy(), X[i*2:(i+1)*2]) + assert mx.test_utils.almost_equal(y.asnumpy(), Y[i*2:(i+1)*2]) + + +def prepare_record(): + if not os.path.isdir("data/test_images"): + os.system("wget http://data.mxnet.io/data/test_images.tar.gz -O data/test_images.tar.gz") + os.system("tar -xf data/test_images.tar.gz -C data") + imgs = os.listdir('data/test_images') + record = mx.recordio.MXIndexedRecordIO('data/test.idx', 'data/test.rec', 'w') + for i, img in enumerate(imgs): + str_img = open('data/test_images/'+img, 'rb').read() + s = mx.recordio.pack((0, i, i, 0), str_img) + record.write_idx(i, s) + return 'data/test.rec' + + +def test_recordimage_dataset(): + recfile = prepare_record() + dataset = gluon.data.ImageRecordDataset(recfile) + loader = gluon.data.DataLoader(dataset, 1) + + for i, (x, y) in enumerate(loader): + assert x.shape[0] == 1 and x.shape[3] == 3 + assert y.asscalar() == i + +def test_sampler(): + seq_sampler = gluon.data.SequentialSampler(10) + assert list(seq_sampler) == list(range(10)) + rand_sampler = gluon.data.RandomSampler(10) + assert sorted(list(rand_sampler)) == list(range(10)) + seq_batch_keep = gluon.data.BatchSampler(seq_sampler, 3, 'keep') + assert sum(list(seq_batch_keep), []) == list(range(10)) + seq_batch_discard = gluon.data.BatchSampler(seq_sampler, 3, 'discard') + assert sum(list(seq_batch_discard), []) == list(range(9)) + rand_batch_keep = gluon.data.BatchSampler(rand_sampler, 3, 'keep') + assert sorted(sum(list(rand_batch_keep), [])) == list(range(10)) + +if __name__ == '__main__': + import nose + nose.runmodule() diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 58d39513a4a8..2f7c3b904e01 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -3642,6 +3642,23 @@ def test_laop(): check_numeric_gradient(test_sumlogdiag, [a]) +def test_stack(): + for _ in range(100): + ndim = random.randint(1, 5) + axis = random.randint(0, ndim) + if random.randint(0, 1): + axis = axis - ndim - 1 + nin = random.randint(1, 3) + dshape = [random.randint(1, 5) for _ in range(ndim)] + inputs = [np.random.uniform(size=dshape) for _ in range(nin)] + output = np.stack(inputs, axis=axis) + sym_ins = [mx.sym.var('x%d'%i) for i in range(nin)] + out = mx.sym.stack(*sym_ins, axis=axis) + check_symbolic_forward(out, inputs, [output]) + check_numeric_gradient(out, inputs) + + + if __name__ == '__main__': import nose nose.runmodule()