Skip to content

Commit

Permalink
[MXNET-539] Allow Scala users to specify data/label names for NDArray…
Browse files Browse the repository at this point in the history
…Iter (apache#11256)

* improve NDArrayIter to have Builder and ability to specifying names
  • Loading branch information
yzhliu authored and nswamy committed Jun 15, 2018
1 parent e48a8fd commit 02e8a71
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 29 deletions.
134 changes: 107 additions & 27 deletions scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ import scala.collection.immutable.ListMap
/**
* NDArrayIter object in mxnet. Taking NDArray to get dataiter.
*
* @param data NDArrayIter supports single or multiple data and label.
* @param data Specify the data as well as the name.
* NDArrayIter supports single or multiple data and label.
* @param label Same as data, but is not fed to the model during testing.
* @param dataBatchSize Batch Size
* @param shuffle Whether to shuffle the data
Expand All @@ -38,15 +39,35 @@ import scala.collection.immutable.ListMap
* the size of data does not match batch_size. Roll over is intended
* for training and can cause problems if used for prediction.
*/
class NDArrayIter (data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = IndexedSeq.empty,
private val dataBatchSize: Int = 1, shuffle: Boolean = false,
lastBatchHandle: String = "pad",
dataName: String = "data", labelName: String = "label") extends DataIter {
private val logger = LoggerFactory.getLogger(classOf[NDArrayIter])
class NDArrayIter(data: IndexedSeq[(String, NDArray)],
label: IndexedSeq[(String, NDArray)],
private val dataBatchSize: Int, shuffle: Boolean,
lastBatchHandle: String) extends DataIter {

/**
* @param data Specify the data. Data names will be data_0, data_1, ..., etc.
* @param label Same as data, but is not fed to the model during testing.
* Label names will be label_0, label_1, ..., etc.
* @param dataBatchSize Batch Size
* @param shuffle Whether to shuffle the data
* @param lastBatchHandle "pad", "discard" or "roll_over". How to handle the last batch
*
* This iterator will pad, discard or roll over the last batch if
* the size of data does not match batch_size. Roll over is intended
* for training and can cause problems if used for prediction.
*/
def this(data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = IndexedSeq.empty,
dataBatchSize: Int = 1, shuffle: Boolean = false,
lastBatchHandle: String = "pad",
dataName: String = "data", labelName: String = "label") {
this(IO.initData(data, allowEmpty = false, dataName),
IO.initData(label, allowEmpty = true, labelName),
dataBatchSize, shuffle, lastBatchHandle)
}

private val logger = LoggerFactory.getLogger(classOf[NDArrayIter])

private val (_dataList: IndexedSeq[NDArray],
_labelList: IndexedSeq[NDArray]) = {
val (initData: IndexedSeq[(String, NDArray)], initLabel: IndexedSeq[(String, NDArray)]) = {
// data should not be null and size > 0
require(data != null && data.size > 0,
"data should not be null and data.size should not be zero")
Expand All @@ -55,17 +76,17 @@ class NDArrayIter (data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = Index
"label should not be null. Use IndexedSeq.empty if there are no labels")

// shuffle is not supported currently
require(shuffle == false, "shuffle is not supported currently")
require(!shuffle, "shuffle is not supported currently")

// discard final part if lastBatchHandle equals discard
if (lastBatchHandle.equals("discard")) {
val dataSize = data(0).shape(0)
val dataSize = data(0)._2.shape(0)
require(dataBatchSize <= dataSize,
"batch_size need to be smaller than data size when not padding.")
val keepSize = dataSize - dataSize % dataBatchSize
val dataList = data.map(ndArray => {ndArray.slice(0, keepSize)})
val dataList = data.map { case (name, ndArray) => (name, ndArray.slice(0, keepSize)) }
if (!label.isEmpty) {
val labelList = label.map(ndArray => {ndArray.slice(0, keepSize)})
val labelList = label.map { case (name, ndArray) => (name, ndArray.slice(0, keepSize)) }
(dataList, labelList)
} else {
(dataList, label)
Expand All @@ -75,13 +96,9 @@ class NDArrayIter (data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = Index
}
}


val initData: IndexedSeq[(String, NDArray)] = IO.initData(_dataList, false, dataName)
val initLabel: IndexedSeq[(String, NDArray)] = IO.initData(_labelList, true, labelName)
val numData = _dataList(0).shape(0)
val numSource = initData.size
var cursor = -dataBatchSize

val numData = initData(0)._2.shape(0)
val numSource: MXUint = initData.size
private var cursor = -dataBatchSize

private val (_provideData: ListMap[String, Shape],
_provideLabel: ListMap[String, Shape]) = {
Expand Down Expand Up @@ -112,8 +129,8 @@ class NDArrayIter (data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = Index
* reset the iterator
*/
override def reset(): Unit = {
if (lastBatchHandle.equals("roll_over") && cursor>numData) {
cursor = -dataBatchSize + (cursor%numData)%dataBatchSize
if (lastBatchHandle.equals("roll_over") && cursor > numData) {
cursor = -dataBatchSize + (cursor%numData) % dataBatchSize
} else {
cursor = -dataBatchSize
}
Expand Down Expand Up @@ -154,16 +171,16 @@ class NDArrayIter (data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = Index
newArray
}

private def _getData(data: IndexedSeq[NDArray]): IndexedSeq[NDArray] = {
private def _getData(data: IndexedSeq[(String, NDArray)]): IndexedSeq[NDArray] = {
require(cursor < numData, "DataIter needs reset.")
if (data == null) {
null
} else {
if (cursor + dataBatchSize <= numData) {
data.map(ndArray => {ndArray.slice(cursor, cursor + dataBatchSize)}).toIndexedSeq
data.map { case (_, ndArray) => ndArray.slice(cursor, cursor + dataBatchSize) }
} else {
// padding
data.map(_padData).toIndexedSeq
data.map { case (_, ndArray) => _padData(ndArray) }
}
}
}
Expand All @@ -173,23 +190,23 @@ class NDArrayIter (data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = Index
* @return the data of current batch
*/
override def getData(): IndexedSeq[NDArray] = {
_getData(_dataList)
_getData(initData)
}

/**
* Get label of current batch
* @return the label of current batch
*/
override def getLabel(): IndexedSeq[NDArray] = {
_getData(_labelList)
_getData(initLabel)
}

/**
* the index of current batch
* @return
*/
override def getIndex(): IndexedSeq[Long] = {
(cursor.toLong to (cursor + dataBatchSize).toLong).toIndexedSeq
cursor.toLong to (cursor + dataBatchSize).toLong
}

/**
Expand All @@ -213,3 +230,66 @@ class NDArrayIter (data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = Index

override def batchSize: Int = dataBatchSize
}

object NDArrayIter {

/**
* Builder class for NDArrayIter.
*/
class Builder() {
private var data: IndexedSeq[(String, NDArray)] = IndexedSeq.empty
private var label: IndexedSeq[(String, NDArray)] = IndexedSeq.empty
private var dataBatchSize: Int = 1
private var lastBatchHandle: String = "pad"

/**
* Add one data input with its name.
* @param name Data name.
* @param data Data nd-array.
* @return The builder object itself.
*/
def addData(name: String, data: NDArray): Builder = {
this.data = this.data ++ IndexedSeq((name, data))
this
}

/**
* Add one label input with its name.
* @param name Label name.
* @param label Label nd-array.
* @return The builder object itself.
*/
def addLabel(name: String, label: NDArray): Builder = {
this.label = this.label ++ IndexedSeq((name, label))
this
}

/**
* Set the batch size of the iterator.
* @param batchSize batch size.
* @return The builder object itself.
*/
def setBatchSize(batchSize: Int): Builder = {
this.dataBatchSize = batchSize
this
}

/**
* How to handle the last batch.
* @param lastBatchHandle Can be "pad", "discard" or "roll_over".
* @return The builder object itself.
*/
def setLastBatchHandle(lastBatchHandle: String): Builder = {
this.lastBatchHandle = lastBatchHandle
this
}

/**
* Build the NDArrayIter object.
* @return the built object.
*/
def build(): NDArrayIter = {
new NDArrayIter(data, label, dataBatchSize, false, lastBatchHandle)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import scala.sys.process._

class IOSuite extends FunSuite with BeforeAndAfterAll {

private var tu = new TestUtil
private val tu = new TestUtil

test("test MNISTIter & MNISTPack") {
// get data
Expand Down Expand Up @@ -258,7 +258,11 @@ class IOSuite extends FunSuite with BeforeAndAfterAll {
assert(batchCount === nBatch0)

// test discard
val dataIter1 = new NDArrayIter(data, label, 128, false, "discard")
val dataIter1 = new NDArrayIter.Builder()
.addData("data0", data(0)).addData("data1", data(1))
.addLabel("label", label(0))
.setBatchSize(128)
.setLastBatchHandle("discard").build()
val nBatch1 = 7
batchCount = 0
while(dataIter1.hasNext) {
Expand Down

0 comments on commit 02e8a71

Please sign in to comment.