Skip to content

Commit

Permalink
Switch RnnOutputLayer reshaping to f order for efficiency
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexDBlack committed May 2, 2016
1 parent 6a19b44 commit 844e896
Showing 1 changed file with 3 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,15 @@ private INDArray reshape3dTo2d(INDArray in){
if(shape[0]==1) return in.tensorAlongDimension(0,1,2).permutei(1,0); //Edge case: miniBatchSize==1
if(shape[2]==1) return in.tensorAlongDimension(0,1,0); //Edge case: timeSeriesLength=1
INDArray permuted = in.permute(0, 2, 1); //Permute, so we get correct order after reshaping
return permuted.reshape(shape[0] * shape[2], shape[1]);
return permuted.reshape('f',shape[0] * shape[2], shape[1]);
}

private INDArray reshape2dTo3d(INDArray in, int miniBatchSize){
if( in.rank() != 2 ) throw new IllegalArgumentException("Invalid input: expect NDArray with rank 2");
//Based on: RnnToFeedForwardPreProcessor
int[] shape = in.shape();
if(in.ordering() == 'f') in = Shape.toOffsetZeroCopy(in, 'c');
INDArray reshaped = in.reshape(miniBatchSize, shape[0] / miniBatchSize, shape[1]);
if(in.ordering() != 'f') in = Shape.toOffsetZeroCopy(in, 'f');
INDArray reshaped = in.reshape('f',miniBatchSize, shape[0] / miniBatchSize, shape[1]);
return reshaped.permute(0, 2, 1);
}

Expand Down

0 comments on commit 844e896

Please sign in to comment.