diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala index 76b5d1a3df17..9f02c6e4e2e0 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala @@ -10,8 +10,6 @@ object Base { type MXUint = Int type MXFloat = Float type CPtrAddress = Long - // TODO: make it more friendly to java - type Shape = Vector[Int] type NDArrayHandle = CPtrAddress type FunctionHandle = CPtrAddress diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/EvalMetric.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/EvalMetric.scala index bf65735821e7..beed27696fc8 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/EvalMetric.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/EvalMetric.scala @@ -50,8 +50,8 @@ class Accuracy extends EvalMetric("accuracy") { for ((pred, label) <- preds zip labels) { val predLabel = NDArray.argmaxChannel(pred) - require(label.shape.sameElements(predLabel.shape), - s"label (${label.shape.mkString(",")}) and prediction (${predLabel.shape.mkString(",")})" + + require(label.shape == predLabel.shape, + s"label ${label.shape} and prediction ${predLabel.shape}" + s"should have the same length.") for ((labelElem, predElem) <- label.toArray zip predLabel.toArray) { if (labelElem == predElem) { diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala index 72d5ad3623ea..14f0c4b1c2a4 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala @@ -294,7 +294,7 @@ class DataParallelExecutorManager(symbol: Symbol, ctx.zipWithIndex.map { case (context, i) => val dataShapes = trainData.provideData.map { case (name: String, shape: Shape) => - (name, Vector(slices(i)._2 - slices(i)._1) ++ shape.drop(1)) + (name, Shape(slices(i)._2 - slices(i)._1) ++ shape.drop(1)) } symbol.simpleBind(context, "write", shapeDict = dataShapes) } @@ -334,7 +334,7 @@ class DataParallelExecutorManager(symbol: Symbol, }.toArray private val batchSize = trainData.batchSize private val outputShapes: Array[Shape] = trainExecs(0).outputs.map { x: NDArray => - Vector(batchSize) ++ x.shape.drop(1) + Shape(batchSize) ++ x.shape.drop(1) } private[mxnet] val cpuOutputArrays = outputShapes.map(NDArray.zeros(_)) diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Model.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Model.scala index db1db0d0609d..a23ce54d489c 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Model.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Model.scala @@ -1,6 +1,5 @@ package ml.dmlc.mxnet -import ml.dmlc.mxnet.Base.Shape import ml.dmlc.mxnet.io.NDArrayIter import ml.dmlc.mxnet.optimizer.SGD import org.slf4j.{Logger, LoggerFactory} diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Monitor.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Monitor.scala index 5ef6c8eaec81..d19048125196 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Monitor.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Monitor.scala @@ -85,7 +85,7 @@ class Monitor(protected val interval: Int, protected var statFunc: (NDArray) => val res = new mutable.Queue[(Int, String, String)] queue.foreach { q => val (n, k, v) = q - if (v.shape.sameElements(Array(1))) { + if (v.shape == Shape(1)) { res += ((n, k, v.toScalar.toString)) } else { res += ((n, k, s"[${v.toArray.mkString(",")}]")) diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala index 8ad6d32248d3..9c040ea4c4b4 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala @@ -222,9 +222,9 @@ object NDArray { new NDArray(handle = NDArray.newAllocHandle(shape, context, delayAlloc = false)) } - def empty(shape: Int *): NDArray = empty(shape.toVector) + def empty(shape: Int *): NDArray = empty(Shape(shape: _*)) - def empty(ctx: Context, shape: Int *): NDArray = empty(shape.toVector, ctx) + def empty(ctx: Context, shape: Int *): NDArray = empty(Shape(shape: _*), ctx) /** * Create a new NDArray filled with 0, with specified shape. @@ -240,9 +240,9 @@ object NDArray { arr } - def zeros(shape: Int *): NDArray = zeros(shape.toVector) + def zeros(shape: Int *): NDArray = zeros(Shape(shape: _*)) - def zeros(ctx: Context, shape: Int *): NDArray = zeros(shape.toVector, ctx) + def zeros(ctx: Context, shape: Int *): NDArray = zeros(Shape(shape: _*), ctx) /** * Create a new NDArray filled with 1, with specified shape. @@ -256,9 +256,9 @@ object NDArray { arr } - def ones(shape: Int *): NDArray = ones(shape.toVector) + def ones(shape: Int *): NDArray = ones(Shape(shape: _*)) - def ones(ctx: Context, shape: Int *): NDArray = ones(shape.toVector, ctx) + def ones(ctx: Context, shape: Int *): NDArray = ones(Shape(shape: _*), ctx) /** * Clip ndarray elements to range (from, to) @@ -477,12 +477,12 @@ object NDArray { val shape = array0.shape.drop(1) var axis0 = array0.shape(0) arrays.drop(1).foreach { array => - require(shape.sameElements(array.shape.drop(1)), - s"shape mismatch between (${array.shape.mkString(",")}) and (${shape.mkString(",")})") + require(shape == array.shape.drop(1), + s"shape mismatch between ${array.shape} and $shape") axis0 += array.shape(0) } - val output = NDArray.empty(Vector(axis0) ++ shape, ctx) + val output = NDArray.empty(Shape(axis0) ++ shape, ctx) axis0 = 0 arrays.foreach { array => output.slice(axis0, axis0 + array.shape(0)).set(array) @@ -767,7 +767,7 @@ class NDArray(private[mxnet] val handle: NDArrayHandle, val writable: Boolean = * @return The scalar representation of the ndarray. */ def toScalar: Float = { - require(shape.sameElements(Array(1)), "The current array is not a scalar") + require(shape == Shape(1), "The current array is not a scalar") this.toArray(0) } @@ -812,16 +812,15 @@ class NDArray(private[mxnet] val handle: NDArrayHandle, val writable: Boolean = val data = ArrayBuffer[Int]() checkCall(_LIB.mxNDArrayGetShape(handle, ndim, data)) require(ndim.value == data.length, s"ndim=$ndim, while len(pdata)=${data.length}") - data.toVector + Shape(data) } // Get size of current NDArray. def size: Int = shape.product override def equals(o: Any): Boolean = o match { - case that: NDArray => { - that.shape == this.shape && that.toArray.sameElements(this.toArray) - } + case that: NDArray => + that != null && that.shape == this.shape && that.toArray.sameElements(this.toArray) case _ => false } diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Shape.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Shape.scala new file mode 100644 index 000000000000..42d9096e2465 --- /dev/null +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Shape.scala @@ -0,0 +1,43 @@ +package ml.dmlc.mxnet + +/** + * Shape of [[NDArray]] or other data + * @author Yizhi Liu + */ +class Shape(dims: Traversable[Int]) { + private val shape = dims.toVector + + def this(dims: Int*) = { + this(dims.toVector) + } + + def apply(dim: Int): Int = shape(dim) + def size: Int = shape.size + def length: Int = shape.length + def drop(dim: Int): Shape = new Shape(shape.drop(dim)) + def slice(from: Int, end: Int): Shape = new Shape(shape.slice(from, end)) + def product: Int = shape.product + def head: Int = shape.head + + def ++(other: Shape): Shape = new Shape(shape ++ other.shape) + + def toArray: Array[Int] = shape.toArray + def toVector: Vector[Int] = shape + + override def toString(): String = s"(${shape.mkString(",")})" + + override def equals(o: Any): Boolean = o match { + case that: Shape => + that != null && that.shape.sameElements(shape) + case _ => false + } + + override def hashCode(): Int = { + shape.hashCode() + } +} + +object Shape { + def apply(dims: Int *): Shape = new Shape(dims: _*) + def apply(dims: Traversable[Int]): Shape = new Shape(dims) +} diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala index 7e5fc53bc9fe..80cc41c3decd 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala @@ -188,7 +188,7 @@ class Symbol(private[mxnet] val handle: SymbolHandle) { val sdata = ArrayBuffer.empty[Int] args.foreach { shape => if (shape != null) { - sdata ++= shape + sdata ++= shape.toVector indPtr += sdata.size } } @@ -212,7 +212,7 @@ class Symbol(private[mxnet] val handle: SymbolHandle) { val sdata = ArrayBuffer.empty[Int] kwargs.foreach { case (key, shape) => keys += key - sdata ++= shape + sdata ++= shape.toVector indPtr += sdata.size } inferShape(keys.toArray, indPtr.toArray, sdata.toArray) @@ -228,7 +228,9 @@ class Symbol(private[mxnet] val handle: SymbolHandle) { checkCall(_LIB.mxSymbolInferShape(handle, indPtr.size - 1, keys, indPtr, values, argShapeData, outShapeData, auxShapeData, complete)) if (complete.value != 0) { - (argShapeData.map(_.toVector), outShapeData.map(_.toVector), auxShapeData.map(_.toVector)) + (argShapeData.map(s => Shape(s)), + outShapeData.map(s => Shape(s)), + auxShapeData.map(s => Shape(s))) } else { (null, null, null) } diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/io/MXDataIter.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/io/MXDataIter.scala index 30a3974b4f32..06c0569be6bb 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/io/MXDataIter.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/io/MXDataIter.scala @@ -1,7 +1,7 @@ package ml.dmlc.mxnet.io import ml.dmlc.mxnet.Base._ -import ml.dmlc.mxnet.{DataPack, DataBatch, DataIter, NDArray} +import ml.dmlc.mxnet.{DataPack, DataBatch, DataIter, NDArray, Shape} import ml.dmlc.mxnet.IO._ import org.slf4j.LoggerFactory diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/io/NDArrayIter.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/io/NDArrayIter.scala index e9273e7e4341..f0169b87bcb8 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/io/NDArrayIter.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/io/NDArrayIter.scala @@ -1,7 +1,7 @@ package ml.dmlc.mxnet.io import ml.dmlc.mxnet.Base._ -import ml.dmlc.mxnet.{DataIter, NDArray} +import ml.dmlc.mxnet.{DataIter, NDArray, Shape} /** * TODO diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/io/PrefetchingIter.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/io/PrefetchingIter.scala index 4667d89b7f3b..270584079cfa 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/io/PrefetchingIter.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/io/PrefetchingIter.scala @@ -1,7 +1,6 @@ package ml.dmlc.mxnet.io -import ml.dmlc.mxnet.{DataIter, NDArray} -import ml.dmlc.mxnet.Base._ +import ml.dmlc.mxnet.{DataIter, NDArray, Shape} /** * TODO diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/ExecutorSuite.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/ExecutorSuite.scala index 4da8e6728e2f..65f8783ffec9 100644 --- a/scala-package/core/src/test/scala/ml/dmlc/mxnet/ExecutorSuite.scala +++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/ExecutorSuite.scala @@ -5,7 +5,7 @@ import ml.dmlc.mxnet.CheckUtils._ class ExecutorSuite extends FunSuite with BeforeAndAfterAll { test("bind") { - val shape = Vector(100, 30) + val shape = Shape(100, 30) val lhs = Symbol.Variable("lhs") val rhs = Symbol.Variable("rhs") val ret = lhs + rhs diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/IOSuite.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/IOSuite.scala index a1946d9f410a..326248114fcb 100644 --- a/scala-package/core/src/test/scala/ml/dmlc/mxnet/IOSuite.scala +++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/IOSuite.scala @@ -34,8 +34,8 @@ class IOSuite extends FunSuite with BeforeAndAfterAll { // test provideData val provideData = mnistIter.provideData val provideLabel = mnistIter.provideLabel - assert(provideData("data") === Array(100, 784)) - assert(provideLabel("label") === Array(100)) + assert(provideData("data") === Shape(100, 784)) + assert(provideLabel("label") === Shape(100)) // test_loop mnistIter.reset() batchCount = 0 diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/KVStoreSuite.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/KVStoreSuite.scala index 3e9fc625b783..7a3dc940a737 100644 --- a/scala-package/core/src/test/scala/ml/dmlc/mxnet/KVStoreSuite.scala +++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/KVStoreSuite.scala @@ -5,7 +5,7 @@ import org.scalatest.{BeforeAndAfterAll, FunSuite} class KVStoreSuite extends FunSuite with BeforeAndAfterAll { test("init and pull") { val kv = KVStore.create() - val shape = Vector(2, 1) + val shape = Shape(2, 1) val ndArray = NDArray.zeros(shape) kv.init(3, NDArray.ones(shape)) @@ -15,7 +15,7 @@ class KVStoreSuite extends FunSuite with BeforeAndAfterAll { test("push and pull") { val kv = KVStore.create() - val shape = Vector(2, 1) + val shape = Shape(2, 1) val ndArray = NDArray.zeros(shape) kv.init(3, NDArray.ones(shape)) @@ -36,7 +36,7 @@ class KVStoreSuite extends FunSuite with BeforeAndAfterAll { } kv.setUpdater(updater) - val shape = Vector(2, 1) + val shape = Shape(2, 1) val ndArray = NDArray.zeros(shape) kv.init(3, NDArray.ones(shape) * 4) diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/ModelParallelSuite.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/ModelParallelSuite.scala index 4d2e57d3d135..afbcdca65efd 100644 --- a/scala-package/core/src/test/scala/ml/dmlc/mxnet/ModelParallelSuite.scala +++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/ModelParallelSuite.scala @@ -18,7 +18,7 @@ class ModelParallelSuite extends FunSuite with BeforeAndAfterAll { net = net + data1 } - val shape = Vector(4, 5) + val shape = Shape(4, 5) val (arr, arrGrad) = new Context(Context.cpu(0)).withScope { val arr = (0 until n).map(_ => NDArray.empty(shape)) diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/NDArraySuite.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/NDArraySuite.scala index eacc8d6b3206..acaa88d023b9 100644 --- a/scala-package/core/src/test/scala/ml/dmlc/mxnet/NDArraySuite.scala +++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/NDArraySuite.scala @@ -27,7 +27,7 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers { test("size and shape") { val ndzeros = NDArray.zeros(4, 1) - assert(ndzeros.shape === Array(4, 1)) + assert(ndzeros.shape === Shape(4, 1)) assert(ndzeros.size === 4) } @@ -103,7 +103,7 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers { } test("rsqrt") { - val ndarray = NDArray.array(Array(1f, 4f), shape = Vector(2, 1)) + val ndarray = NDArray.array(Array(1f, 4f), shape = Shape(2, 1)) assert(NDArray.rsqrt(ndarray).toArray === Array(1f, 0.5f)) } @@ -111,47 +111,47 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers { val ndarray = NDArray.empty(3, 1) ndarray.set(Array(1f, 2f, 3f)) val normed = NDArray.norm(ndarray) - assert(normed.shape === Array(1)) + assert(normed.shape === Shape(1)) assert(normed.toScalar === math.sqrt(14.0).toFloat +- 1e-3f) } test("one hot encode") { - val indices = NDArray.array(Array(1f, 0f, 2f), shape = Vector(3)) + val indices = NDArray.array(Array(1f, 0f, 2f), shape = Shape(3)) val array = NDArray.empty(3, 3) NDArray.onehotEncode(indices, array) - assert(array.shape === Array(3, 3)) + assert(array.shape === Shape(3, 3)) assert(array.toArray === Array(0f, 1f, 0f, 1f, 0f, 0f, 0f, 0f, 1f)) } test("dot") { - val arr1 = NDArray.array(Array(1f, 2f), shape = Vector(1, 2)) - val arr2 = NDArray.array(Array(3f, 4f), shape = Vector(2, 1)) + val arr1 = NDArray.array(Array(1f, 2f), shape = Shape(1, 2)) + val arr2 = NDArray.array(Array(3f, 4f), shape = Shape(2, 1)) val res = NDArray.dot(arr1, arr2) - assert(res.shape === Array(1, 1)) + assert(res.shape === Shape(1, 1)) assert(res.toArray === Array(11f)) } test("choose_element_0index") { - val arr = NDArray.array(Array(1f, 2f, 3f, 4f, 6f, 5f), shape = Vector(2, 3)) - val indices = NDArray.array(Array(0f, 1f), shape = Vector(2)) + val arr = NDArray.array(Array(1f, 2f, 3f, 4f, 6f, 5f), shape = Shape(2, 3)) + val indices = NDArray.array(Array(0f, 1f), shape = Shape(2)) val res = NDArray.chooseElement0Index(arr, indices) assert(res.toArray === Array(1f, 6f)) } test("copy to") { - val source = NDArray.array(Array(1f, 2f, 3f), shape = Vector(1, 3)) + val source = NDArray.array(Array(1f, 2f, 3f), shape = Shape(1, 3)) val dest = NDArray.empty(1, 3) source.copyTo(dest) - assert(dest.shape === Array(1, 3)) + assert(dest.shape === Shape(1, 3)) assert(dest.toArray === Array(1f, 2f, 3f)) } test("random uniform") { val matrix = NDArray.empty(3, 2) NDArray.randomUniform(0f, 1f, matrix) - assert(matrix.shape === Array(3, 2)) + assert(matrix.shape === Shape(3, 2)) val arr = matrix.toArray // scalastyle:off println println(s"Random Uniform: [${arr.mkString(",")}]") @@ -164,7 +164,7 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers { test("random gaussian") { val matrix = NDArray.empty(3, 2) NDArray.randomGaussian(0f, 1f, matrix) - assert(matrix.shape === Array(3, 2)) + assert(matrix.shape === Shape(3, 2)) val arr = matrix.toArray // scalastyle:off println println(s"Random Gaussian: [${arr.mkString(",")}]") @@ -172,32 +172,32 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers { } test("abs") { - val arr = NDArray.array(Array(-1f, -2f, 3f), shape = Vector(3, 1)) + val arr = NDArray.array(Array(-1f, -2f, 3f), shape = Shape(3, 1)) assert(NDArray.abs(arr).toArray === Array(1f, 2f, 3f)) } test("sign") { - val arr = NDArray.array(Array(-1f, -2f, 3f), shape = Vector(3, 1)) + val arr = NDArray.array(Array(-1f, -2f, 3f), shape = Shape(3, 1)) assert(NDArray.sign(arr).toArray === Array(-1f, -1f, 1f)) } test("round") { - val arr = NDArray.array(Array(1.5f, 2.1f, 3.7f), shape = Vector(3, 1)) + val arr = NDArray.array(Array(1.5f, 2.1f, 3.7f), shape = Shape(3, 1)) assert(NDArray.round(arr).toArray === Array(2f, 2f, 4f)) } test("ceil") { - val arr = NDArray.array(Array(1.5f, 2.1f, 3.7f), shape = Vector(3, 1)) + val arr = NDArray.array(Array(1.5f, 2.1f, 3.7f), shape = Shape(3, 1)) assert(NDArray.ceil(arr).toArray === Array(2f, 3f, 4f)) } test("floor") { - val arr = NDArray.array(Array(1.5f, 2.1f, 3.7f), shape = Vector(3, 1)) + val arr = NDArray.array(Array(1.5f, 2.1f, 3.7f), shape = Shape(3, 1)) assert(NDArray.floor(arr).toArray === Array(1f, 2f, 3f)) } test("square") { - val arr = NDArray.array(Array(1f, 2f, 3f), shape = Vector(3, 1)) + val arr = NDArray.array(Array(1f, 2f, 3f), shape = Shape(3, 1)) assert(NDArray.square(arr).toArray === Array(1f, 4f, 9f)) } @@ -225,32 +225,32 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers { } test("max") { - val arr = NDArray.array(Array(1.5f, 2.1f, 3.7f), shape = Vector(3, 1)) + val arr = NDArray.array(Array(1.5f, 2.1f, 3.7f), shape = Shape(3, 1)) assert(NDArray.max(arr).toScalar === 3.7f +- 1e-3f) } test("min") { - val arr = NDArray.array(Array(1.5f, 2.1f, 3.7f), shape = Vector(3, 1)) + val arr = NDArray.array(Array(1.5f, 2.1f, 3.7f), shape = Shape(3, 1)) assert(NDArray.min(arr).toScalar === 1.5f +- 1e-3f) } test("sum") { - val arr = NDArray.array(Array(1f, 2f, 3f, 4f), shape = Vector(2, 2)) + val arr = NDArray.array(Array(1f, 2f, 3f, 4f), shape = Shape(2, 2)) assert(NDArray.sum(arr).toScalar === 10f +- 1e-3f) } test("argmaxChannel") { - val arr = NDArray.array(Array(1f, 2f, 4f, 3f), shape = Vector(2, 2)) + val arr = NDArray.array(Array(1f, 2f, 4f, 3f), shape = Shape(2, 2)) val argmax = NDArray.argmaxChannel(arr) - assert(argmax.shape === Array(2)) + assert(argmax.shape === Shape(2)) assert(argmax.toArray === Array(1f, 0f)) } test("concatenate") { - val arr1 = NDArray.array(Array(1f, 2f, 4f, 3f, 3f, 3f), shape = Vector(2, 3)) - val arr2 = NDArray.array(Array(8f, 7f, 6f), shape = Vector(1, 3)) + val arr1 = NDArray.array(Array(1f, 2f, 4f, 3f, 3f, 3f), shape = Shape(2, 3)) + val arr2 = NDArray.array(Array(8f, 7f, 6f), shape = Shape(1, 3)) val arr = NDArray.concatenate(arr1, arr2) - assert(arr.shape === Array(3, 3)) + assert(arr.shape === Shape(3, 3)) assert(arr.toArray === Array(1f, 2f, 4f, 3f, 3f, 3f, 8f, 7f, 6f)) } @@ -258,14 +258,14 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers { val filename = s"${System.getProperty("java.io.tmpdir")}/ndarray-${sequence.getAndIncrement}.bin" try { - val ndarray = NDArray.array(Array(1f, 2f, 3f), shape = Vector(3, 1)) + val ndarray = NDArray.array(Array(1f, 2f, 3f), shape = Shape(3, 1)) NDArray.save(filename, Map("local" -> ndarray)) val (keys, arrays) = NDArray.load(filename) assert(keys.length === 1) assert(keys(0) === "local") assert(arrays.length === 1) val loadedArray = arrays(0) - assert(loadedArray.shape === Array(3, 1)) + assert(loadedArray.shape === Shape(3, 1)) assert(loadedArray.toArray === Array(1f, 2f, 3f)) } finally { val file = new File(filename) @@ -277,13 +277,13 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers { val filename = s"${System.getProperty("java.io.tmpdir")}/ndarray-${sequence.getAndIncrement}.bin" try { - val ndarray = NDArray.array(Array(1f, 2f, 3f), shape = Vector(3, 1)) + val ndarray = NDArray.array(Array(1f, 2f, 3f), shape = Shape(3, 1)) NDArray.save(filename, Array(ndarray)) val (keys, arrays) = NDArray.load(filename) assert(keys.length === 0) assert(arrays.length === 1) val loadedArray = arrays(0) - assert(loadedArray.shape === Array(3, 1)) + assert(loadedArray.shape === Shape(3, 1)) assert(loadedArray.toArray === Array(1f, 2f, 3f)) } finally { val file = new File(filename) @@ -299,24 +299,24 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers { } test("equals") { - val ndarray1 = NDArray.array(Array(1f, 2f, 3f), shape = Vector(3, 1)) - val ndarray2 = NDArray.array(Array(1f, 2f, 3f), shape = Vector(3, 1)) - val ndarray3 = NDArray.array(Array(1f, 2f, 3f), shape = Vector(1, 3)) - val ndarray4 = NDArray.array(Array(3f, 2f, 3f), shape = Vector(3, 1)) + val ndarray1 = NDArray.array(Array(1f, 2f, 3f), shape = Shape(3, 1)) + val ndarray2 = NDArray.array(Array(1f, 2f, 3f), shape = Shape(3, 1)) + val ndarray3 = NDArray.array(Array(1f, 2f, 3f), shape = Shape(1, 3)) + val ndarray4 = NDArray.array(Array(3f, 2f, 3f), shape = Shape(3, 1)) ndarray1 shouldEqual ndarray2 ndarray1 shouldNot equal(ndarray3) ndarray1 shouldNot equal(ndarray4) } test("slice") { - val arr = NDArray.array(Array(1f, 2f, 3f, 4f, 5f, 6f), shape = Vector(3, 2)) + val arr = NDArray.array(Array(1f, 2f, 3f, 4f, 5f, 6f), shape = Shape(3, 2)) val arr1 = arr.slice(1) - assert(arr1.shape === Vector(1, 2)) + assert(arr1.shape === Shape(1, 2)) assert(arr1.toArray === Array(3f, 4f)) val arr2 = arr.slice(1, 3) - assert(arr2.shape === Vector(2, 2)) + assert(arr2.shape === Shape(2, 2)) assert(arr2.toArray === Array(3f, 4f, 5f, 6f)) } } diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/OperatorSuite.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/OperatorSuite.scala index 74d935bc6a98..e02553832c82 100644 --- a/scala-package/core/src/test/scala/ml/dmlc/mxnet/OperatorSuite.scala +++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/OperatorSuite.scala @@ -1,6 +1,5 @@ package ml.dmlc.mxnet -import ml.dmlc.mxnet.Base.Shape import ml.dmlc.mxnet.CheckUtils._ import org.scalatest.prop.GeneratorDrivenPropertyChecks @@ -30,10 +29,10 @@ class OperatorSuite extends FunSuite with BeforeAndAfterAll } test("elementwise sum") { - checkElementwiseSumWithShape(Vector(5, 5, 3), 4) + checkElementwiseSumWithShape(Shape(5, 5, 3), 4) forAll (Gen.choose(1, 4), Gen.choose(1, 8)) { (dim, n) => forAll (Gen.listOfN(dim, Gen.choose(1, Math.pow(1000, 1.0 / dim).toInt))) { shape => - checkElementwiseSumWithShape(shape.toVector, n) + checkElementwiseSumWithShape(Shape(shape), n) } } } @@ -86,9 +85,9 @@ class OperatorSuite extends FunSuite with BeforeAndAfterAll test("concat") { val merge = Array(2, 3, 4, 5, 6) forAll (Gen.choose(2, 5)) { dim => - val shapes = mutable.ArrayBuffer.empty[Vector[Int]] + val shapes = mutable.ArrayBuffer.empty[Shape] for (i <- 0 until dim) { - shapes += Vector(merge(i), 2) + shapes += Shape(merge(i), 2) } // TODO: check dimension > 0 checkConcatWithShape(shapes, 0, skipSecond = true) @@ -99,9 +98,9 @@ class OperatorSuite extends FunSuite with BeforeAndAfterAll private def checkRegression(model: Symbol, forward: Float => Float, backward: (Float, Float) => Float) = { - val shape = Vector(3, 1) + val shape = Shape(3, 1) val arrData = Random.uniform(-1, 1, shape) - val arrLabel = Random.uniform(0, 1, Vector(shape.head)) + val arrLabel = Random.uniform(0, 1, Shape(shape.head)) val arrGrad = NDArray.empty(shape) val exec1 = model.bind(Context.cpu(), args = Array(arrData, arrLabel), argsGrad = Map("data" -> arrGrad)) @@ -135,7 +134,7 @@ class OperatorSuite extends FunSuite with BeforeAndAfterAll test("swap axes") { val data = Symbol.Variable("data") - val shape = Vector(2, 3, 4) + val shape = Shape(2, 3, 4) val arrData = NDArray.ones(shape) arrData.slice(0).set(1f) arrData.slice(1).set(2f) @@ -167,7 +166,7 @@ class OperatorSuite extends FunSuite with BeforeAndAfterAll // // [[ 1., 1., 1.], // [ 2., 2., 2.]]] - assert(out.shape === Vector(4, 2, 3)) + assert(out.shape === Shape(4, 2, 3)) for (i <- 0 until 4) { val axis0 = out.slice(i) assert(CheckUtils.reldiff(axis0.toArray, Array(1f, 1f, 1f, 2f, 2f, 2f)) < 1e-6f) @@ -176,7 +175,7 @@ class OperatorSuite extends FunSuite with BeforeAndAfterAll test("scalar op") { val data = Symbol.Variable("data") - val shape = Vector(3, 4) + val shape = Shape(3, 4) val dataTmp = NDArray.ones(shape) * 5 val test = { @@ -200,7 +199,7 @@ class OperatorSuite extends FunSuite with BeforeAndAfterAll test("scalar pow") { val data = Symbol.Variable("data") - val shape = Vector(1, 1) + val shape = Shape(1, 1) val dataTmp = NDArray.ones(shape) * 3 val dataTmpPowered = NDArray.ones(shape) * 9 val test = Symbol.pow(data, 2) @@ -210,7 +209,7 @@ class OperatorSuite extends FunSuite with BeforeAndAfterAll } test("symbol pow") { - val shape = Vector(1, 1) + val shape = Shape(1, 1) val data = Symbol.Variable("data") val dataTmp = NDArray.ones(shape) * 2 @@ -231,7 +230,7 @@ class OperatorSuite extends FunSuite with BeforeAndAfterAll } test("pow fn") { - val shape = Vector(3, 4) + val shape = Shape(3, 4) val exp = Symbol.Variable("exp") val y = Symbol.pow(2, exp) val x = NDArray.ones(shape) * 3 @@ -259,7 +258,7 @@ class OperatorSuite extends FunSuite with BeforeAndAfterAll // check ops handle duplicate input correctly. test("binary op duplicate input") { val data = Symbol.Variable("data") - val shape = Vector(3, 4) + val shape = Shape(3, 4) val dataTmp = NDArray.ones(shape) * 5 val arrData = dataTmp.copy() val arrGrad = NDArray.ones(shape) * 3 @@ -274,7 +273,7 @@ class OperatorSuite extends FunSuite with BeforeAndAfterAll test("sign") { val data = Symbol.Variable("data") - val shape = Vector(3, 4) + val shape = Shape(3, 4) val dataTmp = NDArray.ones(shape) * 5 val arrData = dataTmp.copy() val arrGrad = NDArray.ones(shape) * 3 @@ -293,7 +292,7 @@ class OperatorSuite extends FunSuite with BeforeAndAfterAll test("round, ceil, floor") { val data = Symbol.Variable("data") - val shape = Vector(3, 4) + val shape = Shape(3, 4) val dataTmp = NDArray.ones(shape) * 5.543f val arrData = dataTmp.copy() val arrGrad = NDArray.ones(shape) * 2 @@ -308,7 +307,7 @@ class OperatorSuite extends FunSuite with BeforeAndAfterAll test("rsqrt, cos, sin") { val data = Symbol.Variable("data") - val shape = Vector(3, 4) + val shape = Shape(3, 4) val dataTmp = NDArray.ones(shape) * 5 val arrData = dataTmp.copy() val arrGrad = NDArray.ones(shape) * 3 @@ -336,7 +335,7 @@ class OperatorSuite extends FunSuite with BeforeAndAfterAll test("maximum") { val data1 = Symbol.Variable("data") val data2 = Symbol.Variable("data") - val shape = Vector(3, 4) + val shape = Shape(3, 4) val dataTmp1 = Random.uniform(0, 100, shape) val dataTmp2 = Random.uniform(0, 100, shape) @@ -354,7 +353,7 @@ class OperatorSuite extends FunSuite with BeforeAndAfterAll test("minimum") { val data1 = Symbol.Variable("data") val data2 = Symbol.Variable("data") - val shape = Vector(3, 4) + val shape = Shape(3, 4) val dataTmp1 = Random.uniform(0, 100, shape) val dataTmp2 = Random.uniform(0, 100, shape) @@ -371,7 +370,7 @@ class OperatorSuite extends FunSuite with BeforeAndAfterAll test("maximum minimum scalar") { val data = Symbol.Variable("data") - val shape = Vector(3, 4) + val shape = Shape(3, 4) val dataTmp = NDArray.ones(shape) * 2 val arrData = dataTmp.copy() @@ -386,7 +385,7 @@ class OperatorSuite extends FunSuite with BeforeAndAfterAll test("abs") { val data = Symbol.Variable("data") - val shape = Vector(3, 4) + val shape = Shape(3, 4) val dataTmp = NDArray.ones(shape) * 5 val arrData = dataTmp.copy() val arrGrad = NDArray.ones(shape) * 3 @@ -428,7 +427,7 @@ class OperatorSuite extends FunSuite with BeforeAndAfterAll val (argShapes, outShapes, _) = deconv.inferShape(Map("data" -> inputShape)) val inputData = Random.uniform(-5, 5, inputShape) val outGrad = inputData - val convWeight = Random.normal(0, 1, Vector(numFilter, inputShape(1), kernel._1, kernel._2)) + val convWeight = Random.normal(0, 1, Shape(numFilter, inputShape(1), kernel._1, kernel._2)) val args: Map[String, NDArray] = Map("data" -> inputData, "conv_weight" -> convWeight, "deconv_weight" -> convWeight) val argsGrad: Seq[NDArray] = argShapes.map(NDArray.empty(_)) @@ -442,21 +441,21 @@ class OperatorSuite extends FunSuite with BeforeAndAfterAll test("deconvolution forward & backward") { checkDeconvolutionForwardBackward( - inputShape = Vector(1, 1, 5, 5), + inputShape = Shape(1, 1, 5, 5), numFilter = 1, kernel = (3, 3), stride = (1, 1), pad = (1, 1) ) checkDeconvolutionForwardBackward( - inputShape = Vector(32, 3, 28, 28), + inputShape = Shape(32, 3, 28, 28), numFilter = 3, kernel = (3, 3), stride = (1, 1), pad = (1, 1) ) checkDeconvolutionForwardBackward( - inputShape = Vector(10, 3, 403, 403), + inputShape = Shape(10, 3, 403, 403), numFilter = 3, kernel = (7, 7), stride = (5, 5), @@ -486,10 +485,10 @@ class OperatorSuite extends FunSuite with BeforeAndAfterAll val convData = Random.uniform(-5, 5, inputShape) val convArgs = Map("data_conv" -> convData, - "conv_weight" -> Random.normal(0, 1, Vector(numFilter, inputShape(1), kernel._1, kernel._2))) + "conv_weight" -> Random.normal(0, 1, Shape(numFilter, inputShape(1), kernel._1, kernel._2))) val convArgsGrad = Seq(NDArray.zeros(convData.shape), - NDArray.zeros(Vector(numFilter, inputShape(1), kernel._1, kernel._2))) + NDArray.zeros(Shape(numFilter, inputShape(1), kernel._1, kernel._2))) val exeConv = conv.bind(Context.cpu(), args = convArgs, argsGrad = convArgsGrad) val convOutGrad = Random.normal(0, 2, exeConv.outputs.head.shape) exeConv.backward(convOutGrad) @@ -497,7 +496,7 @@ class OperatorSuite extends FunSuite with BeforeAndAfterAll val deconvData = convOutGrad val deconvArgs = Map("data_deconv" -> deconvData, "deconv_weight" -> convArgs("conv_weight")) val deconvArgsGrad = Seq(NDArray.zeros(deconvData.shape), - NDArray.zeros(Vector(numFilter, inputShape(1), kernel._1, kernel._2))) + NDArray.zeros(Shape(numFilter, inputShape(1), kernel._1, kernel._2))) val exeDeconv = deconv.bind(Context.cpu(), args = deconvArgs, argsGrad = deconvArgsGrad) val deconvOutGrad = convData exeDeconv.backward(deconvOutGrad) @@ -506,12 +505,12 @@ class OperatorSuite extends FunSuite with BeforeAndAfterAll test("deconvolution gradient") { checkDeconvolutionGradient( - inputShape = Vector(1, 3, 5, 5), + inputShape = Shape(1, 3, 5, 5), numFilter = 3, pad = (1, 1) ) checkDeconvolutionGradient( - inputShape = Vector(5, 3, 100, 100), + inputShape = Shape(5, 3, 100, 100), numFilter = 3, pad = (3, 3) ) @@ -550,7 +549,7 @@ class OperatorSuite extends FunSuite with BeforeAndAfterAll for (numShape <- 1 to 3) { for (base <- 1 to 3) { val shapes = (0 until numShape).map(i => - Vector(1, 3, base * rootScale * Math.pow(scale, numShape - 1 - i).toInt, + Shape(1, 3, base * rootScale * Math.pow(scale, numShape - 1 - i).toInt, base * rootScale * Math.pow(scale, numShape - 1 - i).toInt)) checkNearestUpSamplingWithShape(shapes, scale, rootScale) } diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/RandomSuite.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/RandomSuite.scala index f2c76962453d..59b8ae51da6e 100644 --- a/scala-package/core/src/test/scala/ml/dmlc/mxnet/RandomSuite.scala +++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/RandomSuite.scala @@ -6,7 +6,7 @@ class RandomSuite extends FunSuite with BeforeAndAfterAll { test("uniform on cpu") { Context.cpu().withScope { val (a, b) = (-10, 10) - val shape = Vector(100, 100) + val shape = Shape(100, 100) Random.seed(128) val un1 = Random.uniform(a, b, shape) Random.seed(128) @@ -18,7 +18,7 @@ class RandomSuite extends FunSuite with BeforeAndAfterAll { test("normal on cpu") { val (mu, sigma) = (10f, 2f) - val shape = Vector(100, 100) + val shape = Shape(100, 100) Random.seed(128) val ret1 = Random.normal(mu, sigma, shape) Random.seed(128) diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/ShapeSuite.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/ShapeSuite.scala new file mode 100644 index 000000000000..33b13ad67914 --- /dev/null +++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/ShapeSuite.scala @@ -0,0 +1,15 @@ +package ml.dmlc.mxnet + +import org.scalatest.{BeforeAndAfterAll, FunSuite} + +class ShapeSuite extends FunSuite with BeforeAndAfterAll { + test("to string") { + val s = Shape(1, 2, 3) + assert(s.toString === "(1,2,3)") + } + + test("equals") { + assert(Shape(1, 2, 3) === Shape(1, 2, 3)) + assert(Shape(1, 2) != Shape(1, 2, 3)) + } +} diff --git a/scala-package/examples/src/main/scala/ml/dmlc/mxnet/examples/imclassification/TrainMnist.scala b/scala-package/examples/src/main/scala/ml/dmlc/mxnet/examples/imclassification/TrainMnist.scala index f09c9ece7caf..44facfab5820 100644 --- a/scala-package/examples/src/main/scala/ml/dmlc/mxnet/examples/imclassification/TrainMnist.scala +++ b/scala-package/examples/src/main/scala/ml/dmlc/mxnet/examples/imclassification/TrainMnist.scala @@ -1,6 +1,5 @@ package ml.dmlc.mxnet.examples.imclassification -import ml.dmlc.mxnet.Base.Shape import ml.dmlc.mxnet._ import org.kohsuke.args4j.{CmdLineParser, Option} import org.slf4j.LoggerFactory @@ -56,7 +55,7 @@ object TrainMnist { "image" -> (dataDir + "train-images-idx3-ubyte"), "label" -> (dataDir + "train-labels-idx1-ubyte"), "label_name" -> "softmax_label", - "input_shape" -> s"(${dataShape.mkString(",")})", + "input_shape" -> dataShape.toString, "batch_size" -> batchSize.toString, "shuffle" -> "True", "flat" -> flat, @@ -67,7 +66,7 @@ object TrainMnist { "image" -> (dataDir + "t10k-images-idx3-ubyte"), "label" -> (dataDir + "t10k-labels-idx1-ubyte"), "label_name" -> "softmax_label", - "input_shape" -> s"(${dataShape.mkString(",")})", + "input_shape" -> dataShape.toString, "batch_size" -> batchSize.toString, "flat" -> flat, "num_parts" -> kv.numWorkers.toString, @@ -84,8 +83,8 @@ object TrainMnist { parser.parseArgument(args.toList.asJava) val (dataShape, net) = - if (inst.network == "mlp") (Vector(784), getMlp) - else (Vector(1, 28, 28), getLenet) + if (inst.network == "mlp") (Shape(784), getMlp) + else (Shape(1, 28, 28), getLenet) val devs = if (inst.gpus != null) inst.gpus.split(',').map(id => Context.gpu(id.trim.toInt))