Skip to content

Commit

Permalink
Merge pull request deeplearning4j#1488 from deeplearning4j/ab_morelstm
Browse files Browse the repository at this point in the history
Additional LSTM (and general) optimizations
  • Loading branch information
AlexDBlack committed May 2, 2016
2 parents 6a19b44 + b401fa1 commit 3996c61
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,15 @@ private INDArray reshape3dTo2d(INDArray in){
if(shape[0]==1) return in.tensorAlongDimension(0,1,2).permutei(1,0); //Edge case: miniBatchSize==1
if(shape[2]==1) return in.tensorAlongDimension(0,1,0); //Edge case: timeSeriesLength=1
INDArray permuted = in.permute(0, 2, 1); //Permute, so we get correct order after reshaping
return permuted.reshape(shape[0] * shape[2], shape[1]);
return permuted.reshape('f',shape[0] * shape[2], shape[1]);
}

private INDArray reshape2dTo3d(INDArray in, int miniBatchSize){
if( in.rank() != 2 ) throw new IllegalArgumentException("Invalid input: expect NDArray with rank 2");
//Based on: RnnToFeedForwardPreProcessor
int[] shape = in.shape();
if(in.ordering() == 'f') in = Shape.toOffsetZeroCopy(in, 'c');
INDArray reshaped = in.reshape(miniBatchSize, shape[0] / miniBatchSize, shape[1]);
if(in.ordering() != 'f') in = Shape.toOffsetZeroCopy(in, 'f');
INDArray reshaped = in.reshape('f',miniBatchSize, shape[0] / miniBatchSize, shape[1]);
return reshaped.permute(0, 2, 1);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1341,8 +1341,8 @@ public INDArray labelProbabilities(INDArray examples) {
*/
@Override
public void fit(INDArray data, INDArray labels) {
setInput(data.dup());
setLabels(labels.dup());
setInput(data);
setLabels(labels);
update(TaskUtils.buildTask(data, labels));

if (layerWiseConfigurations.isPretrain()) {
Expand Down Expand Up @@ -1375,7 +1375,7 @@ public void fit(INDArray data, INDArray labels) {

@Override
public void fit(INDArray data) {
setInput(data.dup());
setInput(data);
update(TaskUtils.buildTask(data));
pretrain(data);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,37 +74,39 @@ public static INDArray initWeights(int[] shape, float min, float max) {
* @return a matrix of the specified dimensions with the specified
* distribution based on the initialization scheme
*/
public static INDArray initWeights(int[] shape, WeightInit initScheme,
Distribution dist) {
public static INDArray initWeights(int[] shape, WeightInit initScheme, Distribution dist) {

//Note: using f order here as params get flattened to f order

INDArray ret;
switch (initScheme) {
case DISTRIBUTION:
ret = dist.sample(shape);
return ret;
case NORMALIZED:
ret = Nd4j.rand(shape, Nd4j.getRandom());
ret = Nd4j.rand('f', shape);
return ret.subi(0.5).divi(shape[0]);
case RELU:
return Nd4j.randn(shape).muli(FastMath.sqrt(2.0 / shape[0])); //N(0, 2/nIn)
return Nd4j.randn('f',shape).muli(FastMath.sqrt(2.0 / shape[0])); //N(0, 2/nIn)
case SIZE:
return uniformBasedOnInAndOut(shape, shape[0], shape[1]);
case UNIFORM:
double a = 1 / (double) shape[0];
return Nd4j.rand(shape, -a, a, Nd4j.getRandom());
return Nd4j.rand('f',shape).muli(2*a).subi(a);
case VI:
ret = Nd4j.rand(shape, Nd4j.getRandom());
ret = Nd4j.rand('f',shape);
int len = 0;
for (int aShape : shape) {
len += aShape;
}
double r = Math.sqrt(6) / Math.sqrt(len + 1);
ret.muli(2).muli(r).subi(r);
ret.muli(2*r).subi(r);
return ret;
case XAVIER:
ret = Nd4j.randn(shape).divi(FastMath.sqrt(shape[0] + shape[1]));
ret = Nd4j.randn('f',shape).divi(FastMath.sqrt(shape[0] + shape[1]));
return ret;
case ZERO:
return Nd4j.create(shape);
return Nd4j.create(shape,'f');
}

throw new IllegalStateException("Illegal weight init value");
Expand Down

0 comments on commit 3996c61

Please sign in to comment.