Skip to content

Commit

Permalink
Merge pull request apache#34 from javelinjs/scala-package-l
Browse files Browse the repository at this point in the history
Change Shape from Vector to an independent class
  • Loading branch information
terrytangyuan committed Feb 27, 2016
2 parents 70d6124 + fcdcafb commit 1f84bc2
Show file tree
Hide file tree
Showing 20 changed files with 166 additions and 113 deletions.
2 changes: 0 additions & 2 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ object Base {
type MXUint = Int
type MXFloat = Float
type CPtrAddress = Long
// TODO: make it more friendly to java
type Shape = Vector[Int]

type NDArrayHandle = CPtrAddress
type FunctionHandle = CPtrAddress
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ class Accuracy extends EvalMetric("accuracy") {

for ((pred, label) <- preds zip labels) {
val predLabel = NDArray.argmaxChannel(pred)
require(label.shape.sameElements(predLabel.shape),
s"label (${label.shape.mkString(",")}) and prediction (${predLabel.shape.mkString(",")})" +
require(label.shape == predLabel.shape,
s"label ${label.shape} and prediction ${predLabel.shape}" +
s"should have the same length.")
for ((labelElem, predElem) <- label.toArray zip predLabel.toArray) {
if (labelElem == predElem) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ class DataParallelExecutorManager(symbol: Symbol,
ctx.zipWithIndex.map { case (context, i) =>
val dataShapes =
trainData.provideData.map { case (name: String, shape: Shape) =>
(name, Vector(slices(i)._2 - slices(i)._1) ++ shape.drop(1))
(name, Shape(slices(i)._2 - slices(i)._1) ++ shape.drop(1))
}
symbol.simpleBind(context, "write", shapeDict = dataShapes)
}
Expand Down Expand Up @@ -334,7 +334,7 @@ class DataParallelExecutorManager(symbol: Symbol,
}.toArray
private val batchSize = trainData.batchSize
private val outputShapes: Array[Shape] = trainExecs(0).outputs.map { x: NDArray =>
Vector(batchSize) ++ x.shape.drop(1)
Shape(batchSize) ++ x.shape.drop(1)
}
private[mxnet] val cpuOutputArrays = outputShapes.map(NDArray.zeros(_))

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package ml.dmlc.mxnet

import ml.dmlc.mxnet.Base.Shape
import ml.dmlc.mxnet.io.NDArrayIter
import ml.dmlc.mxnet.optimizer.SGD
import org.slf4j.{Logger, LoggerFactory}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class Monitor(protected val interval: Int, protected var statFunc: (NDArray) =>
val res = new mutable.Queue[(Int, String, String)]
queue.foreach { q =>
val (n, k, v) = q
if (v.shape.sameElements(Array(1))) {
if (v.shape == Shape(1)) {
res += ((n, k, v.toScalar.toString))
} else {
res += ((n, k, s"[${v.toArray.mkString(",")}]"))
Expand Down
27 changes: 13 additions & 14 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala
Original file line number Diff line number Diff line change
Expand Up @@ -222,9 +222,9 @@ object NDArray {
new NDArray(handle = NDArray.newAllocHandle(shape, context, delayAlloc = false))
}

def empty(shape: Int *): NDArray = empty(shape.toVector)
def empty(shape: Int *): NDArray = empty(Shape(shape: _*))

def empty(ctx: Context, shape: Int *): NDArray = empty(shape.toVector, ctx)
def empty(ctx: Context, shape: Int *): NDArray = empty(Shape(shape: _*), ctx)

/**
* Create a new NDArray filled with 0, with specified shape.
Expand All @@ -240,9 +240,9 @@ object NDArray {
arr
}

def zeros(shape: Int *): NDArray = zeros(shape.toVector)
def zeros(shape: Int *): NDArray = zeros(Shape(shape: _*))

def zeros(ctx: Context, shape: Int *): NDArray = zeros(shape.toVector, ctx)
def zeros(ctx: Context, shape: Int *): NDArray = zeros(Shape(shape: _*), ctx)

/**
* Create a new NDArray filled with 1, with specified shape.
Expand All @@ -256,9 +256,9 @@ object NDArray {
arr
}

def ones(shape: Int *): NDArray = ones(shape.toVector)
def ones(shape: Int *): NDArray = ones(Shape(shape: _*))

def ones(ctx: Context, shape: Int *): NDArray = ones(shape.toVector, ctx)
def ones(ctx: Context, shape: Int *): NDArray = ones(Shape(shape: _*), ctx)

/**
* Clip ndarray elements to range (from, to)
Expand Down Expand Up @@ -477,12 +477,12 @@ object NDArray {
val shape = array0.shape.drop(1)
var axis0 = array0.shape(0)
arrays.drop(1).foreach { array =>
require(shape.sameElements(array.shape.drop(1)),
s"shape mismatch between (${array.shape.mkString(",")}) and (${shape.mkString(",")})")
require(shape == array.shape.drop(1),
s"shape mismatch between ${array.shape} and $shape")
axis0 += array.shape(0)
}

val output = NDArray.empty(Vector(axis0) ++ shape, ctx)
val output = NDArray.empty(Shape(axis0) ++ shape, ctx)
axis0 = 0
arrays.foreach { array =>
output.slice(axis0, axis0 + array.shape(0)).set(array)
Expand Down Expand Up @@ -767,7 +767,7 @@ class NDArray(private[mxnet] val handle: NDArrayHandle, val writable: Boolean =
* @return The scalar representation of the ndarray.
*/
def toScalar: Float = {
require(shape.sameElements(Array(1)), "The current array is not a scalar")
require(shape == Shape(1), "The current array is not a scalar")
this.toArray(0)
}

Expand Down Expand Up @@ -812,16 +812,15 @@ class NDArray(private[mxnet] val handle: NDArrayHandle, val writable: Boolean =
val data = ArrayBuffer[Int]()
checkCall(_LIB.mxNDArrayGetShape(handle, ndim, data))
require(ndim.value == data.length, s"ndim=$ndim, while len(pdata)=${data.length}")
data.toVector
Shape(data)
}

// Get size of current NDArray.
def size: Int = shape.product

override def equals(o: Any): Boolean = o match {
case that: NDArray => {
that.shape == this.shape && that.toArray.sameElements(this.toArray)
}
case that: NDArray =>
that != null && that.shape == this.shape && that.toArray.sameElements(this.toArray)
case _ => false
}

Expand Down
43 changes: 43 additions & 0 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/Shape.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package ml.dmlc.mxnet

/**
* Shape of [[NDArray]] or other data
* @author Yizhi Liu
*/
class Shape(dims: Traversable[Int]) {
private val shape = dims.toVector

def this(dims: Int*) = {
this(dims.toVector)
}

def apply(dim: Int): Int = shape(dim)
def size: Int = shape.size
def length: Int = shape.length
def drop(dim: Int): Shape = new Shape(shape.drop(dim))
def slice(from: Int, end: Int): Shape = new Shape(shape.slice(from, end))
def product: Int = shape.product
def head: Int = shape.head

def ++(other: Shape): Shape = new Shape(shape ++ other.shape)

def toArray: Array[Int] = shape.toArray
def toVector: Vector[Int] = shape

override def toString(): String = s"(${shape.mkString(",")})"

override def equals(o: Any): Boolean = o match {
case that: Shape =>
that != null && that.shape.sameElements(shape)
case _ => false
}

override def hashCode(): Int = {
shape.hashCode()
}
}

object Shape {
def apply(dims: Int *): Shape = new Shape(dims: _*)
def apply(dims: Traversable[Int]): Shape = new Shape(dims)
}
8 changes: 5 additions & 3 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ class Symbol(private[mxnet] val handle: SymbolHandle) {
val sdata = ArrayBuffer.empty[Int]
args.foreach { shape =>
if (shape != null) {
sdata ++= shape
sdata ++= shape.toVector
indPtr += sdata.size
}
}
Expand All @@ -212,7 +212,7 @@ class Symbol(private[mxnet] val handle: SymbolHandle) {
val sdata = ArrayBuffer.empty[Int]
kwargs.foreach { case (key, shape) =>
keys += key
sdata ++= shape
sdata ++= shape.toVector
indPtr += sdata.size
}
inferShape(keys.toArray, indPtr.toArray, sdata.toArray)
Expand All @@ -228,7 +228,9 @@ class Symbol(private[mxnet] val handle: SymbolHandle) {
checkCall(_LIB.mxSymbolInferShape(handle, indPtr.size - 1, keys, indPtr, values,
argShapeData, outShapeData, auxShapeData, complete))
if (complete.value != 0) {
(argShapeData.map(_.toVector), outShapeData.map(_.toVector), auxShapeData.map(_.toVector))
(argShapeData.map(s => Shape(s)),
outShapeData.map(s => Shape(s)),
auxShapeData.map(s => Shape(s)))
} else {
(null, null, null)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package ml.dmlc.mxnet.io

import ml.dmlc.mxnet.Base._
import ml.dmlc.mxnet.{DataPack, DataBatch, DataIter, NDArray}
import ml.dmlc.mxnet.{DataPack, DataBatch, DataIter, NDArray, Shape}
import ml.dmlc.mxnet.IO._
import org.slf4j.LoggerFactory

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package ml.dmlc.mxnet.io

import ml.dmlc.mxnet.Base._
import ml.dmlc.mxnet.{DataIter, NDArray}
import ml.dmlc.mxnet.{DataIter, NDArray, Shape}

/**
* TODO
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package ml.dmlc.mxnet.io

import ml.dmlc.mxnet.{DataIter, NDArray}
import ml.dmlc.mxnet.Base._
import ml.dmlc.mxnet.{DataIter, NDArray, Shape}

/**
* TODO
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import ml.dmlc.mxnet.CheckUtils._

class ExecutorSuite extends FunSuite with BeforeAndAfterAll {
test("bind") {
val shape = Vector(100, 30)
val shape = Shape(100, 30)
val lhs = Symbol.Variable("lhs")
val rhs = Symbol.Variable("rhs")
val ret = lhs + rhs
Expand Down
4 changes: 2 additions & 2 deletions scala-package/core/src/test/scala/ml/dmlc/mxnet/IOSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ class IOSuite extends FunSuite with BeforeAndAfterAll {
// test provideData
val provideData = mnistIter.provideData
val provideLabel = mnistIter.provideLabel
assert(provideData("data") === Array(100, 784))
assert(provideLabel("label") === Array(100))
assert(provideData("data") === Shape(100, 784))
assert(provideLabel("label") === Shape(100))
// test_loop
mnistIter.reset()
batchCount = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import org.scalatest.{BeforeAndAfterAll, FunSuite}
class KVStoreSuite extends FunSuite with BeforeAndAfterAll {
test("init and pull") {
val kv = KVStore.create()
val shape = Vector(2, 1)
val shape = Shape(2, 1)
val ndArray = NDArray.zeros(shape)

kv.init(3, NDArray.ones(shape))
Expand All @@ -15,7 +15,7 @@ class KVStoreSuite extends FunSuite with BeforeAndAfterAll {

test("push and pull") {
val kv = KVStore.create()
val shape = Vector(2, 1)
val shape = Shape(2, 1)
val ndArray = NDArray.zeros(shape)

kv.init(3, NDArray.ones(shape))
Expand All @@ -36,7 +36,7 @@ class KVStoreSuite extends FunSuite with BeforeAndAfterAll {
}
kv.setUpdater(updater)

val shape = Vector(2, 1)
val shape = Shape(2, 1)
val ndArray = NDArray.zeros(shape)

kv.init(3, NDArray.ones(shape) * 4)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class ModelParallelSuite extends FunSuite with BeforeAndAfterAll {
net = net + data1
}

val shape = Vector(4, 5)
val shape = Shape(4, 5)
val (arr, arrGrad) =
new Context(Context.cpu(0)).withScope {
val arr = (0 until n).map(_ => NDArray.empty(shape))
Expand Down
Loading

0 comments on commit 1f84bc2

Please sign in to comment.