Skip to content

Commit

Permalink
[MXNET-531] GAN MNIST Examples for Scala new API (apache#11547)
Browse files Browse the repository at this point in the history
* add gan base file and example suite
  • Loading branch information
lanking520 authored and nswamy committed Jul 4, 2018
1 parent 01f1457 commit 92f0c51
Show file tree
Hide file tree
Showing 4 changed files with 188 additions and 120 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,66 +17,52 @@

package org.apache.mxnetexamples.gan

import org.apache.mxnet.{Context, CustomMetric, DataBatch, IO, NDArray, Shape, Symbol, Xavier}
import org.apache.mxnet.optimizer.Adam
import org.kohsuke.args4j.{CmdLineParser, Option}
import org.slf4j.LoggerFactory

import scala.collection.JavaConverters._
import Viz._
import org.apache.mxnet.Context
import org.apache.mxnet.Shape
import org.apache.mxnet.IO
import org.apache.mxnet.NDArray
import org.apache.mxnet.CustomMetric
import org.apache.mxnet.Xavier
import org.apache.mxnet.optimizer.Adam
import org.apache.mxnet.DataBatch
import org.apache.mxnet.Symbol
import org.apache.mxnet.Shape

/**
* @author Depeng Liang
*/
object GanMnist {

private val logger = LoggerFactory.getLogger(classOf[GanMnist])

// a deconv layer that enlarges the feature map
// a deconv layer that enlarges the feature map
def deconv2D(data: Symbol, iShape: Shape, oShape: Shape,
kShape: (Int, Int), name: String, stride: (Int, Int) = (2, 2)): Symbol = {
val targetShape = (oShape(oShape.length - 2), oShape(oShape.length - 1))
val net = Symbol.Deconvolution(name)()(Map(
"data" -> data,
"kernel" -> s"$kShape",
"stride" -> s"$stride",
"target_shape" -> s"$targetShape",
"num_filter" -> oShape(0),
"no_bias" -> true))
kShape: (Int, Int), name: String, stride: (Int, Int) = (2, 2)): Symbol = {
val targetShape = Shape(oShape(oShape.length - 2), oShape(oShape.length - 1))
val net = Symbol.api.Deconvolution(data = Some(data), kernel = Shape(kShape._1, kShape._2),
stride = Some(Shape(stride._1, stride._2)), target_shape = Some(targetShape),
num_filter = oShape(0), no_bias = Some(true), name = name)
net
}

def deconv2DBnRelu(data: Symbol, prefix: String, iShape: Shape,
oShape: Shape, kShape: (Int, Int), eps: Float = 1e-5f + 1e-12f): Symbol = {
oShape: Shape, kShape: (Int, Int), eps: Float = 1e-5f + 1e-12f): Symbol = {
var net = deconv2D(data, iShape, oShape, kShape, name = s"${prefix}_deconv")
net = Symbol.BatchNorm(s"${prefix}_bn")()(Map("data" -> net, "fix_gamma" -> true, "eps" -> eps))
net = Symbol.Activation(s"${prefix}_act")()(Map("data" -> net, "act_type" -> "relu"))
net = Symbol.api.BatchNorm(name = s"${prefix}_bn", data = Some(net),
fix_gamma = Some(true), eps = Some(eps))
net = Symbol.api.Activation(data = Some(net), act_type = "relu", name = s"${prefix}_act")
net
}

def deconv2DAct(data: Symbol, prefix: String, actType: String,
iShape: Shape, oShape: Shape, kShape: (Int, Int)): Symbol = {
iShape: Shape, oShape: Shape, kShape: (Int, Int)): Symbol = {
var net = deconv2D(data, iShape, oShape, kShape, name = s"${prefix}_deconv")
net = Symbol.Activation(s"${prefix}_act")()(Map("data" -> net, "act_type" -> actType))
net = Symbol.api.Activation(data = Some(net), act_type = "relu", name = s"${prefix}_act")
net
}

def makeDcganSym(oShape: Shape, ngf: Int = 128, finalAct: String = "sigmoid",
eps: Float = 1e-5f + 1e-12f): (Symbol, Symbol) = {
eps: Float = 1e-5f + 1e-12f): (Symbol, Symbol) = {

val code = Symbol.Variable("rand")
var net = Symbol.FullyConnected("g1")()(Map("data" -> code,
"num_hidden" -> 4 * 4 * ngf * 4, "no_bias" -> true))
net = Symbol.Activation("gact1")()(Map("data" -> net, "act_type" -> "relu"))
var net = Symbol.api.FullyConnected(data = Some(code), num_hidden = 4 * 4 * ngf * 4,
no_bias = Some(true), name = " g1")
net = Symbol.api.Activation(data = Some(net), act_type = "relu", name = "gact1")
// 4 x 4
net = Symbol.Reshape()()(Map("data" -> net, "shape" -> s"(-1, ${ngf * 4}, 4, 4)"))
net = Symbol.api.Reshape(data = Some(net), shape = Some(Shape(-1, ngf * 4, 4, 4)))
// 8 x 8
net = deconv2DBnRelu(net, prefix = "g2",
iShape = Shape(ngf * 4, 4, 4), oShape = Shape(ngf * 2, 8, 8), kShape = (3, 3))
Expand All @@ -89,22 +75,22 @@ object GanMnist {

val data = Symbol.Variable("data")
// 28 x 28
val conv1 = Symbol.Convolution("conv1")()(Map("data" -> data,
"kernel" -> "(5,5)", "num_filter" -> 20))
val tanh1 = Symbol.Activation()()(Map("data" -> conv1, "act_type" -> "tanh"))
val pool1 = Symbol.Pooling()()(Map("data" -> tanh1,
"pool_type" -> "max", "kernel" -> "(2,2)", "stride" -> "(2,2)"))
val conv1 = Symbol.api.Convolution(data = Some(data), kernel = Shape(5, 5),
num_filter = 20, name = "conv1")
val tanh1 = Symbol.api.Activation(data = Some(conv1), act_type = "tanh")
val pool1 = Symbol.api.Pooling(data = Some(tanh1), pool_type = Some("max"),
kernel = Some(Shape(2, 2)), stride = Some(Shape(2, 2)))
// second conv
val conv2 = Symbol.Convolution("conv2")()(Map("data" -> pool1,
"kernel" -> "(5,5)", "num_filter" -> 50))
val tanh2 = Symbol.Activation()()(Map("data" -> conv2, "act_type" -> "tanh"))
val pool2 = Symbol.Pooling()()(Map("data" -> tanh2, "pool_type" -> "max",
"kernel" -> "(2,2)", "stride" -> "(2,2)"))
var d5 = Symbol.Flatten()()(Map("data" -> pool2))
d5 = Symbol.FullyConnected("fc1")()(Map("data" -> d5, "num_hidden" -> 500))
d5 = Symbol.Activation()()(Map("data" -> d5, "act_type" -> "tanh"))
d5 = Symbol.FullyConnected("fc_dloss")()(Map("data" -> d5, "num_hidden" -> 1))
val dloss = Symbol.LogisticRegressionOutput("dloss")()(Map("data" -> d5))
val conv2 = Symbol.api.Convolution(data = Some(pool1), kernel = Shape(5, 5),
num_filter = 50, name = "conv2")
val tanh2 = Symbol.api.Activation(data = Some(conv2), act_type = "tanh")
val pool2 = Symbol.api.Pooling(data = Some(tanh2), pool_type = Some("max"),
kernel = Some(Shape(2, 2)), stride = Some(Shape(2, 2)))
var d5 = Symbol.api.Flatten(data = Some(pool2))
d5 = Symbol.api.FullyConnected(data = Some(d5), num_hidden = 500, name = "fc1")
d5 = Symbol.api.Activation(data = Some(d5), act_type = "tanh")
d5 = Symbol.api.FullyConnected(data = Some(d5), num_hidden = 1, name = "fc_dloss")
val dloss = Symbol.api.LogisticRegressionOutput(data = Some(d5), name = "dloss")

(gout, dloss)
}
Expand All @@ -116,85 +102,92 @@ object GanMnist {
labelArr.zip(predArr).map { case (l, p) => Math.abs(l - p) }.sum / label.shape(0)
}

def runTraining(dataPath : String, context : Context,
outputPath : String, numEpoch : Int): Float = {
val lr = 0.0005f
val beta1 = 0.5f
val batchSize = 100
val randShape = Shape(batchSize, 100)
val dataShape = Shape(batchSize, 1, 28, 28)

val (symGen, symDec) =
makeDcganSym(oShape = dataShape, ngf = 32, finalAct = "sigmoid")

val gMod = new GANModule(
symGen,
symDec,
context = context,
dataShape = dataShape,
codeShape = randShape)

gMod.initGParams(new Xavier(factorType = "in", magnitude = 2.34f))
gMod.initDParams(new Xavier(factorType = "in", magnitude = 2.34f))

gMod.initOptimizer(new Adam(learningRate = lr, wd = 0f, beta1 = beta1))

val params = Map(
"image" -> s"$dataPath/train-images-idx3-ubyte",
"label" -> s"$dataPath/train-labels-idx1-ubyte",
"input_shape" -> s"(1, 28, 28)",
"batch_size" -> s"$batchSize",
"shuffle" -> "True"
)

val mnistIter = IO.MNISTIter(params)

val metricAcc = new CustomMetric(ferr, "ferr")

var t = 0
var dataBatch: DataBatch = null
var acc = 0.0f
for (epoch <- 0 until numEpoch) {
mnistIter.reset()
metricAcc.reset()
t = 0
while (mnistIter.hasNext) {
dataBatch = mnistIter.next()
gMod.update(dataBatch)
gMod.dLabel.set(0f)
metricAcc.update(Array(gMod.dLabel), gMod.outputsFake)
gMod.dLabel.set(1f)
metricAcc.update(Array(gMod.dLabel), gMod.outputsReal)

if (t % 50 == 0) {
val (name, value) = metricAcc.get
acc = value(0)
logger.info(s"epoch: $epoch, iter $t, metric=${value.mkString(" ")}")
Viz.imSave("gout", outputPath, gMod.tempOutG(0), flip = true)
val diff = gMod.tempDiffD
val arr = diff.toArray
val mean = arr.sum / arr.length
val std = {
val tmpA = arr.map(a => (a - mean) * (a - mean))
Math.sqrt(tmpA.sum / tmpA.length).toFloat
}
diff.set((diff - mean) / std + 0.5f)
Viz.imSave("diff", outputPath, diff, flip = true)
Viz.imSave("data", outputPath, dataBatch.data(0), flip = true)
}

t += 1
}
}
acc
}

def main(args: Array[String]): Unit = {
val anst = new GanMnist
val parser: CmdLineParser = new CmdLineParser(anst)
try {
parser.parseArgument(args.toList.asJava)

val dataPath = if (anst.mnistDataPath == null) System.getenv("MXNET_DATA_DIR")
else anst.mnistDataPath
else anst.mnistDataPath

assert(dataPath != null)

val lr = 0.0005f
val beta1 = 0.5f
val batchSize = 100
val randShape = Shape(batchSize, 100)
val numEpoch = 100
val dataShape = Shape(batchSize, 1, 28, 28)
val context = if (anst.gpu == -1) Context.cpu() else Context.gpu(anst.gpu)

val (symGen, symDec) =
makeDcganSym(oShape = dataShape, ngf = 32, finalAct = "sigmoid")

val gMod = new GANModule(
symGen,
symDec,
context = context,
dataShape = dataShape,
codeShape = randShape)

gMod.initGParams(new Xavier(factorType = "in", magnitude = 2.34f))
gMod.initDParams(new Xavier(factorType = "in", magnitude = 2.34f))

gMod.initOptimizer(new Adam(learningRate = lr, wd = 0f, beta1 = beta1))

val params = Map(
"image" -> s"${dataPath}/train-images-idx3-ubyte",
"label" -> s"${dataPath}/train-labels-idx1-ubyte",
"input_shape" -> s"(1, 28, 28)",
"batch_size" -> s"$batchSize",
"shuffle" -> "True"
)

val mnistIter = IO.MNISTIter(params)

val metricAcc = new CustomMetric(ferr, "ferr")

var t = 0
var dataBatch: DataBatch = null
for (epoch <- 0 until numEpoch) {
mnistIter.reset()
metricAcc.reset()
t = 0
while (mnistIter.hasNext) {
dataBatch = mnistIter.next()
gMod.update(dataBatch)
gMod.dLabel.set(0f)
metricAcc.update(Array(gMod.dLabel), gMod.outputsFake)
gMod.dLabel.set(1f)
metricAcc.update(Array(gMod.dLabel), gMod.outputsReal)

if (t % 50 == 0) {
val (name, value) = metricAcc.get
logger.info(s"epoch: $epoch, iter $t, metric=$value")
Viz.imSave("gout", anst.outputPath, gMod.tempOutG(0), flip = true)
val diff = gMod.tempDiffD
val arr = diff.toArray
val mean = arr.sum / arr.length
val std = {
val tmpA = arr.map(a => (a - mean) * (a - mean))
Math.sqrt(tmpA.sum / tmpA.length).toFloat
}
diff.set((diff - mean) / std + 0.5f)
Viz.imSave("diff", anst.outputPath, diff, flip = true)
Viz.imSave("data", anst.outputPath, dataBatch.data(0), flip = true)
}

t += 1
}
}
runTraining(dataPath, context, anst.outputPath, 100)
} catch {
case ex: Exception => {
logger.error(ex.getMessage, ex)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,6 @@ import org.apache.mxnet.Initializer
import org.apache.mxnet.DataBatch
import org.apache.mxnet.Random

/**
* @author Depeng Liang
*/
class GANModule(
symbolGenerator: Symbol,
symbolEncoder: Symbol,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# GAN MNIST Example for Scala
This is the GAN MNIST Training Example implemented for Scala type-safe api

This example is only for Illustration and not modeled to achieve the best accuracy.
## Setup
### Download the source File
```$xslt
https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/mnist/mnist.zip
```
### Unzip the file
```$xslt
unzip mnist.zip
```
### Arguement Configuration
Then you need to define the arguments that you would like to pass in the model:
```$xslt
--mnist-data-path <location of your downloaded file>
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.mxnetexamples.gan

import java.io.File
import java.net.URL

import org.apache.commons.io.FileUtils
import org.apache.mxnet.Context
import org.scalatest.{BeforeAndAfterAll, FunSuite}
import org.slf4j.LoggerFactory

import scala.sys.process.Process

class GanExampleSuite extends FunSuite with BeforeAndAfterAll{
private val logger = LoggerFactory.getLogger(classOf[GanExampleSuite])

test("Example CI: Test GAN MNIST") {
if (System.getenv().containsKey("SCALA_TEST_ON_GPU") &&
System.getenv("SCALA_TEST_ON_GPU").toInt == 1) {
logger.info("Downloading mnist model")
val baseUrl = "https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci"
val tempDirPath = System.getProperty("java.io.tmpdir")
val modelDirPath = tempDirPath + File.separator + "mnist/"
logger.info("tempDirPath: %s".format(tempDirPath))
val tmpFile = new File(tempDirPath + "/mnist/mnist.zip")
if (!tmpFile.exists()) {
FileUtils.copyURLToFile(new URL(baseUrl + "/mnist/mnist.zip"),
tmpFile)
}
// TODO: Need to confirm with Windows
Process("unzip " + tempDirPath + "/mnist/mnist.zip -d "
+ tempDirPath + "/mnist/") !

val context = Context.gpu()

val output = GanMnist.runTraining(modelDirPath, context, modelDirPath, 5)
Process("rm -rf " + modelDirPath) !

assert(output >= 0.0f)
} else {
logger.info("GPU test only, skipped...")
}
}
}

0 comments on commit 92f0c51

Please sign in to comment.