Skip to content

Commit

Permalink
Fix SequenceRecordReaderDataSetIterator post TAD changes
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexDBlack committed May 1, 2016
1 parent 04156c0 commit 598efe9
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,9 @@ private DataSet nextSingleSequenceReader(int num){

if(minLength == maxLength){
for (int i = 0; i < listFeatures.size(); i++) {
featuresOut.tensorAlongDimension(i, 1, 2).assign(listFeatures.get(i));
labelsOut.tensorAlongDimension(i, 1, 2).assign(listLabels.get(i));
//Note: this TAD gives us shape [vectorSize,tsLength] whereas we need a [vectorSize,timeSeriesLength] matrix (that listFeatures contains)
featuresOut.tensorAlongDimension(i, 1, 2).permutei(1,0).assign(listFeatures.get(i));
labelsOut.tensorAlongDimension(i, 1, 2).permutei(1,0).assign(listLabels.get(i));
}
} else {
featuresMask = Nd4j.ones(listFeatures.size(),maxLength);
Expand All @@ -171,8 +172,10 @@ private DataSet nextSingleSequenceReader(int num){
INDArray f = listFeatures.get(i);
int tsLength = f.size(0);

featuresOut.tensorAlongDimension(i, 1, 2).put(new INDArrayIndex[]{NDArrayIndex.interval(0, tsLength), NDArrayIndex.all()}, f);
labelsOut.tensorAlongDimension(i, 1, 2).put(new INDArrayIndex[]{NDArrayIndex.interval(0, tsLength), NDArrayIndex.all()}, listLabels.get(i));
featuresOut.tensorAlongDimension(i, 1, 2).permutei(1,0)
.put(new INDArrayIndex[]{NDArrayIndex.interval(0, tsLength), NDArrayIndex.all()}, f);
labelsOut.tensorAlongDimension(i, 1, 2).permutei(1,0)
.put(new INDArrayIndex[]{NDArrayIndex.interval(0, tsLength), NDArrayIndex.all()}, listLabels.get(i));
for( int j=tsLength; j<maxLength; j++ ){
featuresMask.put(i,j,0.0);
labelsMask.put(i,j,0.0);
Expand Down Expand Up @@ -222,8 +225,10 @@ private DataSet nextMultipleSequenceReaders(int num){
featuresOut = Nd4j.create(featureShape,'f');
labelsOut = Nd4j.create(labelShape,'f');
for (int i = 0; i < featureList.size(); i++) {
featuresOut.tensorAlongDimension(i, 1, 2).assign(featureList.get(i));
labelsOut.tensorAlongDimension(i, 1, 2).assign(labelList.get(i));
featuresOut.tensorAlongDimension(i, 1, 2).permutei(1,0)
.assign(featureList.get(i));
labelsOut.tensorAlongDimension(i, 1, 2).permutei(1,0)
.assign(labelList.get(i));
}
} else if( alignmentMode == AlignmentMode.ALIGN_START ){
int longestTimeSeries = 0;
Expand Down Expand Up @@ -251,9 +256,10 @@ private DataSet nextMultipleSequenceReaders(int num){
INDArray f = featureList.get(i);
INDArray l = labelList.get(i);

featuresOut.tensorAlongDimension(i, 1, 2)
//Again, permute is to put [timeSeriesLength,vectorSize] into a [vectorSize,timeSeriesLength] matrix
featuresOut.tensorAlongDimension(i, 1, 2).permutei(1,0)
.put(new INDArrayIndex[]{NDArrayIndex.interval(0, f.size(0)), NDArrayIndex.all()}, f);
labelsOut.tensorAlongDimension(i, 1, 2)
labelsOut.tensorAlongDimension(i, 1, 2).permutei(1,0)
.put(new INDArrayIndex[]{NDArrayIndex.interval(0, l.size(0)), NDArrayIndex.all()}, l);
for( int j=f.size(0); j<longestTimeSeries; j++ ){
featuresMask.putScalar(i,j,0.0);
Expand Down Expand Up @@ -294,9 +300,9 @@ private DataSet nextMultipleSequenceReaders(int num){

if(fLen >= lLen){
//Align labels with end of features (features are longer)
featuresOut.tensorAlongDimension(i, 1, 2)
featuresOut.tensorAlongDimension(i, 1, 2).permutei(1,0)
.put(new INDArrayIndex[]{NDArrayIndex.interval(0, fLen), NDArrayIndex.all()}, f);
labelsOut.tensorAlongDimension(i, 1, 2)
labelsOut.tensorAlongDimension(i, 1, 2).permutei(1,0)
.put(new INDArrayIndex[]{NDArrayIndex.interval(fLen-lLen, fLen), NDArrayIndex.all()}, l);

for( int j=fLen; j<longestTimeSeries; j++ ){
Expand All @@ -312,9 +318,9 @@ private DataSet nextMultipleSequenceReaders(int num){
}
} else {
//Align features with end of labels (labels are longer)
featuresOut.tensorAlongDimension(i, 1, 2)
featuresOut.tensorAlongDimension(i, 1, 2).permutei(1,0)
.put(new INDArrayIndex[]{NDArrayIndex.interval(lLen-fLen, lLen), NDArrayIndex.all()}, f);
labelsOut.tensorAlongDimension(i, 1, 2)
labelsOut.tensorAlongDimension(i, 1, 2).permutei(1,0)
.put(new INDArrayIndex[]{NDArrayIndex.interval(0, lLen), NDArrayIndex.all()}, l);

//features mask: component before features
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -416,24 +416,24 @@ public void testVariableLengthSequence() throws Exception{
expF0.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[]{10, 11, 12}));
expF0.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[]{20, 21, 22}));
expF0.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[]{30, 31, 32}));
assertEquals(dsListAlignStart.get(0).getFeatureMatrix(), expF0);
assertEquals(dsListAlignEnd.get(0).getFeatureMatrix(), expF0);
assertEquals(expF0, dsListAlignStart.get(0).getFeatureMatrix());
assertEquals(expF0, dsListAlignEnd.get(0).getFeatureMatrix());

INDArray expF1 = Nd4j.create(1, 3, 4);
expF1.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[]{100, 101, 102}));
expF1.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[]{110, 111, 112}));
expF1.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[]{120, 121, 122}));
expF1.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[]{130, 131, 132}));
assertEquals(dsListAlignStart.get(1).getFeatureMatrix(), expF1);
assertEquals(dsListAlignEnd.get(1).getFeatureMatrix(), expF1);
assertEquals(expF1, dsListAlignStart.get(1).getFeatureMatrix());
assertEquals(expF1, dsListAlignEnd.get(1).getFeatureMatrix());

INDArray expF2 = Nd4j.create(1, 3, 4);
expF2.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[]{200, 201, 202}));
expF2.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[]{210, 211, 212}));
expF2.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[]{220, 221, 222}));
expF2.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[]{230, 231, 232}));
assertEquals(dsListAlignStart.get(2).getFeatureMatrix(), expF2);
assertEquals(dsListAlignEnd.get(2).getFeatureMatrix(), expF2);
assertEquals(expF2, dsListAlignStart.get(2).getFeatureMatrix());
assertEquals(expF2, dsListAlignEnd.get(2).getFeatureMatrix());

//Check features mask array:
INDArray featuresMaskExpected = Nd4j.ones(1,4); //1 example, 4 values: same for both start/end align here
Expand Down

0 comments on commit 598efe9

Please sign in to comment.