Skip to content

Commit

Permalink
[Scala] NDArrayIter constructor fix for null (apache#4308)
Browse files Browse the repository at this point in the history
  • Loading branch information
benqua authored and yzhliu committed Dec 22, 2016
1 parent 1a71e3b commit 1610b4f
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 5 deletions.
5 changes: 3 additions & 2 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,9 @@ object IO {
private[mxnet] def initData(data: IndexedSeq[NDArray],
allowEmpty: Boolean,
defaultName: String): IndexedSeq[(String, NDArray)] = {
require(data != null || allowEmpty)
if (data == null) {
require(data != null)
require(data != IndexedSeq.empty || allowEmpty)
if (data == IndexedSeq.empty) {
IndexedSeq()
} else if (data.length == 1) {
IndexedSeq((defaultName, data(0)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import scala.collection.immutable.ListMap
* the size of data does not match batch_size. Roll over is intended
* for training and can cause problems if used for prediction.
*/
class NDArrayIter (data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = null,
class NDArrayIter (data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = IndexedSeq.empty,
private val dataBatchSize: Int = 1, shuffle: Boolean = false,
lastBatchHandle: String = "pad") extends DataIter {
private val logger = LoggerFactory.getLogger(classOf[NDArrayIter])
Expand All @@ -35,6 +35,9 @@ class NDArrayIter (data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = null,
require(data != null && data.size > 0,
"data should not be null and data.size should not be zero")

require(label != null,
"label should not be null. Use IndexedSeq.empty if there are no labels")

// shuffle is not supported currently
require(shuffle == false, "shuffle is not supported currently")

Expand All @@ -45,11 +48,11 @@ class NDArrayIter (data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = null,
"batch_size need to be smaller than data size when not padding.")
val keepSize = dataSize - dataSize % dataBatchSize
val dataList = data.map(ndArray => {ndArray.slice(0, keepSize)})
if (label != null) {
if (!label.isEmpty) {
val labelList = label.map(ndArray => {ndArray.slice(0, keepSize)})
(dataList, labelList)
} else {
(dataList, null)
(dataList, label)
}
} else {
(data, label)
Expand Down
14 changes: 14 additions & 0 deletions scala-package/core/src/test/scala/ml/dmlc/mxnet/IOSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -251,5 +251,19 @@ class IOSuite extends FunSuite with BeforeAndAfterAll {
}

assert(batchCount === nBatch1)

// test empty label (for prediction)
val dataIter2 = new NDArrayIter(data = data, dataBatchSize = 128, lastBatchHandle = "discard")
batchCount = 0
while(dataIter2.hasNext) {
val tBatch = dataIter2.next()
batchCount += 1

assert(tBatch.data(0).toArray === batchData0.toArray)
assert(tBatch.data(1).toArray === batchData1.toArray)
}

assert(batchCount === nBatch1)
assert(dataIter2.initLabel == IndexedSeq.empty)
}
}

0 comments on commit 1610b4f

Please sign in to comment.