Skip to content

Commit

Permalink
Switch record readers to use new putScalar methods
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexDBlack committed Apr 29, 2016
1 parent 6a0efff commit 940b24b
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -203,25 +203,23 @@ private INDArray convertWritables(List<Collection<Writable>> list, int minValues
else if(details.oneHot) arr = Nd4j.zeros(minValues, details.oneHotNumClasses);
else arr = Nd4j.create(minValues, details.subsetEndInclusive-details.subsetStart + 1);

int[] idx = new int[2];
for( int i=0; i<minValues; i++){
idx[0] = i;
Collection<Writable> c = list.get(i);
if(details.entireReader) {
//Convert entire reader contents, without modification
int j = 0;
for (Writable w : c) {
idx[1] = j++;
try {
arr.putScalar(idx, w.toDouble());
arr.putScalar(i,j, w.toDouble());
} catch (UnsupportedOperationException e) {
// This isn't a scalar, so check if we got an array already
if (w instanceof NDArrayWritable) {
arr.putRow(idx[0], ((NDArrayWritable)w).get());
arr.putRow(i, ((NDArrayWritable)w).get());
} else {
throw e;
}
}
j++;
}
} else if(details.oneHot){
//Convert a single column to a one-hot representation
Expand All @@ -231,27 +229,27 @@ private INDArray convertWritables(List<Collection<Writable>> list, int minValues
Iterator<Writable> iter = c.iterator();
for( int k=0; k<=details.subsetStart; k++ ) w = iter.next();
}
idx[1] = w.toInt(); //Index of class
arr.putScalar(idx,1.0);
//Index of class
arr.putScalar(i,w.toInt(),1.0);
} else {
//Convert a subset of the columns
Iterator<Writable> iter = c.iterator();
for( int j=0; j<details.subsetStart; j++ ) iter.next();
int k=0;
for( int j=details.subsetStart; j<=details.subsetEndInclusive; j++){
idx[1] = k++;
Writable w = iter.next();
try {
arr.putScalar(idx,w.toDouble());
arr.putScalar(i,k,w.toDouble());
} catch (UnsupportedOperationException e) {
// This isn't a scalar, so check if we got an array already
if (w instanceof NDArrayWritable) {
arr.putRow(idx[0], ((NDArrayWritable)w).get().get(NDArrayIndex.all(),
arr.putRow(i, ((NDArrayWritable)w).get().get(NDArrayIndex.all(),
NDArrayIndex.interval(details.subsetStart, details.subsetEndInclusive + 1)));
} else {
throw e;
}
}
k++;
}
}
}
Expand Down Expand Up @@ -279,11 +277,7 @@ private Pair<INDArray,INDArray> convertWritablesSequence(List<Collection<Collect
if(needMaskArray) maskArray = Nd4j.ones(minValues,maxTSLength);
else maskArray = null;


int[] idx = new int[3];
int[] maskIdx = new int[2];
for( int i=0; i<minValues; i++ ){
idx[0] = i;
Collection<Collection<Writable>> sequence = list.get(i);

//Offset for alignment:
Expand All @@ -297,53 +291,52 @@ private Pair<INDArray,INDArray> convertWritablesSequence(List<Collection<Collect
}

int t=0;
int k;
for (Collection<Writable> timeStep : sequence) {
idx[2] = startOffset + t++;
k = startOffset + t++;

if(details.entireReader) {
//Convert entire reader contents, without modification
Iterator<Writable> iter = timeStep.iterator();
int j = 0;
while (iter.hasNext()) {
idx[1] = j++;
Writable w = iter.next();
try {
arr.putScalar(idx,w.toDouble());
arr.putScalar(i,j,k,w.toDouble());
} catch (UnsupportedOperationException e) {
// This isn't a scalar, so check if we got an array already
if (w instanceof NDArrayWritable) {
arr.get(NDArrayIndex.point(idx[0]), NDArrayIndex.all(), NDArrayIndex.point(idx[2]))
arr.get(NDArrayIndex.point(i), NDArrayIndex.all(), NDArrayIndex.point(k))
.putRow(0, ((NDArrayWritable)w).get());
} else {
throw e;
}
}
j++;
}
} else if(details.oneHot){
//Convert a single column to a one-hot representation
Writable w = null;
if(timeStep instanceof List) w = ((List<Writable>)timeStep).get(details.subsetStart);
else{
Iterator<Writable> iter = timeStep.iterator();
for( int k=0; k<=details.subsetStart; k++ ) w = iter.next();
for( int x=0; x<=details.subsetStart; x++ ) w = iter.next();
}
int classIdx = w.toInt();
idx[1] = classIdx;
arr.putScalar(idx,1.0);
arr.putScalar(i,classIdx,k,1.0);
} else {
//Convert a subset of the columns...
Iterator<Writable> iter = timeStep.iterator();
for( int j=0; j<details.subsetStart; j++ ) iter.next();
int k=0;
int l = 0;
for( int j=details.subsetStart; j<=details.subsetEndInclusive; j++){
idx[1] = k++;
Writable w = iter.next();
try {
arr.putScalar(idx,w.toDouble());
arr.putScalar(i,l++,k,w.toDouble());
} catch (UnsupportedOperationException e) {
// This isn't a scalar, so check if we got an array already
if (w instanceof NDArrayWritable) {
arr.get(NDArrayIndex.point(idx[0]), NDArrayIndex.all(), NDArrayIndex.point(idx[2]))
arr.get(NDArrayIndex.point(i), NDArrayIndex.all(), NDArrayIndex.point(k))
.putRow(0, ((NDArrayWritable)w).get().get(NDArrayIndex.all(),
NDArrayIndex.interval(details.subsetStart, details.subsetEndInclusive + 1)));
} else {
Expand All @@ -356,20 +349,17 @@ private Pair<INDArray,INDArray> convertWritablesSequence(List<Collection<Collect

//For any remaining time steps: set mask array to 0 (just padding)
if(needMaskArray){
maskIdx[0] = i;
//Masking array entries at start (for align end)
if(alignmentMode == AlignmentMode.ALIGN_END) {
for (int t2 = 0; t2 < startOffset; t2++) {
maskIdx[1] = t2;
maskArray.putScalar(maskIdx, 0.0);
maskArray.putScalar(i,t2, 0.0);
}
}

//Masking array entries at end (for align start)
if(alignmentMode == AlignmentMode.ALIGN_START) {
for (int t2 = t; t2 < maxTSLength; t2++) {
maskIdx[1] = t2;
maskArray.putScalar(maskIdx, 0.0);
maskArray.putScalar(i,t2, 0.0);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,6 @@ private DataSet nextMultipleSequenceReaders(int num){
labelsOut = Nd4j.create(labelsShape,'f');
featuresMask = Nd4j.ones(featureList.size(),longestTimeSeries);
labelsMask = Nd4j.ones(labelList.size(),longestTimeSeries);
int[] temp = new int[2];
for (int i = 0; i < featureList.size(); i++) {
INDArray f = featureList.get(i);
INDArray l = labelList.get(i);
Expand All @@ -256,14 +255,11 @@ private DataSet nextMultipleSequenceReaders(int num){
.put(new INDArrayIndex[]{NDArrayIndex.interval(0, f.size(0)), NDArrayIndex.all()}, f);
labelsOut.tensorAlongDimension(i, 1, 2)
.put(new INDArrayIndex[]{NDArrayIndex.interval(0, l.size(0)), NDArrayIndex.all()}, l);
temp[0] = i;
for( int j=f.size(0); j<longestTimeSeries; j++ ){
temp[1] = j;
featuresMask.putScalar(temp,0.0);
featuresMask.putScalar(i,j,0.0);
}
for( int j=l.size(0); j<longestTimeSeries; j++ ){
temp[1] = j;
labelsMask.putScalar(temp,0.0);
labelsMask.putScalar(i,j,0.0);
}
}
} else if( alignmentMode == AlignmentMode.ALIGN_END ){ //Align at end
Expand All @@ -289,14 +285,12 @@ private DataSet nextMultipleSequenceReaders(int num){
labelsOut = Nd4j.create(labelsShape,'f');
featuresMask = Nd4j.ones(featureList.size(), longestTimeSeries);
labelsMask = Nd4j.ones(labelList.size(), longestTimeSeries);
int[] temp = new int[2];
for (int i = 0; i < featureList.size(); i++) {
INDArray f = featureList.get(i);
INDArray l = labelList.get(i);

int fLen = f.size(0);
int lLen = l.size(0);
temp[0] = i;

if(fLen >= lLen){
//Align labels with end of features (features are longer)
Expand All @@ -306,18 +300,15 @@ private DataSet nextMultipleSequenceReaders(int num){
.put(new INDArrayIndex[]{NDArrayIndex.interval(fLen-lLen, fLen), NDArrayIndex.all()}, l);

for( int j=fLen; j<longestTimeSeries; j++ ){
temp[1] = j;
featuresMask.putScalar(temp,0.0);
featuresMask.putScalar(i,j,0.0);
}
//labels mask: component before labels
for( int j=0; j<fLen-lLen; j++ ){
temp[1] = j;
labelsMask.putScalar(temp,0.0);
labelsMask.putScalar(i,j,0.0);
}
//labels mask: component after labels
for( int j=fLen; j<longestTimeSeries; j++ ){
temp[1] = j;
labelsMask.putScalar(temp,0.0);
labelsMask.putScalar(i,j,0.0);
}
} else {
//Align features with end of labels (labels are longer)
Expand All @@ -328,19 +319,16 @@ private DataSet nextMultipleSequenceReaders(int num){

//features mask: component before features
for( int j=0; j<lLen-fLen; j++ ){
temp[1] = j;
featuresMask.putScalar(temp,0.0);
featuresMask.putScalar(i,j,0.0);
}
//features mask: component after features
for( int j=lLen; j<longestTimeSeries; j++ ){
temp[1] = j;
featuresMask.putScalar(temp,0.0);
featuresMask.putScalar(i,j,0.0);
}

//labels mask
for( int j=lLen; j<longestTimeSeries; j++ ){
temp[1] = j;
labelsMask.putScalar(temp,0.0);
labelsMask.putScalar(i,j,0.0);
}
}
}
Expand Down

0 comments on commit 940b24b

Please sign in to comment.