Skip to content

Commit

Permalink
Merge pull request deeplearning4j#1422 from deeplearning4j/sa_arraywr…
Browse files Browse the repository at this point in the history
…itable

Add ArrayWritable to optimize loading of images
  • Loading branch information
saudet committed Apr 27, 2016
2 parents d5609f8 + 30c6b51 commit 6a0efff
Show file tree
Hide file tree
Showing 10 changed files with 169 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,18 @@
import org.canova.api.records.reader.RecordReader;
import org.canova.api.records.reader.SequenceRecordReader;
import org.canova.api.writable.Writable;
import org.canova.common.data.NDArrayWritable;
import org.deeplearning4j.datasets.iterator.DataSetIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.FeatureUtil;

import java.util.*;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;


/**
Expand Down Expand Up @@ -204,15 +208,7 @@ private DataSet getDataSet(Collection<Writable> record) {
}

INDArray label = null;
INDArray featureVector;
if(regression && labelIndex >= 0){
//Handle the possibly multi-label regression case here:
int nLabels = labelIndexTo - labelIndex + 1;
featureVector = Nd4j.create(1, currList.size() - nLabels);
} else {
//Classification case, and also no-labels case
featureVector = Nd4j.create(labelIndex >= 0 ? currList.size() - 1 : currList.size());
}
INDArray featureVector = null;
int featureCount = 0;
int labelCount = 0;
for (int j = 0; j < currList.size(); j++) {
Expand Down Expand Up @@ -243,7 +239,28 @@ private DataSet getDataSet(Collection<Writable> record) {
label = FeatureUtil.toOutcomeVector(curr, numPossibleLabels);
}
} else {
featureVector.putScalar(featureCount++, current.toDouble());
try {
double value = current.toDouble();
if (featureVector == null) {
if(regression && labelIndex >= 0){
//Handle the possibly multi-label regression case here:
int nLabels = labelIndexTo - labelIndex + 1;
featureVector = Nd4j.create(1, currList.size() - nLabels);
} else {
//Classification case, and also no-labels case
featureVector = Nd4j.create(labelIndex >= 0 ? currList.size() - 1 : currList.size());
}
}
featureVector.putScalar(featureCount++, value);
} catch (UnsupportedOperationException e) {
// This isn't a scalar, so check if we got an array already
if (current instanceof NDArrayWritable) {
assert featureVector == null;
featureVector = ((NDArrayWritable)current).get();
} else {
throw e;
}
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,22 @@
import org.canova.api.records.reader.RecordReader;
import org.canova.api.records.reader.SequenceRecordReader;
import org.canova.api.writable.Writable;
import org.canova.common.data.NDArrayWritable;
import org.deeplearning4j.berkeley.Pair;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;

import java.util.*;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;

/**RecordReaderMultiDataSetIterator: A {@link MultiDataSetIterator} for data from one or more RecordReaders and SequenceRecordReaders<br>
* The idea: generate multiple inputs and multiple outputs from one or more Sequence/RecordReaders. Inputs and outputs
Expand Down Expand Up @@ -204,7 +212,16 @@ private INDArray convertWritables(List<Collection<Writable>> list, int minValues
int j = 0;
for (Writable w : c) {
idx[1] = j++;
arr.putScalar(idx, w.toDouble());
try {
arr.putScalar(idx, w.toDouble());
} catch (UnsupportedOperationException e) {
// This isn't a scalar, so check if we got an array already
if (w instanceof NDArrayWritable) {
arr.putRow(idx[0], ((NDArrayWritable)w).get());
} else {
throw e;
}
}
}
} else if(details.oneHot){
//Convert a single column to a one-hot representation
Expand All @@ -223,7 +240,18 @@ private INDArray convertWritables(List<Collection<Writable>> list, int minValues
int k=0;
for( int j=details.subsetStart; j<=details.subsetEndInclusive; j++){
idx[1] = k++;
arr.putScalar(idx,iter.next().toDouble());
Writable w = iter.next();
try {
arr.putScalar(idx,w.toDouble());
} catch (UnsupportedOperationException e) {
// This isn't a scalar, so check if we got an array already
if (w instanceof NDArrayWritable) {
arr.putRow(idx[0], ((NDArrayWritable)w).get().get(NDArrayIndex.all(),
NDArrayIndex.interval(details.subsetStart, details.subsetEndInclusive + 1)));
} else {
throw e;
}
}
}
}
}
Expand Down Expand Up @@ -278,7 +306,18 @@ private Pair<INDArray,INDArray> convertWritablesSequence(List<Collection<Collect
int j = 0;
while (iter.hasNext()) {
idx[1] = j++;
arr.putScalar(idx,iter.next().toDouble());
Writable w = iter.next();
try {
arr.putScalar(idx,w.toDouble());
} catch (UnsupportedOperationException e) {
// This isn't a scalar, so check if we got an array already
if (w instanceof NDArrayWritable) {
arr.get(NDArrayIndex.point(idx[0]), NDArrayIndex.all(), NDArrayIndex.point(idx[2]))
.putRow(0, ((NDArrayWritable)w).get());
} else {
throw e;
}
}
}
} else if(details.oneHot){
//Convert a single column to a one-hot representation
Expand All @@ -298,7 +337,19 @@ private Pair<INDArray,INDArray> convertWritablesSequence(List<Collection<Collect
int k=0;
for( int j=details.subsetStart; j<=details.subsetEndInclusive; j++){
idx[1] = k++;
arr.putScalar(idx,iter.next().toDouble());
Writable w = iter.next();
try {
arr.putScalar(idx,w.toDouble());
} catch (UnsupportedOperationException e) {
// This isn't a scalar, so check if we got an array already
if (w instanceof NDArrayWritable) {
arr.get(NDArrayIndex.point(idx[0]), NDArrayIndex.all(), NDArrayIndex.point(idx[2]))
.putRow(0, ((NDArrayWritable)w).get().get(NDArrayIndex.all(),
NDArrayIndex.interval(details.subsetStart, details.subsetEndInclusive + 1)));
} else {
throw e;
}
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import org.canova.api.records.reader.SequenceRecordReader;
import org.canova.api.writable.Writable;
import org.canova.common.data.NDArrayWritable;
import org.deeplearning4j.datasets.iterator.DataSetIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
Expand All @@ -11,7 +12,11 @@
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.util.FeatureUtil;

import java.util.*;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.NoSuchElementException;

/**
* Sequence record reader data set iterator
Expand Down Expand Up @@ -438,7 +443,16 @@ private INDArray getFeatures(Collection<Collection<Writable>> features) {
int f = 0;
while (timeStepIter.hasNext()) {
Writable current = timeStepIter.next();
out.put(i, f++, current.toDouble());
try {
out.put(i, f++, current.toDouble());
} catch (UnsupportedOperationException e) {
// This isn't a scalar, so check if we got an array already
if (current instanceof NDArrayWritable) {
out.putRow(i, ((NDArrayWritable)current).get());
} else {
throw e;
}
}
}
i++;
}
Expand Down Expand Up @@ -514,7 +528,16 @@ private INDArray[] getFeaturesLabelsSingleReader(Collection<Collection<Writable>
}
} else {
//feature
features.put(i, countFeatures++, current.toDouble());
try {
features.put(i, countFeatures++, current.toDouble());
} catch (UnsupportedOperationException e) {
// This isn't a scalar, so check if we got an array already
if (current instanceof NDArrayWritable) {
features.putRow(i, ((NDArrayWritable)current).get());
} else {
throw e;
}
}
}
}
i++;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public class CifarDataSetIterator extends RecordReaderDataSetIterator {
* @param numExamples the overall number of examples
* */
public CifarDataSetIterator(int batchSize, int numExamples) {
super(null, batchSize, height * width * channels , CifarLoader.NUM_LABELS);
super(null, batchSize, 1 , CifarLoader.NUM_LABELS);
this.loader = new CifarLoader();
this.inputStream = loader.getInputStream();
this.numExamples = numExamples > totalExamples? totalExamples: numExamples;
Expand All @@ -47,7 +47,7 @@ public CifarDataSetIterator(int batchSize, int numExamples) {
* @param numExamples the overall number of examples
* */
public CifarDataSetIterator(int batchSize, int numExamples, String version) {
super(null, batchSize, height * width * channels, CifarLoader.NUM_LABELS);
super(null, batchSize, 1, CifarLoader.NUM_LABELS);
this.loader = new CifarLoader(version);
this.inputStream = loader.getInputStream();
this.numExamples = numExamples > totalExamples? totalExamples: numExamples;
Expand All @@ -61,7 +61,7 @@ public CifarDataSetIterator(int batchSize, int numExamples, String version) {
* @param numExamples the overall number of examples
* */
public CifarDataSetIterator(int batchSize, int numExamples, int[] imgDim) {
super(null, batchSize, imgDim[0] * imgDim[1] * imgDim[2], CifarLoader.NUM_LABELS);
super(null, batchSize, 1, CifarLoader.NUM_LABELS);
this.loader = new CifarLoader();
this.inputStream = loader.getInputStream();
this.numExamples = numExamples > totalExamples? totalExamples: numExamples;
Expand All @@ -74,7 +74,7 @@ public CifarDataSetIterator(int batchSize, int numExamples, int[] imgDim) {
* @param numExamples the overall number of examples
* */
public CifarDataSetIterator(int batchSize, int numExamples, int[] imgDim, int numCategories) {
super(null, batchSize, imgDim[0] * imgDim[1] * imgDim[2], numCategories);
super(null, batchSize, 1, numCategories);
this.loader = new CifarLoader();
this.inputStream = loader.getInputStream();
this.numExamples = numExamples > totalExamples? totalExamples: numExamples;
Expand All @@ -87,7 +87,7 @@ public CifarDataSetIterator(int batchSize, int numExamples, int[] imgDim, int nu
* @param numCategories the overall number of labels
* */
public CifarDataSetIterator(int batchSize, int numExamples, int numCategories) {
super(null, batchSize, height * width * channels, numCategories);
super(null, batchSize, 1, numCategories);
this.loader = new CifarLoader();
this.inputStream = loader.getInputStream();
this.numExamples = numExamples > totalExamples? totalExamples: numExamples;
Expand All @@ -100,7 +100,7 @@ public CifarDataSetIterator(int batchSize, int numExamples, int numCategories) {
* @param imgDim an array of height, width and channels
*/
public CifarDataSetIterator(int batchSize, int[] imgDim) {
super(null, batchSize, imgDim[0] * imgDim[1] * imgDim[2], CifarLoader.NUM_LABELS);
super(null, batchSize, 1, CifarLoader.NUM_LABELS);
this.loader = new CifarLoader();
this.inputStream = loader.getInputStream();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public class LFWDataSetIterator extends RecordReaderDataSetIterator {
* @param numExamples the overall number of examples
* */
public LFWDataSetIterator(int batchSize, int numExamples) {
super(new LFWLoader().getRecordReader(numExamples), batchSize, height * width * channels, LFWLoader.NUM_LABELS);
super(new LFWLoader().getRecordReader(numExamples), batchSize, 1, LFWLoader.NUM_LABELS);
}

/**
Expand All @@ -46,7 +46,7 @@ public LFWDataSetIterator(int batchSize, int numExamples) {
* @param numExamples the overall number of examples
* */
public LFWDataSetIterator(int batchSize, int numExamples, int[] imgDim) {
super(new LFWLoader().getRecordReader(numExamples, imgDim[0], imgDim[1], imgDim[2]), batchSize, imgDim[0] * imgDim[1] * imgDim[2], LFWLoader.NUM_LABELS);
super(new LFWLoader().getRecordReader(numExamples, imgDim[0], imgDim[1], imgDim[2]), batchSize, 1, LFWLoader.NUM_LABELS);
}

/**
Expand All @@ -56,7 +56,7 @@ public LFWDataSetIterator(int batchSize, int numExamples, int[] imgDim) {
* @param numExamples the overall number of examples
* */
public LFWDataSetIterator(int batchSize, int numExamples, int[] imgDim, int numCategories) {
super(new LFWLoader().getRecordReader(numExamples, imgDim[0], imgDim[1], imgDim[2]), batchSize, imgDim[0] * imgDim[1] * imgDim[2], numCategories);
super(new LFWLoader().getRecordReader(numExamples, imgDim[0], imgDim[1], imgDim[2]), batchSize, 1, numCategories);
}

/**
Expand All @@ -66,7 +66,7 @@ public LFWDataSetIterator(int batchSize, int numExamples, int[] imgDim, int numC
* @param numCategories the overall number of labels
* */
public LFWDataSetIterator(int batchSize, int numExamples, int numCategories) {
super(new LFWLoader().getRecordReader(numExamples, numCategories), batchSize, height * width * channels, numCategories);
super(new LFWLoader().getRecordReader(numExamples, numCategories), batchSize, 1, numCategories);
}

/**
Expand All @@ -75,7 +75,7 @@ public LFWDataSetIterator(int batchSize, int numExamples, int numCategories) {
* @param imgDim an array of height, width and channels
*/
public LFWDataSetIterator(int batchSize, int[] imgDim) {
super(new LFWLoader().getRecordReader(imgDim[0], imgDim[1], imgDim[2]), batchSize, imgDim[0] * imgDim[1] * imgDim[2], LFWLoader.NUM_LABELS);
super(new LFWLoader().getRecordReader(imgDim[0], imgDim[1], imgDim[2]), batchSize, 1, LFWLoader.NUM_LABELS);
}


Expand All @@ -86,7 +86,7 @@ public LFWDataSetIterator(int batchSize, int[] imgDim) {
* @param numCategories the overall number of labels
* */
public LFWDataSetIterator(int batchSize, int numExamples, int numCategories, boolean useSubset) {
super(new LFWLoader(useSubset).getRecordReader(numExamples, numCategories), batchSize, height * width * channels, numCategories);
super(new LFWLoader(useSubset).getRecordReader(numExamples, numCategories), batchSize, 1, numCategories);
}

/**
Expand All @@ -95,7 +95,7 @@ public LFWDataSetIterator(int batchSize, int numExamples, int numCategories, boo
* @param imgDim an array of height, width and channels
*/
public LFWDataSetIterator(int batchSize, int[] imgDim, boolean useSubset) {
super(new LFWLoader(useSubset).getRecordReader(imgDim[0], imgDim[1], imgDim[2]), batchSize, imgDim[0] * imgDim[1] * imgDim[2], useSubset ? LFWLoader.SUB_NUM_LABELS : LFWLoader.NUM_LABELS);
super(new LFWLoader(useSubset).getRecordReader(imgDim[0], imgDim[1], imgDim[2]), batchSize, 1, useSubset ? LFWLoader.SUB_NUM_LABELS : LFWLoader.NUM_LABELS);
}


Expand All @@ -106,7 +106,7 @@ public LFWDataSetIterator(int batchSize, int[] imgDim, boolean useSubset) {
* @param numExamples the overall number of examples
* */
public LFWDataSetIterator(int batchSize, int numExamples, int[] imgDim, int numCategories, boolean useSubset, Random rng) {
super(new LFWLoader(useSubset).getRecordReader(numExamples, imgDim[0], imgDim[1], imgDim[2], numCategories, rng), batchSize, imgDim[0] * imgDim[1] * imgDim[2], numCategories);
super(new LFWLoader(useSubset).getRecordReader(numExamples, imgDim[0], imgDim[1], imgDim[2], numCategories, rng), batchSize, 1, numCategories);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ public void testMultiChannel() throws Exception {

RecordReader reader = new ImageRecordReader(28,28,3,true,labels);
reader.initialize(new FileSplit(new File(rootDir)));
DataSetIterator recordReader = new RecordReaderDataSetIterator(reader,28 * 28 * 3,labels.size());
DataSetIterator recordReader = new RecordReaderDataSetIterator(reader,1,labels.size());

labels.remove("lfwtest");
NeuralNetConfiguration.ListBuilder builder = (NeuralNetConfiguration.ListBuilder) incompleteLFW();
Expand All @@ -154,7 +154,7 @@ public void testLRN() throws Exception{

RecordReader reader = new ImageRecordReader(28,28,3,true,labels);
reader.initialize(new FileSplit(new File(rootDir)));
DataSetIterator recordReader = new RecordReaderDataSetIterator(reader,28 * 28 * 3,labels.size());
DataSetIterator recordReader = new RecordReaderDataSetIterator(reader,1,labels.size());
labels.remove("lfwtest");
NeuralNetConfiguration.ListBuilder builder = (NeuralNetConfiguration.ListBuilder) incompleteLRN();
new ConvolutionLayerSetup(builder,28,28,3);
Expand Down
Loading

0 comments on commit 6a0efff

Please sign in to comment.