Skip to content

Commit

Permalink
add IO classes
Browse files Browse the repository at this point in the history
  • Loading branch information
yanqingmen committed Dec 17, 2015
1 parent 05431f5 commit a41c2e8
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 18 deletions.
34 changes: 16 additions & 18 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -88,23 +88,23 @@ example/notebooks/.ipynb_checkpoints/*

# Scala package
# Jetbrain
mxnet-scala/.idea
scala-package/.idea

# ctags
mxnet-scala/tags
scala-package/tags

mxnet-scala/*.class
mxnet-scala/*.log
scala-package/*.class
scala-package/*.log

# sbt specific
mxnet-scala/.cache
mxnet-scala/.lib/
mxnet-scala/dist/*
mxnet-scala/target/
mxnet-scala/lib_managed/
mxnet-scala/src_managed/
mxnet-scala/project/boot/
mxnet-scala/project/plugins/project/
scala-package/.cache
scala-package/.lib/
scala-package/dist/*
scala-package/target/
scala-package/lib_managed/
scala-package/src_managed/
scala-package/project/boot/
scala-package/project/plugins/project/

#scala target folders
scala-package/*/target/
Expand All @@ -116,9 +116,7 @@ scala-package/*/*/target/
.settings

# IDE specific
mxnet-scala/.scala_dependencies
mxnet-scala/.worksheet
mxnet-scala/.idea
mxnet-scala/*.iml


.scala_dependencies
.worksheet
.idea
*.iml
1 change: 1 addition & 0 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ object Base {
type MXFloatRef = RefFloat
type NDArrayHandle = RefLong
type FunctionHandle = RefLong
type DataIterHandle = RefLong

System.loadLibrary("mxnet-scala")
val _LIB = new LibInfo
Expand Down
57 changes: 57 additions & 0 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package ml.dmlc.mxnet

import ml.dmlc.mxnet.Base._
import ml.dmlc.mxnet.NDArray
import org.slf4j.LoggerFactory



abstract class DataIter (val batchSize: Int = 0) {
/**
* reset the iterator
*/
def reset(): Unit
/**
* Iterate to next batch
* @return whether the move is successful
*/
def iterNext(): Boolean

/**
* get data of current batch
* @return the data of current batch
*/
def getData(): NDArray

/**
* Get label of current batch
* @return the label of current batch
*/
def getLabel(): NDArray

/**
* get the number of padding examples
* in current batch
* @return number of padding examples in current batch
*/
def getPad(): Int

/**
* the index of current batch
* @return
*/
def getIndex(): Seq[Int]
}

class MXDataIter(var handle: DataIterHandle) extends DataIter {
private val logger = LoggerFactory.getLogger(classOf[MXDataIter])

def reset(): Unit = {
checkCall(_LIB.mxDataIterBeforeFirst(handle))
}

def iterNext(): Boolean = {
checkCall(_LIB.mxDataIterNext(handle))
return true
}
}
8 changes: 8 additions & 0 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,12 @@ class LibInfo {
start: MXUint,
end: MXUint,
sliceHandle: NDArrayHandle): Int
@native def mxDataIterBeforeFirst(handle: DataIterHandle): Int
@native def mxDataIterNext(handle: DataIterHandle): Int
@native def mxDataIterGetLabel(handle: DataIterHandle,
out: NDArrayHandle): Int
@native def mxDataIterGetData(handle: DataIterHandle,
out: NDArrayHandle): Int
@native def mxDataIterGetPadNum(handle: DataIterHandle,
out: MXUintRef): Int
}

0 comments on commit a41c2e8

Please sign in to comment.