Skip to content

Commit

Permalink
Merge pull request deeplearning4j#2204 from deeplearning4j/ab_xavier
Browse files Browse the repository at this point in the history
Weight init changes
  • Loading branch information
AlexDBlack authored Oct 21, 2016
2 parents 04f8c76 + 21a2562 commit 38c7fd8
Show file tree
Hide file tree
Showing 15 changed files with 137 additions and 137 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ public void testCifarDataSetIteratorReset() {

MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
net.setListeners(Arrays.asList((IterationListener) new ScoreIterationListener(1)));
net.setListeners(new ScoreIterationListener(1));

MultipleEpochsIterator ds = new MultipleEpochsIterator(epochs, new CifarDataSetIterator(10,20, new int[]{20,20,1}));
net.fit(ds);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
Expand Down Expand Up @@ -239,6 +240,7 @@ public void testCnnWithSubsampling(){
.regularization(false)
.learningRate(1.0)
.updater(Updater.SGD)
.weightInit(WeightInit.DISTRIBUTION).dist(new NormalDistribution(0,1))
.list()
.layer(0, new ConvolutionLayer.Builder(kernel, stride, padding)
.nIn(inputDepth).nOut(3)
Expand Down Expand Up @@ -299,6 +301,7 @@ public void testCnnWithSubsamplingV2(){
.regularization(false)
.learningRate(1.0)
.updater(Updater.SGD)
.weightInit(WeightInit.DISTRIBUTION).dist(new NormalDistribution(0,1))
.list()
.layer(0, new ConvolutionLayer.Builder(kernel, stride, padding)
.nIn(inputDepth).nOut(3)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -616,13 +616,13 @@ public void testGradientCnnFfRnn() {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.updater(Updater.NONE)
.seed(12345)
.weightInit(WeightInit.DISTRIBUTION).dist(new NormalDistribution(0,1))
.list()
.layer(0, new ConvolutionLayer.Builder(5, 5)
.nIn(3)
.nOut(5)
.stride(1, 1)
.activation("tanh")
.weightInit(WeightInit.XAVIER)
.build()) //Out: (10-5)/1+1 = 6 -> 6x6x5
.layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
.kernelSize(2, 2)
Expand All @@ -631,13 +631,11 @@ public void testGradientCnnFfRnn() {
.layer(2, new DenseLayer.Builder()
.nIn(5 * 5 * 5)
.nOut(4)
.weightInit(WeightInit.XAVIER)
.activation("tanh")
.build())
.layer(3, new GravesLSTM.Builder()
.nIn(4)
.nOut(3)
.weightInit(WeightInit.XAVIER)
.activation("tanh")
.build())
.layer(4, new RnnOutputLayer.Builder()
Expand Down Expand Up @@ -797,7 +795,7 @@ public void testAutoEncoder() {
.l2(l2).l1(l1)
.optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT)
.seed(12345L)
.weightInit(WeightInit.XAVIER)
.weightInit(WeightInit.DISTRIBUTION).dist(new NormalDistribution(0,1))
.updater(Updater.SGD)
.list()
.layer(0, new AutoEncoder.Builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -470,17 +470,17 @@ public void testPreTraining(){
.nIn(4).nOut(3)
.weightInit(WeightInit.DISTRIBUTION).dist(new UniformDistribution(0, 1))
.activation("tanh")
.lossFunction(LossFunctions.LossFunction.RMSE_XENT).build(), "in")
.lossFunction(LossFunctions.LossFunction.KL_DIVERGENCE).build(), "in")
.addLayer("layer1", new RBM.Builder(RBM.HiddenUnit.GAUSSIAN, RBM.VisibleUnit.GAUSSIAN)
.nIn(4).nOut(3)
.weightInit(WeightInit.DISTRIBUTION).dist(new UniformDistribution(0, 1))
.activation("tanh")
.lossFunction(LossFunctions.LossFunction.RMSE_XENT).build(), "in")
.lossFunction(LossFunctions.LossFunction.KL_DIVERGENCE).build(), "in")
.addLayer("layer2", new RBM.Builder(RBM.HiddenUnit.GAUSSIAN, RBM.VisibleUnit.GAUSSIAN)
.nIn(3).nOut(3)
.weightInit(WeightInit.DISTRIBUTION).dist(new UniformDistribution(0, 1))
.activation("tanh")
.lossFunction(LossFunctions.LossFunction.RMSE_XENT).build(),"layer1")
.lossFunction(LossFunctions.LossFunction.KL_DIVERGENCE).build(),"layer1")
.addLayer("out", new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.nIn(3+3).nOut(3)
.weightInit(WeightInit.DISTRIBUTION).dist(new UniformDistribution(0, 1))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ public void testLfw() throws Exception {
.layer(new org.deeplearning4j.nn.conf.layers.RBM.Builder(org.deeplearning4j.nn.conf.layers.RBM.HiddenUnit.RECTIFIED, org.deeplearning4j.nn.conf.layers.RBM.VisibleUnit.GAUSSIAN)
.nIn(d.numInputs()).nOut(nOut)
.weightInit(WeightInit.VI)
.lossFunction(LossFunctions.LossFunction.RMSE_XENT)
.lossFunction(LossFunctions.LossFunction.KL_DIVERGENCE)
.build())
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.learningRate(1e-3f)
Expand All @@ -105,7 +105,7 @@ public void testIrisGaussianHidden() {
.layer(new org.deeplearning4j.nn.conf.layers.RBM.Builder(
org.deeplearning4j.nn.conf.layers.RBM.HiddenUnit.GAUSSIAN, org.deeplearning4j.nn.conf.layers.RBM.VisibleUnit.GAUSSIAN)
.nIn(d.numInputs()).nOut(3)
.lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
.lossFunction(LossFunctions.LossFunction.KL_DIVERGENCE).build())
.build();

int numParams = conf.getLayer().initializer().numParams(conf,true);
Expand All @@ -127,7 +127,7 @@ public void testIris() {
.learningRate(1e-1f)
.layer(new org.deeplearning4j.nn.conf.layers.RBM.Builder(org.deeplearning4j.nn.conf.layers.RBM.HiddenUnit.RECTIFIED, org.deeplearning4j.nn.conf.layers.RBM.VisibleUnit.GAUSSIAN)
.nIn(d.numInputs()).nOut(3)
.lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
.lossFunction(LossFunctions.LossFunction.KL_DIVERGENCE).build())
.build();

int numParams = conf.getLayer().initializer().numParams(conf,true);
Expand Down Expand Up @@ -157,7 +157,7 @@ public void testBasic() {
.learningRate(1e-1f)
.layer(new org.deeplearning4j.nn.conf.layers.RBM.Builder()
.nIn(6).nOut(4)
.lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
.lossFunction(LossFunctions.LossFunction.KL_DIVERGENCE).build())
.build();

int numParams = conf.getLayer().initializer().numParams(conf,true);
Expand Down Expand Up @@ -208,7 +208,7 @@ public void testSetGetParams() {
.learningRate(1e-1f)
.layer(new org.deeplearning4j.nn.conf.layers.RBM.Builder()
.nIn(6).nOut(4)
.lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
.lossFunction(LossFunctions.LossFunction.KL_DIVERGENCE).build())
.build();

int numParams = conf.getLayer().initializer().numParams(conf,true);
Expand Down Expand Up @@ -241,7 +241,7 @@ public void testCg() {
.learningRate(1e-1f)
.layer(new org.deeplearning4j.nn.conf.layers.RBM.Builder()
.nIn(6).nOut(4)
.lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
.lossFunction(LossFunctions.LossFunction.KL_DIVERGENCE).build())
.build();

int numParams = conf.getLayer().initializer().numParams(conf,true);
Expand Down Expand Up @@ -276,7 +276,7 @@ public void testGradient() {
.learningRate(1e-1f)
.layer(new org.deeplearning4j.nn.conf.layers.RBM.Builder()
.nIn(6).nOut(4)
.lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
.lossFunction(LossFunctions.LossFunction.KL_DIVERGENCE).build())
.build();

int numParams = conf.getLayer().initializer().numParams(conf,true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ public void testDbn() throws Exception {
.nIn(4).nOut(3)
.weightInit(WeightInit.DISTRIBUTION).dist(new UniformDistribution(0, 1))
.activation("tanh")
.lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
.lossFunction(LossFunctions.LossFunction.KL_DIVERGENCE).build())
.layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.nIn(3).nOut(3)
.weightInit(WeightInit.DISTRIBUTION).dist(new UniformDistribution(0, 1))
Expand Down Expand Up @@ -655,6 +655,7 @@ public void testPredict() throws Exception{
.layer(0, new DenseLayer.Builder().nIn(400).nOut(50).activation("relu").build())
.layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation("softmax").nIn(50).nOut(10).build())
.pretrain(false).backprop(true)
.setInputType(InputType.convolutional(20,20,1))
.build();

MultiLayerNetwork net = new MultiLayerNetwork(conf);
Expand Down Expand Up @@ -698,6 +699,7 @@ public void testOutput() throws Exception{
.layer(0, new DenseLayer.Builder().nIn(400).nOut(50).activation("relu").build())
.layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation("softmax").nIn(50).nOut(10).build())
.pretrain(false).backprop(true)
.setInputType(InputType.convolutional(20,20,1))
.build();

MultiLayerNetwork net = new MultiLayerNetwork(conf);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
* Created by nyghtowl on 11/14/15.
*/
public class WeightInitUtilTest {
protected int[] shape = new int[]{2, 2};
protected int fanIn = 3;
protected int fanOut = 2;
protected int[] shape = new int[]{fanIn, fanOut};
protected Distribution dist = Distributions.createDistribution(new GaussianDistribution(0.0, 0.1));

@Before
Expand All @@ -26,7 +28,7 @@ public void doBefore(){
@Test
public void testDistribution(){
INDArray params = Nd4j.create(shape,'f');
INDArray weightsActual = WeightInitUtil.initWeights(shape, WeightInit.DISTRIBUTION, dist, params);
INDArray weightsActual = WeightInitUtil.initWeights(-1, -1, shape, WeightInit.DISTRIBUTION, dist, params); //fan in/out not used

// expected calculation
Nd4j.getRandom().setSeed(123);
Expand All @@ -38,7 +40,7 @@ public void testDistribution(){
@Test
public void testNormalize(){
INDArray params = Nd4j.create(shape,'f');
INDArray weightsActual = WeightInitUtil.initWeights(shape, WeightInit.NORMALIZED, dist, params);
INDArray weightsActual = WeightInitUtil.initWeights(fanIn, fanOut, shape, WeightInit.NORMALIZED, dist, params);

// expected calculation
Nd4j.getRandom().setSeed(123);
Expand All @@ -51,19 +53,19 @@ public void testNormalize(){
@Test
public void testRelu(){
INDArray params = Nd4j.create(shape,'f');
INDArray weightsActual = WeightInitUtil.initWeights(shape, WeightInit.RELU, dist,params);
INDArray weightsActual = WeightInitUtil.initWeights(fanIn, fanOut, shape, WeightInit.RELU, dist,params);

// expected calculation
Nd4j.getRandom().setSeed(123);
INDArray weightsExpected = Nd4j.randn('f',shape).muli(FastMath.sqrt(2.0 / shape[0]));
INDArray weightsExpected = Nd4j.randn('f',shape).muli(FastMath.sqrt(2.0 / fanIn));

assertEquals(weightsExpected, weightsActual);
}

@Test
public void testSize(){
public void testSigmoidUniform(){
INDArray params = Nd4j.create(shape,'f');
INDArray weightsActual = WeightInitUtil.initWeights(shape, WeightInit.SIZE, dist, params);
INDArray weightsActual = WeightInitUtil.initWeights(fanIn, fanOut, shape, WeightInit.SIGMOID_UNIFORM, dist, params);

// expected calculation
Nd4j.getRandom().setSeed(123);
Expand All @@ -77,74 +79,59 @@ public void testSize(){
@Test
public void testUniform(){
INDArray params = Nd4j.create(shape,'f');
INDArray weightsActual = WeightInitUtil.initWeights(shape, WeightInit.UNIFORM, dist, params);
INDArray weightsActual = WeightInitUtil.initWeights(fanIn, fanOut, shape, WeightInit.UNIFORM, dist, params);

// expected calculation
Nd4j.getRandom().setSeed(123);
double a = 1/(double) shape[0];
INDArray weightsExpected = Nd4j.rand('f',shape).muli(2*a).subi(a);

assertEquals(weightsExpected, weightsActual);
}

@Test
public void testVI(){
INDArray params = Nd4j.create(shape,'f');
INDArray weightsActual = WeightInitUtil.initWeights(shape, WeightInit.VI, dist, params);

// expected calculation
Nd4j.getRandom().setSeed(123);
INDArray weightsExpected = Nd4j.rand('f',shape);
int numValues = shape[0] + shape[1];
double r = Math.sqrt(6) / Math.sqrt(numValues + 1);
weightsExpected.muli(2).muli(r).subi(r);
double a = 1.0/Math.sqrt(fanIn);
INDArray weightsExpected = Nd4j.rand(shape,Nd4j.getDistributions().createUniform(-a,a));

assertEquals(weightsExpected, weightsActual);
}

@Test
public void testXavier(){
INDArray params = Nd4j.create(shape,'f');
INDArray weightsActual = WeightInitUtil.initWeights(shape, WeightInit.XAVIER, dist, params);
INDArray weightsActual = WeightInitUtil.initWeights(fanIn, fanOut, shape, WeightInit.XAVIER, dist, params);

// expected calculation
Nd4j.getRandom().setSeed(123);
INDArray weightsExpected = Nd4j.randn('f',shape);
weightsExpected.divi(FastMath.sqrt(shape[0] + shape[1]));
weightsExpected.divi(FastMath.sqrt(2.0 / (fanIn + fanOut)));

assertEquals(weightsExpected, weightsActual);
}

@Test
public void testXavierCaffe(){
public void testXavierFanIn(){
INDArray params = Nd4j.create(shape,'f');
INDArray weightsActual = WeightInitUtil.initWeights(shape, WeightInit.XAVIER_CAFFE, dist, params);
INDArray weightsActual = WeightInitUtil.initWeights(fanIn, fanOut, shape, WeightInit.XAVIER_FAN_IN, dist, params);

// expected calculation
Nd4j.getRandom().setSeed(123);
INDArray weightsExpected = Nd4j.randn('f',shape);
weightsExpected.divi(FastMath.sqrt(shape[0]));
weightsExpected.divi(FastMath.sqrt(fanIn));

assertEquals(weightsExpected, weightsActual);
}

@Test
public void testXavierTorch(){
public void testXavierLegacy(){
INDArray params = Nd4j.create(shape,'f');
INDArray weightsActual = WeightInitUtil.initWeights(shape, WeightInit.XAVIER_TORCH, dist, params);
INDArray weightsActual = WeightInitUtil.initWeights(fanIn, fanOut, shape, WeightInit.XAVIER_LEGACY, dist, params);

// expected calculation
Nd4j.getRandom().setSeed(123);
INDArray weightsExpected = Nd4j.randn('f',shape);
weightsExpected.muli(FastMath.sqrt(2.0 / (shape[0] + shape[1])));
weightsExpected.muli(FastMath.sqrt(1.0 / (fanIn + fanOut)));

assertEquals(weightsExpected, weightsActual);
}

@Test
public void testZero(){
INDArray params = Nd4j.create(shape,'f');
INDArray weightsActual = WeightInitUtil.initWeights(shape, WeightInit.ZERO, dist, params);
INDArray weightsActual = WeightInitUtil.initWeights(fanIn, fanOut, shape, WeightInit.ZERO, dist, params);

// expected calculation
INDArray weightsExpected = Nd4j.create(shape,'f');
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ private static MultiLayerConfiguration getIrisMultiLayerConfig(String activation
.optimizationAlgo(optimizer)
.iterations(iterations)
.miniBatch(false).momentum(0.9)
.learningRate(0.1).updater(Updater.NESTEROVS)
.learningRate(0.01).updater(Updater.NESTEROVS)
.seed(12345L)
.list()
.layer(0, new DenseLayer.Builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public void testPlotter() throws Exception {
.nIn(784).nOut(600)
.weightInit(WeightInit.DISTRIBUTION).dist(new NormalDistribution(1e-3, 1e-1))
.dropOut(0.5)
.lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
.lossFunction(LossFunctions.LossFunction.KL_DIVERGENCE).build())
.build();


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,17 @@ protected INDArray createWeightMatrix(NeuralNetConfiguration conf, INDArray weig
if(initializeParams) {
Distribution dist = Distributions.createDistribution(conf.getLayer().getDist());
int[] kernel = layerConf.getKernelSize();
return WeightInitUtil.initWeights(new int[]{layerConf.getNOut(), layerConf.getNIn(), kernel[0], kernel[1]},
int[] stride = layerConf.getStride();

int inputDepth = layerConf.getNIn();
int outputDepth = layerConf.getNOut();

double fanIn = inputDepth * kernel[0] * kernel[1];
double fanOut = outputDepth * kernel[0] * kernel[1] / ((double) stride[0] * stride[1]);

int[] weightsShape = new int[]{outputDepth, inputDepth, kernel[0], kernel[1]};

return WeightInitUtil.initWeights(fanIn, fanOut, weightsShape,
layerConf.getWeightInit(), dist, 'c', weightView);
} else {
int[] kernel = layerConf.getKernelSize();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,16 @@ protected INDArray createWeightMatrix(NeuralNetConfiguration conf, INDArray weig
(org.deeplearning4j.nn.conf.layers.FeedForwardLayer) conf.getLayer();

if(initializeParameters) {

int nIn = layerConf.getNIn();
int nOut = layerConf.getNOut();
int[] shape = new int[]{nIn,nOut};

Distribution dist = Distributions.createDistribution(layerConf.getDist());
INDArray ret = WeightInitUtil.initWeights(
layerConf.getNIn(),
layerConf.getNOut(),
nIn, //Fan in
nOut, //Fan out
shape,
layerConf.getWeightInit(),
dist,
weightParamView);
Expand Down
Loading

0 comments on commit 38c7fd8

Please sign in to comment.