diff --git a/deeplearning4j-core/src/main/java/org/deeplearning4j/nn/layers/recurrent/RnnOutputLayer.java b/deeplearning4j-core/src/main/java/org/deeplearning4j/nn/layers/recurrent/RnnOutputLayer.java index 9c5595315bcf..4c0646338b6e 100644 --- a/deeplearning4j-core/src/main/java/org/deeplearning4j/nn/layers/recurrent/RnnOutputLayer.java +++ b/deeplearning4j-core/src/main/java/org/deeplearning4j/nn/layers/recurrent/RnnOutputLayer.java @@ -51,7 +51,7 @@ public RnnOutputLayer(NeuralNetConfiguration conf, INDArray input) { private INDArray reshape3dTo2d(INDArray in){ if( in.rank() != 3 ) throw new IllegalArgumentException("Invalid input: expect NDArray with rank 3"); int[] shape = in.shape(); - if(shape[0]==1) return in.tensorAlongDimension(0,1,2); //Edge case: miniBatchSize==1 + if(shape[0]==1) return in.tensorAlongDimension(0,1,2).permutei(1,0); //Edge case: miniBatchSize==1 if(shape[2]==1) return in.tensorAlongDimension(0,1,0); //Edge case: timeSeriesLength=1 INDArray permuted = in.permute(0, 2, 1); //Permute, so we get correct order after reshaping return permuted.reshape(shape[0] * shape[2], shape[1]);