Skip to content

Commit 38c7fd8

Browse files
authored
Merge pull request deeplearning4j#2204 from deeplearning4j/ab_xavier
Weight init changes
2 parents 04f8c76 + 21a2562 commit 38c7fd8

File tree

15 files changed

+137
-137
lines changed

15 files changed

+137
-137
lines changed

deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/MultipleEpochsIteratorTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ public void testCifarDataSetIteratorReset() {
104104

105105
MultiLayerNetwork net = new MultiLayerNetwork(conf);
106106
net.init();
107-
net.setListeners(Arrays.asList((IterationListener) new ScoreIterationListener(1)));
107+
net.setListeners(new ScoreIterationListener(1));
108108

109109
MultipleEpochsIterator ds = new MultipleEpochsIterator(epochs, new CifarDataSetIterator(10,20, new int[]{20,20,1}));
110110
net.fit(ds);

deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
66
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
77
import org.deeplearning4j.nn.conf.Updater;
8+
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
89
import org.deeplearning4j.nn.conf.inputs.InputType;
910
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
1011
import org.deeplearning4j.nn.conf.layers.OutputLayer;
@@ -239,6 +240,7 @@ public void testCnnWithSubsampling(){
239240
.regularization(false)
240241
.learningRate(1.0)
241242
.updater(Updater.SGD)
243+
.weightInit(WeightInit.DISTRIBUTION).dist(new NormalDistribution(0,1))
242244
.list()
243245
.layer(0, new ConvolutionLayer.Builder(kernel, stride, padding)
244246
.nIn(inputDepth).nOut(3)
@@ -299,6 +301,7 @@ public void testCnnWithSubsamplingV2(){
299301
.regularization(false)
300302
.learningRate(1.0)
301303
.updater(Updater.SGD)
304+
.weightInit(WeightInit.DISTRIBUTION).dist(new NormalDistribution(0,1))
302305
.list()
303306
.layer(0, new ConvolutionLayer.Builder(kernel, stride, padding)
304307
.nIn(inputDepth).nOut(3)

deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTests.java

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -616,13 +616,13 @@ public void testGradientCnnFfRnn() {
616616
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
617617
.updater(Updater.NONE)
618618
.seed(12345)
619+
.weightInit(WeightInit.DISTRIBUTION).dist(new NormalDistribution(0,1))
619620
.list()
620621
.layer(0, new ConvolutionLayer.Builder(5, 5)
621622
.nIn(3)
622623
.nOut(5)
623624
.stride(1, 1)
624625
.activation("tanh")
625-
.weightInit(WeightInit.XAVIER)
626626
.build()) //Out: (10-5)/1+1 = 6 -> 6x6x5
627627
.layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
628628
.kernelSize(2, 2)
@@ -631,13 +631,11 @@ public void testGradientCnnFfRnn() {
631631
.layer(2, new DenseLayer.Builder()
632632
.nIn(5 * 5 * 5)
633633
.nOut(4)
634-
.weightInit(WeightInit.XAVIER)
635634
.activation("tanh")
636635
.build())
637636
.layer(3, new GravesLSTM.Builder()
638637
.nIn(4)
639638
.nOut(3)
640-
.weightInit(WeightInit.XAVIER)
641639
.activation("tanh")
642640
.build())
643641
.layer(4, new RnnOutputLayer.Builder()
@@ -797,7 +795,7 @@ public void testAutoEncoder() {
797795
.l2(l2).l1(l1)
798796
.optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT)
799797
.seed(12345L)
800-
.weightInit(WeightInit.XAVIER)
798+
.weightInit(WeightInit.DISTRIBUTION).dist(new NormalDistribution(0,1))
801799
.updater(Updater.SGD)
802800
.list()
803801
.layer(0, new AutoEncoder.Builder()

deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/TestComputationGraphNetwork.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -470,17 +470,17 @@ public void testPreTraining(){
470470
.nIn(4).nOut(3)
471471
.weightInit(WeightInit.DISTRIBUTION).dist(new UniformDistribution(0, 1))
472472
.activation("tanh")
473-
.lossFunction(LossFunctions.LossFunction.RMSE_XENT).build(), "in")
473+
.lossFunction(LossFunctions.LossFunction.KL_DIVERGENCE).build(), "in")
474474
.addLayer("layer1", new RBM.Builder(RBM.HiddenUnit.GAUSSIAN, RBM.VisibleUnit.GAUSSIAN)
475475
.nIn(4).nOut(3)
476476
.weightInit(WeightInit.DISTRIBUTION).dist(new UniformDistribution(0, 1))
477477
.activation("tanh")
478-
.lossFunction(LossFunctions.LossFunction.RMSE_XENT).build(), "in")
478+
.lossFunction(LossFunctions.LossFunction.KL_DIVERGENCE).build(), "in")
479479
.addLayer("layer2", new RBM.Builder(RBM.HiddenUnit.GAUSSIAN, RBM.VisibleUnit.GAUSSIAN)
480480
.nIn(3).nOut(3)
481481
.weightInit(WeightInit.DISTRIBUTION).dist(new UniformDistribution(0, 1))
482482
.activation("tanh")
483-
.lossFunction(LossFunctions.LossFunction.RMSE_XENT).build(),"layer1")
483+
.lossFunction(LossFunctions.LossFunction.KL_DIVERGENCE).build(),"layer1")
484484
.addLayer("out", new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
485485
.nIn(3+3).nOut(3)
486486
.weightInit(WeightInit.DISTRIBUTION).dist(new UniformDistribution(0, 1))

deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/feedforward/rbm/RBMTests.java

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ public void testLfw() throws Exception {
8080
.layer(new org.deeplearning4j.nn.conf.layers.RBM.Builder(org.deeplearning4j.nn.conf.layers.RBM.HiddenUnit.RECTIFIED, org.deeplearning4j.nn.conf.layers.RBM.VisibleUnit.GAUSSIAN)
8181
.nIn(d.numInputs()).nOut(nOut)
8282
.weightInit(WeightInit.VI)
83-
.lossFunction(LossFunctions.LossFunction.RMSE_XENT)
83+
.lossFunction(LossFunctions.LossFunction.KL_DIVERGENCE)
8484
.build())
8585
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
8686
.learningRate(1e-3f)
@@ -105,7 +105,7 @@ public void testIrisGaussianHidden() {
105105
.layer(new org.deeplearning4j.nn.conf.layers.RBM.Builder(
106106
org.deeplearning4j.nn.conf.layers.RBM.HiddenUnit.GAUSSIAN, org.deeplearning4j.nn.conf.layers.RBM.VisibleUnit.GAUSSIAN)
107107
.nIn(d.numInputs()).nOut(3)
108-
.lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
108+
.lossFunction(LossFunctions.LossFunction.KL_DIVERGENCE).build())
109109
.build();
110110

111111
int numParams = conf.getLayer().initializer().numParams(conf,true);
@@ -127,7 +127,7 @@ public void testIris() {
127127
.learningRate(1e-1f)
128128
.layer(new org.deeplearning4j.nn.conf.layers.RBM.Builder(org.deeplearning4j.nn.conf.layers.RBM.HiddenUnit.RECTIFIED, org.deeplearning4j.nn.conf.layers.RBM.VisibleUnit.GAUSSIAN)
129129
.nIn(d.numInputs()).nOut(3)
130-
.lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
130+
.lossFunction(LossFunctions.LossFunction.KL_DIVERGENCE).build())
131131
.build();
132132

133133
int numParams = conf.getLayer().initializer().numParams(conf,true);
@@ -157,7 +157,7 @@ public void testBasic() {
157157
.learningRate(1e-1f)
158158
.layer(new org.deeplearning4j.nn.conf.layers.RBM.Builder()
159159
.nIn(6).nOut(4)
160-
.lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
160+
.lossFunction(LossFunctions.LossFunction.KL_DIVERGENCE).build())
161161
.build();
162162

163163
int numParams = conf.getLayer().initializer().numParams(conf,true);
@@ -208,7 +208,7 @@ public void testSetGetParams() {
208208
.learningRate(1e-1f)
209209
.layer(new org.deeplearning4j.nn.conf.layers.RBM.Builder()
210210
.nIn(6).nOut(4)
211-
.lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
211+
.lossFunction(LossFunctions.LossFunction.KL_DIVERGENCE).build())
212212
.build();
213213

214214
int numParams = conf.getLayer().initializer().numParams(conf,true);
@@ -241,7 +241,7 @@ public void testCg() {
241241
.learningRate(1e-1f)
242242
.layer(new org.deeplearning4j.nn.conf.layers.RBM.Builder()
243243
.nIn(6).nOut(4)
244-
.lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
244+
.lossFunction(LossFunctions.LossFunction.KL_DIVERGENCE).build())
245245
.build();
246246

247247
int numParams = conf.getLayer().initializer().numParams(conf,true);
@@ -276,7 +276,7 @@ public void testGradient() {
276276
.learningRate(1e-1f)
277277
.layer(new org.deeplearning4j.nn.conf.layers.RBM.Builder()
278278
.nIn(6).nOut(4)
279-
.lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
279+
.lossFunction(LossFunctions.LossFunction.KL_DIVERGENCE).build())
280280
.build();
281281

282282
int numParams = conf.getLayer().initializer().numParams(conf,true);

deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTest.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ public void testDbn() throws Exception {
184184
.nIn(4).nOut(3)
185185
.weightInit(WeightInit.DISTRIBUTION).dist(new UniformDistribution(0, 1))
186186
.activation("tanh")
187-
.lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
187+
.lossFunction(LossFunctions.LossFunction.KL_DIVERGENCE).build())
188188
.layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
189189
.nIn(3).nOut(3)
190190
.weightInit(WeightInit.DISTRIBUTION).dist(new UniformDistribution(0, 1))
@@ -655,6 +655,7 @@ public void testPredict() throws Exception{
655655
.layer(0, new DenseLayer.Builder().nIn(400).nOut(50).activation("relu").build())
656656
.layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation("softmax").nIn(50).nOut(10).build())
657657
.pretrain(false).backprop(true)
658+
.setInputType(InputType.convolutional(20,20,1))
658659
.build();
659660

660661
MultiLayerNetwork net = new MultiLayerNetwork(conf);
@@ -698,6 +699,7 @@ public void testOutput() throws Exception{
698699
.layer(0, new DenseLayer.Builder().nIn(400).nOut(50).activation("relu").build())
699700
.layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation("softmax").nIn(50).nOut(10).build())
700701
.pretrain(false).backprop(true)
702+
.setInputType(InputType.convolutional(20,20,1))
701703
.build();
702704

703705
MultiLayerNetwork net = new MultiLayerNetwork(conf);

deeplearning4j-core/src/test/java/org/deeplearning4j/nn/weights/WeightInitUtilTest.java

Lines changed: 21 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
* Created by nyghtowl on 11/14/15.
1616
*/
1717
public class WeightInitUtilTest {
18-
protected int[] shape = new int[]{2, 2};
18+
protected int fanIn = 3;
19+
protected int fanOut = 2;
20+
protected int[] shape = new int[]{fanIn, fanOut};
1921
protected Distribution dist = Distributions.createDistribution(new GaussianDistribution(0.0, 0.1));
2022

2123
@Before
@@ -26,7 +28,7 @@ public void doBefore(){
2628
@Test
2729
public void testDistribution(){
2830
INDArray params = Nd4j.create(shape,'f');
29-
INDArray weightsActual = WeightInitUtil.initWeights(shape, WeightInit.DISTRIBUTION, dist, params);
31+
INDArray weightsActual = WeightInitUtil.initWeights(-1, -1, shape, WeightInit.DISTRIBUTION, dist, params); //fan in/out not used
3032

3133
// expected calculation
3234
Nd4j.getRandom().setSeed(123);
@@ -38,7 +40,7 @@ public void testDistribution(){
3840
@Test
3941
public void testNormalize(){
4042
INDArray params = Nd4j.create(shape,'f');
41-
INDArray weightsActual = WeightInitUtil.initWeights(shape, WeightInit.NORMALIZED, dist, params);
43+
INDArray weightsActual = WeightInitUtil.initWeights(fanIn, fanOut, shape, WeightInit.NORMALIZED, dist, params);
4244

4345
// expected calculation
4446
Nd4j.getRandom().setSeed(123);
@@ -51,19 +53,19 @@ public void testNormalize(){
5153
@Test
5254
public void testRelu(){
5355
INDArray params = Nd4j.create(shape,'f');
54-
INDArray weightsActual = WeightInitUtil.initWeights(shape, WeightInit.RELU, dist,params);
56+
INDArray weightsActual = WeightInitUtil.initWeights(fanIn, fanOut, shape, WeightInit.RELU, dist,params);
5557

5658
// expected calculation
5759
Nd4j.getRandom().setSeed(123);
58-
INDArray weightsExpected = Nd4j.randn('f',shape).muli(FastMath.sqrt(2.0 / shape[0]));
60+
INDArray weightsExpected = Nd4j.randn('f',shape).muli(FastMath.sqrt(2.0 / fanIn));
5961

6062
assertEquals(weightsExpected, weightsActual);
6163
}
6264

6365
@Test
64-
public void testSize(){
66+
public void testSigmoidUniform(){
6567
INDArray params = Nd4j.create(shape,'f');
66-
INDArray weightsActual = WeightInitUtil.initWeights(shape, WeightInit.SIZE, dist, params);
68+
INDArray weightsActual = WeightInitUtil.initWeights(fanIn, fanOut, shape, WeightInit.SIGMOID_UNIFORM, dist, params);
6769

6870
// expected calculation
6971
Nd4j.getRandom().setSeed(123);
@@ -77,74 +79,59 @@ public void testSize(){
7779
@Test
7880
public void testUniform(){
7981
INDArray params = Nd4j.create(shape,'f');
80-
INDArray weightsActual = WeightInitUtil.initWeights(shape, WeightInit.UNIFORM, dist, params);
82+
INDArray weightsActual = WeightInitUtil.initWeights(fanIn, fanOut, shape, WeightInit.UNIFORM, dist, params);
8183

8284
// expected calculation
8385
Nd4j.getRandom().setSeed(123);
84-
double a = 1/(double) shape[0];
85-
INDArray weightsExpected = Nd4j.rand('f',shape).muli(2*a).subi(a);
86-
87-
assertEquals(weightsExpected, weightsActual);
88-
}
89-
90-
@Test
91-
public void testVI(){
92-
INDArray params = Nd4j.create(shape,'f');
93-
INDArray weightsActual = WeightInitUtil.initWeights(shape, WeightInit.VI, dist, params);
94-
95-
// expected calculation
96-
Nd4j.getRandom().setSeed(123);
97-
INDArray weightsExpected = Nd4j.rand('f',shape);
98-
int numValues = shape[0] + shape[1];
99-
double r = Math.sqrt(6) / Math.sqrt(numValues + 1);
100-
weightsExpected.muli(2).muli(r).subi(r);
86+
double a = 1.0/Math.sqrt(fanIn);
87+
INDArray weightsExpected = Nd4j.rand(shape,Nd4j.getDistributions().createUniform(-a,a));
10188

10289
assertEquals(weightsExpected, weightsActual);
10390
}
10491

10592
@Test
10693
public void testXavier(){
10794
INDArray params = Nd4j.create(shape,'f');
108-
INDArray weightsActual = WeightInitUtil.initWeights(shape, WeightInit.XAVIER, dist, params);
95+
INDArray weightsActual = WeightInitUtil.initWeights(fanIn, fanOut, shape, WeightInit.XAVIER, dist, params);
10996

11097
// expected calculation
11198
Nd4j.getRandom().setSeed(123);
11299
INDArray weightsExpected = Nd4j.randn('f',shape);
113-
weightsExpected.divi(FastMath.sqrt(shape[0] + shape[1]));
100+
weightsExpected.divi(FastMath.sqrt(2.0 / (fanIn + fanOut)));
114101

115102
assertEquals(weightsExpected, weightsActual);
116103
}
117104

118105
@Test
119-
public void testXavierCaffe(){
106+
public void testXavierFanIn(){
120107
INDArray params = Nd4j.create(shape,'f');
121-
INDArray weightsActual = WeightInitUtil.initWeights(shape, WeightInit.XAVIER_CAFFE, dist, params);
108+
INDArray weightsActual = WeightInitUtil.initWeights(fanIn, fanOut, shape, WeightInit.XAVIER_FAN_IN, dist, params);
122109

123110
// expected calculation
124111
Nd4j.getRandom().setSeed(123);
125112
INDArray weightsExpected = Nd4j.randn('f',shape);
126-
weightsExpected.divi(FastMath.sqrt(shape[0]));
113+
weightsExpected.divi(FastMath.sqrt(fanIn));
127114

128115
assertEquals(weightsExpected, weightsActual);
129116
}
130117

131118
@Test
132-
public void testXavierTorch(){
119+
public void testXavierLegacy(){
133120
INDArray params = Nd4j.create(shape,'f');
134-
INDArray weightsActual = WeightInitUtil.initWeights(shape, WeightInit.XAVIER_TORCH, dist, params);
121+
INDArray weightsActual = WeightInitUtil.initWeights(fanIn, fanOut, shape, WeightInit.XAVIER_LEGACY, dist, params);
135122

136123
// expected calculation
137124
Nd4j.getRandom().setSeed(123);
138125
INDArray weightsExpected = Nd4j.randn('f',shape);
139-
weightsExpected.muli(FastMath.sqrt(2.0 / (shape[0] + shape[1])));
126+
weightsExpected.muli(FastMath.sqrt(1.0 / (fanIn + fanOut)));
140127

141128
assertEquals(weightsExpected, weightsActual);
142129
}
143130

144131
@Test
145132
public void testZero(){
146133
INDArray params = Nd4j.create(shape,'f');
147-
INDArray weightsActual = WeightInitUtil.initWeights(shape, WeightInit.ZERO, dist, params);
134+
INDArray weightsActual = WeightInitUtil.initWeights(fanIn, fanOut, shape, WeightInit.ZERO, dist, params);
148135

149136
// expected calculation
150137
INDArray weightsExpected = Nd4j.create(shape,'f');

deeplearning4j-core/src/test/java/org/deeplearning4j/optimize/solver/BackTrackLineSearchTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ private static MultiLayerConfiguration getIrisMultiLayerConfig(String activation
232232
.optimizationAlgo(optimizer)
233233
.iterations(iterations)
234234
.miniBatch(false).momentum(0.9)
235-
.learningRate(0.1).updater(Updater.NESTEROVS)
235+
.learningRate(0.01).updater(Updater.NESTEROVS)
236236
.seed(12345L)
237237
.list()
238238
.layer(0, new DenseLayer.Builder()

deeplearning4j-core/src/test/java/org/deeplearning4j/plot/RenderTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ public void testPlotter() throws Exception {
5252
.nIn(784).nOut(600)
5353
.weightInit(WeightInit.DISTRIBUTION).dist(new NormalDistribution(1e-3, 1e-1))
5454
.dropOut(0.5)
55-
.lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
55+
.lossFunction(LossFunctions.LossFunction.KL_DIVERGENCE).build())
5656
.build();
5757

5858

deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/ConvolutionParamInitializer.java

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,17 @@ protected INDArray createWeightMatrix(NeuralNetConfiguration conf, INDArray weig
126126
if(initializeParams) {
127127
Distribution dist = Distributions.createDistribution(conf.getLayer().getDist());
128128
int[] kernel = layerConf.getKernelSize();
129-
return WeightInitUtil.initWeights(new int[]{layerConf.getNOut(), layerConf.getNIn(), kernel[0], kernel[1]},
129+
int[] stride = layerConf.getStride();
130+
131+
int inputDepth = layerConf.getNIn();
132+
int outputDepth = layerConf.getNOut();
133+
134+
double fanIn = inputDepth * kernel[0] * kernel[1];
135+
double fanOut = outputDepth * kernel[0] * kernel[1] / ((double) stride[0] * stride[1]);
136+
137+
int[] weightsShape = new int[]{outputDepth, inputDepth, kernel[0], kernel[1]};
138+
139+
return WeightInitUtil.initWeights(fanIn, fanOut, weightsShape,
130140
layerConf.getWeightInit(), dist, 'c', weightView);
131141
} else {
132142
int[] kernel = layerConf.getKernelSize();

deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/DefaultParamInitializer.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,10 +117,16 @@ protected INDArray createWeightMatrix(NeuralNetConfiguration conf, INDArray weig
117117
(org.deeplearning4j.nn.conf.layers.FeedForwardLayer) conf.getLayer();
118118

119119
if(initializeParameters) {
120+
121+
int nIn = layerConf.getNIn();
122+
int nOut = layerConf.getNOut();
123+
int[] shape = new int[]{nIn,nOut};
124+
120125
Distribution dist = Distributions.createDistribution(layerConf.getDist());
121126
INDArray ret = WeightInitUtil.initWeights(
122-
layerConf.getNIn(),
123-
layerConf.getNOut(),
127+
nIn, //Fan in
128+
nOut, //Fan out
129+
shape,
124130
layerConf.getWeightInit(),
125131
dist,
126132
weightParamView);

0 commit comments

Comments
 (0)