Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-531] GAN MNIST Examples for Scala new API #11547

Merged
merged 3 commits into from
Jul 4, 2018
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
add gan base file and example suite
  • Loading branch information
lanking520 committed Jul 3, 2018
commit e0da62c9a5807a1fa4ab37f598d14a7363d87306
Original file line number Diff line number Diff line change
Expand Up @@ -17,66 +17,52 @@

package org.apache.mxnetexamples.gan

import org.apache.mxnet.{Context, CustomMetric, DataBatch, IO, NDArray, Shape, Symbol, Xavier}
import org.apache.mxnet.optimizer.Adam
import org.kohsuke.args4j.{CmdLineParser, Option}
import org.slf4j.LoggerFactory

import scala.collection.JavaConverters._
import Viz._
import org.apache.mxnet.Context
import org.apache.mxnet.Shape
import org.apache.mxnet.IO
import org.apache.mxnet.NDArray
import org.apache.mxnet.CustomMetric
import org.apache.mxnet.Xavier
import org.apache.mxnet.optimizer.Adam
import org.apache.mxnet.DataBatch
import org.apache.mxnet.Symbol
import org.apache.mxnet.Shape

/**
* @author Depeng Liang
*/
object GanMnist {

private val logger = LoggerFactory.getLogger(classOf[GanMnist])

// a deconv layer that enlarges the feature map
// a deconv layer that enlarges the feature map
def deconv2D(data: Symbol, iShape: Shape, oShape: Shape,
kShape: (Int, Int), name: String, stride: (Int, Int) = (2, 2)): Symbol = {
val targetShape = (oShape(oShape.length - 2), oShape(oShape.length - 1))
val net = Symbol.Deconvolution(name)()(Map(
"data" -> data,
"kernel" -> s"$kShape",
"stride" -> s"$stride",
"target_shape" -> s"$targetShape",
"num_filter" -> oShape(0),
"no_bias" -> true))
kShape: (Int, Int), name: String, stride: (Int, Int) = (2, 2)): Symbol = {
val targetShape = Shape(oShape(oShape.length - 2), oShape(oShape.length - 1))
val net = Symbol.api.Deconvolution(data = Some(data), kernel = Shape(kShape._1, kShape._2),
stride = Some(Shape(stride._1, stride._2)), target_shape = Some(targetShape),
num_filter = oShape(0), no_bias = Some(true), name = name)
net
}

def deconv2DBnRelu(data: Symbol, prefix: String, iShape: Shape,
oShape: Shape, kShape: (Int, Int), eps: Float = 1e-5f + 1e-12f): Symbol = {
oShape: Shape, kShape: (Int, Int), eps: Float = 1e-5f + 1e-12f): Symbol = {
var net = deconv2D(data, iShape, oShape, kShape, name = s"${prefix}_deconv")
net = Symbol.BatchNorm(s"${prefix}_bn")()(Map("data" -> net, "fix_gamma" -> true, "eps" -> eps))
net = Symbol.Activation(s"${prefix}_act")()(Map("data" -> net, "act_type" -> "relu"))
net = Symbol.api.BatchNorm(name = s"${prefix}_bn", data = Some(net),
fix_gamma = Some(true), eps = Some(eps))
net = Symbol.api.Activation(data = Some(net), act_type = "relu", name = s"${prefix}_act")
net
}

def deconv2DAct(data: Symbol, prefix: String, actType: String,
iShape: Shape, oShape: Shape, kShape: (Int, Int)): Symbol = {
iShape: Shape, oShape: Shape, kShape: (Int, Int)): Symbol = {
var net = deconv2D(data, iShape, oShape, kShape, name = s"${prefix}_deconv")
net = Symbol.Activation(s"${prefix}_act")()(Map("data" -> net, "act_type" -> actType))
net = Symbol.api.Activation(data = Some(net), act_type = "relu", name = s"${prefix}_act")
net
}

def makeDcganSym(oShape: Shape, ngf: Int = 128, finalAct: String = "sigmoid",
eps: Float = 1e-5f + 1e-12f): (Symbol, Symbol) = {
eps: Float = 1e-5f + 1e-12f): (Symbol, Symbol) = {

val code = Symbol.Variable("rand")
var net = Symbol.FullyConnected("g1")()(Map("data" -> code,
"num_hidden" -> 4 * 4 * ngf * 4, "no_bias" -> true))
net = Symbol.Activation("gact1")()(Map("data" -> net, "act_type" -> "relu"))
var net = Symbol.api.FullyConnected(data = Some(code), num_hidden = 4 * 4 * ngf * 4,
no_bias = Some(true), name = " g1")
net = Symbol.api.Activation(data = Some(net), act_type = "relu", name = "gact1")
// 4 x 4
net = Symbol.Reshape()()(Map("data" -> net, "shape" -> s"(-1, ${ngf * 4}, 4, 4)"))
net = Symbol.api.Reshape(data = Some(net), shape = Some(Shape(-1, ngf * 4, 4, 4)))
// 8 x 8
net = deconv2DBnRelu(net, prefix = "g2",
iShape = Shape(ngf * 4, 4, 4), oShape = Shape(ngf * 2, 8, 8), kShape = (3, 3))
Expand All @@ -89,22 +75,22 @@ object GanMnist {

val data = Symbol.Variable("data")
// 28 x 28
val conv1 = Symbol.Convolution("conv1")()(Map("data" -> data,
"kernel" -> "(5,5)", "num_filter" -> 20))
val tanh1 = Symbol.Activation()()(Map("data" -> conv1, "act_type" -> "tanh"))
val pool1 = Symbol.Pooling()()(Map("data" -> tanh1,
"pool_type" -> "max", "kernel" -> "(2,2)", "stride" -> "(2,2)"))
val conv1 = Symbol.api.Convolution(data = Some(data), kernel = Shape(5, 5),
num_filter = 20, name = "conv1")
val tanh1 = Symbol.api.Activation(data = Some(conv1), act_type = "tanh")
val pool1 = Symbol.api.Pooling(data = Some(tanh1), pool_type = Some("max"),
kernel = Some(Shape(2, 2)), stride = Some(Shape(2, 2)))
// second conv
val conv2 = Symbol.Convolution("conv2")()(Map("data" -> pool1,
"kernel" -> "(5,5)", "num_filter" -> 50))
val tanh2 = Symbol.Activation()()(Map("data" -> conv2, "act_type" -> "tanh"))
val pool2 = Symbol.Pooling()()(Map("data" -> tanh2, "pool_type" -> "max",
"kernel" -> "(2,2)", "stride" -> "(2,2)"))
var d5 = Symbol.Flatten()()(Map("data" -> pool2))
d5 = Symbol.FullyConnected("fc1")()(Map("data" -> d5, "num_hidden" -> 500))
d5 = Symbol.Activation()()(Map("data" -> d5, "act_type" -> "tanh"))
d5 = Symbol.FullyConnected("fc_dloss")()(Map("data" -> d5, "num_hidden" -> 1))
val dloss = Symbol.LogisticRegressionOutput("dloss")()(Map("data" -> d5))
val conv2 = Symbol.api.Convolution(data = Some(pool1), kernel = Shape(5, 5),
num_filter = 50, name = "conv2")
val tanh2 = Symbol.api.Activation(data = Some(conv2), act_type = "tanh")
val pool2 = Symbol.api.Pooling(data = Some(tanh2), pool_type = Some("max"),
kernel = Some(Shape(2, 2)), stride = Some(Shape(2, 2)))
var d5 = Symbol.api.Flatten(data = Some(pool2))
d5 = Symbol.api.FullyConnected(data = Some(d5), num_hidden = 500, name = "fc1")
d5 = Symbol.api.Activation(data = Some(d5), act_type = "tanh")
d5 = Symbol.api.FullyConnected(data = Some(d5), num_hidden = 1, name = "fc_dloss")
val dloss = Symbol.api.LogisticRegressionOutput(data = Some(d5), name = "dloss")

(gout, dloss)
}
Expand All @@ -116,85 +102,91 @@ object GanMnist {
labelArr.zip(predArr).map { case (l, p) => Math.abs(l - p) }.sum / label.shape(0)
}

def test(dataPath : String, context : Context, outputPath : String, numEpoch : Int): Float = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test->runTraining

val lr = 0.0005f
val beta1 = 0.5f
val batchSize = 100
val randShape = Shape(batchSize, 100)
val dataShape = Shape(batchSize, 1, 28, 28)

val (symGen, symDec) =
makeDcganSym(oShape = dataShape, ngf = 32, finalAct = "sigmoid")

val gMod = new GANModule(
symGen,
symDec,
context = context,
dataShape = dataShape,
codeShape = randShape)

gMod.initGParams(new Xavier(factorType = "in", magnitude = 2.34f))
gMod.initDParams(new Xavier(factorType = "in", magnitude = 2.34f))

gMod.initOptimizer(new Adam(learningRate = lr, wd = 0f, beta1 = beta1))

val params = Map(
"image" -> s"$dataPath/train-images-idx3-ubyte",
"label" -> s"$dataPath/train-labels-idx1-ubyte",
"input_shape" -> s"(1, 28, 28)",
"batch_size" -> s"$batchSize",
"shuffle" -> "True"
)

val mnistIter = IO.MNISTIter(params)

val metricAcc = new CustomMetric(ferr, "ferr")

var t = 0
var dataBatch: DataBatch = null
var acc = 0.0f
for (epoch <- 0 until numEpoch) {
mnistIter.reset()
metricAcc.reset()
t = 0
while (mnistIter.hasNext) {
dataBatch = mnistIter.next()
gMod.update(dataBatch)
gMod.dLabel.set(0f)
metricAcc.update(Array(gMod.dLabel), gMod.outputsFake)
gMod.dLabel.set(1f)
metricAcc.update(Array(gMod.dLabel), gMod.outputsReal)

if (t % 50 == 0) {
val (name, value) = metricAcc.get
acc = value(0)
logger.info(s"epoch: $epoch, iter $t, metric=${value.mkString(" ")}")
Viz.imSave("gout", outputPath, gMod.tempOutG(0), flip = true)
val diff = gMod.tempDiffD
val arr = diff.toArray
val mean = arr.sum / arr.length
val std = {
val tmpA = arr.map(a => (a - mean) * (a - mean))
Math.sqrt(tmpA.sum / tmpA.length).toFloat
}
diff.set((diff - mean) / std + 0.5f)
Viz.imSave("diff", outputPath, diff, flip = true)
Viz.imSave("data", outputPath, dataBatch.data(0), flip = true)
}

t += 1
}
}
acc
}

def main(args: Array[String]): Unit = {
val anst = new GanMnist
val parser: CmdLineParser = new CmdLineParser(anst)
try {
parser.parseArgument(args.toList.asJava)

val dataPath = if (anst.mnistDataPath == null) System.getenv("MXNET_DATA_DIR")
else anst.mnistDataPath
else anst.mnistDataPath

assert(dataPath != null)

val lr = 0.0005f
val beta1 = 0.5f
val batchSize = 100
val randShape = Shape(batchSize, 100)
val numEpoch = 100
val dataShape = Shape(batchSize, 1, 28, 28)
val context = if (anst.gpu == -1) Context.cpu() else Context.gpu(anst.gpu)

val (symGen, symDec) =
makeDcganSym(oShape = dataShape, ngf = 32, finalAct = "sigmoid")

val gMod = new GANModule(
symGen,
symDec,
context = context,
dataShape = dataShape,
codeShape = randShape)

gMod.initGParams(new Xavier(factorType = "in", magnitude = 2.34f))
gMod.initDParams(new Xavier(factorType = "in", magnitude = 2.34f))

gMod.initOptimizer(new Adam(learningRate = lr, wd = 0f, beta1 = beta1))

val params = Map(
"image" -> s"${dataPath}/train-images-idx3-ubyte",
"label" -> s"${dataPath}/train-labels-idx1-ubyte",
"input_shape" -> s"(1, 28, 28)",
"batch_size" -> s"$batchSize",
"shuffle" -> "True"
)

val mnistIter = IO.MNISTIter(params)

val metricAcc = new CustomMetric(ferr, "ferr")

var t = 0
var dataBatch: DataBatch = null
for (epoch <- 0 until numEpoch) {
mnistIter.reset()
metricAcc.reset()
t = 0
while (mnistIter.hasNext) {
dataBatch = mnistIter.next()
gMod.update(dataBatch)
gMod.dLabel.set(0f)
metricAcc.update(Array(gMod.dLabel), gMod.outputsFake)
gMod.dLabel.set(1f)
metricAcc.update(Array(gMod.dLabel), gMod.outputsReal)

if (t % 50 == 0) {
val (name, value) = metricAcc.get
logger.info(s"epoch: $epoch, iter $t, metric=$value")
Viz.imSave("gout", anst.outputPath, gMod.tempOutG(0), flip = true)
val diff = gMod.tempDiffD
val arr = diff.toArray
val mean = arr.sum / arr.length
val std = {
val tmpA = arr.map(a => (a - mean) * (a - mean))
Math.sqrt(tmpA.sum / tmpA.length).toFloat
}
diff.set((diff - mean) / std + 0.5f)
Viz.imSave("diff", anst.outputPath, diff, flip = true)
Viz.imSave("data", anst.outputPath, dataBatch.data(0), flip = true)
}

t += 1
}
}
test(dataPath, context, anst.outputPath, 100)
} catch {
case ex: Exception => {
logger.error(ex.getMessage, ex)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.mxnetexamples.gan

import java.io.File
import java.net.URL

import org.apache.commons.io.FileUtils
import org.apache.mxnet.Context
import org.scalatest.{BeforeAndAfterAll, FunSuite}
import org.slf4j.LoggerFactory

import scala.sys.process.Process

class GanExampleSuite extends FunSuite with BeforeAndAfterAll{
private val logger = LoggerFactory.getLogger(classOf[GanExampleSuite])

test("Example CI: Test GAN MNIST") {
if (System.getenv().containsKey("SCALA_TEST_ON_GPU") &&
System.getenv("SCALA_TEST_ON_GPU").toInt == 1) {
logger.info("Downloading mnist model")
val baseUrl = "https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci"
val tempDirPath = System.getProperty("java.io.tmpdir")
val modelDirPath = tempDirPath + File.separator + "mnist/"
logger.info("tempDirPath: %s".format(tempDirPath))
val tmpFile = new File(tempDirPath + "/mnist/mnist.zip")
if (!tmpFile.exists()) {
FileUtils.copyURLToFile(new URL(baseUrl + "/mnist/mnist.zip"),
tmpFile)
}
// TODO: Need to confirm with Windows
Process("unzip " + tempDirPath + "/mnist/mnist.zip -d "
+ tempDirPath + "/mnist/") !

val context = Context.gpu()

val output = GanMnist.test(modelDirPath, context, modelDirPath, 5)
Process("rm -rf " + modelDirPath) !

assert(output >= 0.0f)
} else {
logger.info("GPU test only, skipped...")
}
}
}