Skip to content

Commit

Permalink
Gluon data pipeline (apache#7155)
Browse files Browse the repository at this point in the history
* add data pipeline to gluon

* add cifar

* fix

* fix

* fix
  • Loading branch information
piiswrong authored Jul 25, 2017
1 parent c0377a5 commit 84ce29b
Show file tree
Hide file tree
Showing 21 changed files with 933 additions and 207 deletions.
2 changes: 1 addition & 1 deletion example/gluon/actor_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def forward(self, x):
return F.softmax(probs), values

net = Policy()
net.collect_params().initialize(mx.init.Uniform(0.02))
net.initialize(mx.init.Uniform(0.02))
trainer = gluon.Trainer(net.collect_params(), 'adam', {'learning_rate': 3e-2})
loss = gluon.loss.L1Loss()

Expand Down
339 changes: 177 additions & 162 deletions example/gluon/dcgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,8 @@
from mxnet import gluon
from mxnet.gluon import nn
from mxnet import autograd
from data import cifar10_iterator
import numpy as np
import logging
import cv2
from datetime import datetime
import os
import time
Expand All @@ -32,173 +30,190 @@ def visual(title, X, name):
buff = np.zeros((int(n*X.shape[1]), int(n*X.shape[2]), int(X.shape[3])), dtype=np.uint8)
for i, img in enumerate(X):
fill_buf(buff, i, img, X.shape[1:3])
buff = cv2.cvtColor(buff, cv2.COLOR_BGR2RGB)
buff = buff[:,:,::-1]
plt.imshow(buff)
plt.title(title)
plt.savefig(name)
return None

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default='cifar10', help='dataset to use. options are cifar10 and imagenet.')
parser.add_argument('--batchSize', type=int, default=64, help='input batch size')
parser.add_argument('--imageSize', type=int, default=64, help='the height / width of the input image to network')
parser.add_argument('--nz', type=int, default=100, help='size of the latent z vector')
parser.add_argument('--ngf', type=int, default=64)
parser.add_argument('--ndf', type=int, default=64)
parser.add_argument('--niter', type=int, default=25, help='number of epochs to train for')
parser.add_argument('--lr', type=float, default=0.0002, help='learning rate, default=0.0002')
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
parser.add_argument('--cuda', action='store_true', help='enables cuda')
parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use')
parser.add_argument('--netG', default='', help="path to netG (to continue training)")
parser.add_argument('--netD', default='', help="path to netD (to continue training)")
parser.add_argument('--outf', default='./results', help='folder to output images and model checkpoints')
parser.add_argument('--manualSeed', type=int, help='manual seed')
parser.add_argument('--check_point', default=True, help="save results at each epoch or not")

opt = parser.parse_args()
print(opt)

logging.basicConfig(level=logging.DEBUG)
ngpu = int(opt.ngpu)
nz = int(opt.nz)
ngf = int(opt.ngf)
ndf = int(opt.ndf)
nc = 3

parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default='cifar10', help='dataset to use. options are cifar10 and imagenet.')
parser.add_argument('--batch-size', type=int, default=64, help='input batch size')
parser.add_argument('--nz', type=int, default=100, help='size of the latent z vector')
parser.add_argument('--ngf', type=int, default=64)
parser.add_argument('--ndf', type=int, default=64)
parser.add_argument('--nepoch', type=int, default=25, help='number of epochs to train for')
parser.add_argument('--lr', type=float, default=0.0002, help='learning rate, default=0.0002')
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
parser.add_argument('--cuda', action='store_true', help='enables cuda')
parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use')
parser.add_argument('--netG', default='', help="path to netG (to continue training)")
parser.add_argument('--netD', default='', help="path to netD (to continue training)")
parser.add_argument('--outf', default='./results', help='folder to output images and model checkpoints')
parser.add_argument('--check-point', default=True, help="save results at each epoch or not")

opt = parser.parse_args()
print(opt)

logging.basicConfig(level=logging.DEBUG)
ngpu = int(opt.ngpu)
nz = int(opt.nz)
ngf = int(opt.ngf)
ndf = int(opt.ndf)
nc = 3
if opt.cuda:
ctx = mx.gpu(0)
check_point = bool(opt.check_point)
outf = opt.outf

if not os.path.exists(outf):
os.makedirs(outf)

if opt.dataset == 'cifar10':
train_iter, val_iter = cifar10_iterator(opt.batchSize, (3, 64, 64), 64)

# build the generator
netG = nn.Sequential()
with netG.name_scope():
# input is Z, going into a convolution
netG.add(nn.Conv2DTranspose(ngf * 8, 4, 1, 0, use_bias=False))
netG.add(nn.BatchNorm())
netG.add(nn.Activation('relu'))
# state size. (ngf*8) x 4 x 4
netG.add(nn.Conv2DTranspose(ngf * 4, 4, 2, 1, use_bias=False))
netG.add(nn.BatchNorm())
netG.add(nn.Activation('relu'))
# state size. (ngf*8) x 8 x 8
netG.add(nn.Conv2DTranspose(ngf * 2, 4, 2, 1, use_bias=False))
netG.add(nn.BatchNorm())
netG.add(nn.Activation('relu'))
# state size. (ngf*8) x 16 x 16
netG.add(nn.Conv2DTranspose(ngf, 4, 2, 1, use_bias=False))
netG.add(nn.BatchNorm())
netG.add(nn.Activation('relu'))
# state size. (ngf*8) x 32 x 32
netG.add(nn.Conv2DTranspose(nc, 4, 2, 1, use_bias=False))
netG.add(nn.Activation('tanh'))
# state size. (nc) x 64 x 64

# build the discriminator
netD = nn.Sequential()
with netD.name_scope():
# input is (nc) x 64 x 64
netD.add(nn.Conv2D(ndf, 4, 2, 1, use_bias=False))
netD.add(nn.LeakyReLU(0.2))
# state size. (ndf) x 32 x 32
netD.add(nn.Conv2D(ndf * 2, 4, 2, 1, use_bias=False))
netD.add(nn.BatchNorm())
netD.add(nn.LeakyReLU(0.2))
# state size. (ndf) x 16 x 16
netD.add(nn.Conv2D(ndf * 4, 4, 2, 1, use_bias=False))
netD.add(nn.BatchNorm())
netD.add(nn.LeakyReLU(0.2))
# state size. (ndf) x 8 x 8
netD.add(nn.Conv2D(ndf * 8, 4, 2, 1, use_bias=False))
netD.add(nn.BatchNorm())
netD.add(nn.LeakyReLU(0.2))
# state size. (ndf) x 4 x 4
netD.add(nn.Conv2D(2, 4, 1, 0, use_bias=False))

# loss
loss = gluon.loss.SoftmaxCrossEntropyLoss()

# initialize the generator and the discriminator
netG.collect_params().initialize(mx.init.Normal(0.02), ctx=ctx)
netD.collect_params().initialize(mx.init.Normal(0.02), ctx=ctx)

# trainer for the generator and the discriminator
trainerG = gluon.Trainer(netG.collect_params(), 'adam', {'learning_rate': opt.lr, 'beta1': opt.beta1})
trainerD = gluon.Trainer(netD.collect_params(), 'adam', {'learning_rate': opt.lr, 'beta1': opt.beta1})

# ============printing==============
real_label = mx.nd.ones((opt.batchSize,), ctx=ctx)
fake_label = mx.nd.zeros((opt.batchSize,), ctx=ctx)

metric = mx.metric.Accuracy()
print('Training... ')
stamp = datetime.now().strftime('%Y_%m_%d-%H_%M')

iter = 0
for epoch in range(opt.niter):
tic = time.time()
train_iter.reset()
btic = time.time()
for batch in train_iter:
############################
# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
###########################
# train with real_t
data = batch.data[0].copyto(ctx)
noise = mx.nd.random_normal(0, 1, shape=(opt.batchSize, nz, 1, 1), ctx=ctx)

with autograd.record():
output = netD(data)
output = output.reshape((opt.batchSize, 2))
errD_real = loss(output, real_label)
metric.update([real_label,], [output,])

fake = netG(noise)
output = netD(fake.detach())
output = output.reshape((opt.batchSize, 2))
errD_fake = loss(output, fake_label)
errD = errD_real + errD_fake
errD.backward()
metric.update([fake_label,], [output,])

trainerD.step(opt.batchSize)

############################
# (2) Update G network: maximize log(D(G(z)))
###########################
with autograd.record():
output = netD(fake)
output = output.reshape((opt.batchSize, 2))
errG = loss(output, real_label)
errG.backward()

trainerG.step(opt.batchSize)

name, acc = metric.get()
# logging.info('speed: {} samples/s'.format(opt.batchSize / (time.time() - btic)))
logging.info('discriminator loss = %f, generator loss = %f, binary training acc = %f at iter %d epoch %d' %(mx.nd.mean(errD).asscalar(), mx.nd.mean(errG).asscalar(), acc, iter, epoch))
if iter % 200 == 0:
visual('gout', fake.asnumpy(), name=os.path.join(outf,'fake_img_iter_%d.png' %iter))
visual('data', batch.data[0].asnumpy(), name=os.path.join(outf,'real_img_iter_%d.png' %iter))

iter = iter + 1
btic = time.time()
else:
ctx = mx.cpu()
check_point = bool(opt.check_point)
outf = opt.outf

if not os.path.exists(outf):
os.makedirs(outf)


def transformer(data, label):
# resize to 64x64
data = mx.image.imresize(data, 64, 64)
# transpose from (64, 64, 3) to (3, 64, 64)
data = mx.nd.transpose(data, (2,0,1))
# normalize to [-1, 1]
data = data.astype(np.float32)/128 - 1
# if image is greyscale, repeat 3 times to get RGB image.
if data.shape[0] == 1:
data = mx.nd.tile(data, (3, 1, 1))
return data, label

train_data = gluon.data.DataLoader(
gluon.data.vision.MNIST('./data', train=True, transform=transformer),
batch_size=opt.batch_size, shuffle=True, last_batch='discard')

val_data = gluon.data.DataLoader(
gluon.data.vision.MNIST('./data', train=False, transform=transformer),
batch_size=opt.batch_size, shuffle=False)


# build the generator
netG = nn.Sequential()
with netG.name_scope():
# input is Z, going into a convolution
netG.add(nn.Conv2DTranspose(ngf * 8, 4, 1, 0, use_bias=False))
netG.add(nn.BatchNorm())
netG.add(nn.Activation('relu'))
# state size. (ngf*8) x 4 x 4
netG.add(nn.Conv2DTranspose(ngf * 4, 4, 2, 1, use_bias=False))
netG.add(nn.BatchNorm())
netG.add(nn.Activation('relu'))
# state size. (ngf*8) x 8 x 8
netG.add(nn.Conv2DTranspose(ngf * 2, 4, 2, 1, use_bias=False))
netG.add(nn.BatchNorm())
netG.add(nn.Activation('relu'))
# state size. (ngf*8) x 16 x 16
netG.add(nn.Conv2DTranspose(ngf, 4, 2, 1, use_bias=False))
netG.add(nn.BatchNorm())
netG.add(nn.Activation('relu'))
# state size. (ngf*8) x 32 x 32
netG.add(nn.Conv2DTranspose(nc, 4, 2, 1, use_bias=False))
netG.add(nn.Activation('tanh'))
# state size. (nc) x 64 x 64

# build the discriminator
netD = nn.Sequential()
with netD.name_scope():
# input is (nc) x 64 x 64
netD.add(nn.Conv2D(ndf, 4, 2, 1, use_bias=False))
netD.add(nn.LeakyReLU(0.2))
# state size. (ndf) x 32 x 32
netD.add(nn.Conv2D(ndf * 2, 4, 2, 1, use_bias=False))
netD.add(nn.BatchNorm())
netD.add(nn.LeakyReLU(0.2))
# state size. (ndf) x 16 x 16
netD.add(nn.Conv2D(ndf * 4, 4, 2, 1, use_bias=False))
netD.add(nn.BatchNorm())
netD.add(nn.LeakyReLU(0.2))
# state size. (ndf) x 8 x 8
netD.add(nn.Conv2D(ndf * 8, 4, 2, 1, use_bias=False))
netD.add(nn.BatchNorm())
netD.add(nn.LeakyReLU(0.2))
# state size. (ndf) x 4 x 4
netD.add(nn.Conv2D(2, 4, 1, 0, use_bias=False))

# loss
loss = gluon.loss.SoftmaxCrossEntropyLoss()

# initialize the generator and the discriminator
netG.initialize(mx.init.Normal(0.02), ctx=ctx)
netD.initialize(mx.init.Normal(0.02), ctx=ctx)

# trainer for the generator and the discriminator
trainerG = gluon.Trainer(netG.collect_params(), 'adam', {'learning_rate': opt.lr, 'beta1': opt.beta1})
trainerD = gluon.Trainer(netD.collect_params(), 'adam', {'learning_rate': opt.lr, 'beta1': opt.beta1})

# ============printing==============
real_label = mx.nd.ones((opt.batch_size,), ctx=ctx)
fake_label = mx.nd.zeros((opt.batch_size,), ctx=ctx)

metric = mx.metric.Accuracy()
print('Training... ')
stamp = datetime.now().strftime('%Y_%m_%d-%H_%M')

iter = 0
for epoch in range(opt.nepoch):
tic = time.time()
btic = time.time()
for data, _ in train_data:
############################
# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
###########################
# train with real_t
data = data.as_in_context(ctx)
noise = mx.nd.random_normal(0, 1, shape=(opt.batch_size, nz, 1, 1), ctx=ctx)

with autograd.record():
output = netD(data)
output = output.reshape((opt.batch_size, 2))
errD_real = loss(output, real_label)
metric.update([real_label,], [output,])

fake = netG(noise)
output = netD(fake.detach())
output = output.reshape((opt.batch_size, 2))
errD_fake = loss(output, fake_label)
errD = errD_real + errD_fake
errD.backward()
metric.update([fake_label,], [output,])

trainerD.step(opt.batch_size)

############################
# (2) Update G network: maximize log(D(G(z)))
###########################
with autograd.record():
output = netD(fake)
output = output.reshape((-1, 2))
errG = loss(output, real_label)
errG.backward()

trainerG.step(opt.batch_size)

name, acc = metric.get()
metric.reset()
logging.info('\nbinary training acc at epoch %d: %s=%f' % (epoch, name, acc))
logging.info('time: %f' % (time.time() - tic))
# logging.info('speed: {} samples/s'.format(opt.batch_size / (time.time() - btic)))
logging.info('discriminator loss = %f, generator loss = %f, binary training acc = %f at iter %d epoch %d' %(mx.nd.mean(errD).asscalar(), mx.nd.mean(errG).asscalar(), acc, iter, epoch))
if iter % 1 == 0:
visual('gout', fake.asnumpy(), name=os.path.join(outf,'fake_img_iter_%d.png' %iter))
visual('data', data.asnumpy(), name=os.path.join(outf,'real_img_iter_%d.png' %iter))

iter = iter + 1
btic = time.time()

if check_point:
netG.collect_params().save(os.path.join(outf,'generator_epoch_%d.params' %epoch))
netD.collect_params().save(os.path.join(outf,'discriminator_epoch_%d.params' % epoch))
name, acc = metric.get()
metric.reset()
logging.info('\nbinary training acc at epoch %d: %s=%f' % (epoch, name, acc))
logging.info('time: %f' % (time.time() - tic))

netG.collect_params().save(os.path.join(outf, 'generator.params'))
netD.collect_params().save(os.path.join(outf, 'discriminator.params'))
if check_point:
netG.save_params(os.path.join(outf,'generator_epoch_%d.params' %epoch))
netD.save_params(os.path.join(outf,'discriminator_epoch_%d.params' % epoch))

netG.save_params(os.path.join(outf, 'generator.params'))
netD.save_params(os.path.join(outf, 'discriminator.params'))
Loading

0 comments on commit 84ce29b

Please sign in to comment.