From d27bd78693c0cb64149ebd729e3f7e995d8fd6e0 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Sat, 30 Apr 2016 20:19:48 +1000 Subject: [PATCH] Fix for LSTM edge case (minibatch size of 1) after recent TAD changes --- .../org/deeplearning4j/nn/layers/recurrent/RnnOutputLayer.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deeplearning4j-core/src/main/java/org/deeplearning4j/nn/layers/recurrent/RnnOutputLayer.java b/deeplearning4j-core/src/main/java/org/deeplearning4j/nn/layers/recurrent/RnnOutputLayer.java index 9c5595315bcf..4c0646338b6e 100644 --- a/deeplearning4j-core/src/main/java/org/deeplearning4j/nn/layers/recurrent/RnnOutputLayer.java +++ b/deeplearning4j-core/src/main/java/org/deeplearning4j/nn/layers/recurrent/RnnOutputLayer.java @@ -51,7 +51,7 @@ public RnnOutputLayer(NeuralNetConfiguration conf, INDArray input) { private INDArray reshape3dTo2d(INDArray in){ if( in.rank() != 3 ) throw new IllegalArgumentException("Invalid input: expect NDArray with rank 3"); int[] shape = in.shape(); - if(shape[0]==1) return in.tensorAlongDimension(0,1,2); //Edge case: miniBatchSize==1 + if(shape[0]==1) return in.tensorAlongDimension(0,1,2).permutei(1,0); //Edge case: miniBatchSize==1 if(shape[2]==1) return in.tensorAlongDimension(0,1,0); //Edge case: timeSeriesLength=1 INDArray permuted = in.permute(0, 2, 1); //Permute, so we get correct order after reshaping return permuted.reshape(shape[0] * shape[2], shape[1]);