forked from apache/mxnet
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
05431f5
commit a41c2e8
Showing
4 changed files
with
82 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
package ml.dmlc.mxnet | ||
|
||
import ml.dmlc.mxnet.Base._ | ||
import ml.dmlc.mxnet.NDArray | ||
import org.slf4j.LoggerFactory | ||
|
||
|
||
|
||
abstract class DataIter (val batchSize: Int = 0) { | ||
/** | ||
* reset the iterator | ||
*/ | ||
def reset(): Unit | ||
/** | ||
* Iterate to next batch | ||
* @return whether the move is successful | ||
*/ | ||
def iterNext(): Boolean | ||
|
||
/** | ||
* get data of current batch | ||
* @return the data of current batch | ||
*/ | ||
def getData(): NDArray | ||
|
||
/** | ||
* Get label of current batch | ||
* @return the label of current batch | ||
*/ | ||
def getLabel(): NDArray | ||
|
||
/** | ||
* get the number of padding examples | ||
* in current batch | ||
* @return number of padding examples in current batch | ||
*/ | ||
def getPad(): Int | ||
|
||
/** | ||
* the index of current batch | ||
* @return | ||
*/ | ||
def getIndex(): Seq[Int] | ||
} | ||
|
||
class MXDataIter(var handle: DataIterHandle) extends DataIter { | ||
private val logger = LoggerFactory.getLogger(classOf[MXDataIter]) | ||
|
||
def reset(): Unit = { | ||
checkCall(_LIB.mxDataIterBeforeFirst(handle)) | ||
} | ||
|
||
def iterNext(): Boolean = { | ||
checkCall(_LIB.mxDataIterNext(handle)) | ||
return true | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters