Skip to content

Commit

Permalink
Switch weight init to f order; remove unnecessary labels/params dups …
Browse files Browse the repository at this point in the history
…in MultiLayerNetwork
  • Loading branch information
AlexDBlack committed May 2, 2016
1 parent 844e896 commit b401fa1
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1341,8 +1341,8 @@ public INDArray labelProbabilities(INDArray examples) {
*/
@Override
public void fit(INDArray data, INDArray labels) {
setInput(data.dup());
setLabels(labels.dup());
setInput(data);
setLabels(labels);
update(TaskUtils.buildTask(data, labels));

if (layerWiseConfigurations.isPretrain()) {
Expand Down Expand Up @@ -1375,7 +1375,7 @@ public void fit(INDArray data, INDArray labels) {

@Override
public void fit(INDArray data) {
setInput(data.dup());
setInput(data);
update(TaskUtils.buildTask(data));
pretrain(data);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,37 +74,39 @@ public static INDArray initWeights(int[] shape, float min, float max) {
* @return a matrix of the specified dimensions with the specified
* distribution based on the initialization scheme
*/
public static INDArray initWeights(int[] shape, WeightInit initScheme,
Distribution dist) {
public static INDArray initWeights(int[] shape, WeightInit initScheme, Distribution dist) {

//Note: using f order here as params get flattened to f order

INDArray ret;
switch (initScheme) {
case DISTRIBUTION:
ret = dist.sample(shape);
return ret;
case NORMALIZED:
ret = Nd4j.rand(shape, Nd4j.getRandom());
ret = Nd4j.rand('f', shape);
return ret.subi(0.5).divi(shape[0]);
case RELU:
return Nd4j.randn(shape).muli(FastMath.sqrt(2.0 / shape[0])); //N(0, 2/nIn)
return Nd4j.randn('f',shape).muli(FastMath.sqrt(2.0 / shape[0])); //N(0, 2/nIn)
case SIZE:
return uniformBasedOnInAndOut(shape, shape[0], shape[1]);
case UNIFORM:
double a = 1 / (double) shape[0];
return Nd4j.rand(shape, -a, a, Nd4j.getRandom());
return Nd4j.rand('f',shape).muli(2*a).subi(a);
case VI:
ret = Nd4j.rand(shape, Nd4j.getRandom());
ret = Nd4j.rand('f',shape);
int len = 0;
for (int aShape : shape) {
len += aShape;
}
double r = Math.sqrt(6) / Math.sqrt(len + 1);
ret.muli(2).muli(r).subi(r);
ret.muli(2*r).subi(r);
return ret;
case XAVIER:
ret = Nd4j.randn(shape).divi(FastMath.sqrt(shape[0] + shape[1]));
ret = Nd4j.randn('f',shape).divi(FastMath.sqrt(shape[0] + shape[1]));
return ret;
case ZERO:
return Nd4j.create(shape);
return Nd4j.create(shape,'f');
}

throw new IllegalStateException("Illegal weight init value");
Expand Down

0 comments on commit b401fa1

Please sign in to comment.