Skip to content

Commit

Permalink
add constant and augment tests
Browse files Browse the repository at this point in the history
  • Loading branch information
maxpumperla committed Sep 13, 2017
1 parent 9e3bdc9 commit 23dce46
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,9 @@ public static Pair<WeightInit, Distribution> mapWeightInitialization(String kera
init = WeightInit.DISTRIBUTION;
} else if (kerasInit.equals(conf.getINIT_CONSTANT()) ||
kerasInit.equals(conf.getINIT_CONSTANT_ALIAS())) {
// TODO: CONSTANT keras.initializers.Constant(value=0)
init = WeightInit.ONES;
double value = (double) initConfig.get(conf.getLAYER_FIELD_INIT_VALUE());
dist = new ConstantDistribution(value);
init = WeightInit.DISTRIBUTION;
} else if (kerasInit.equals(conf.getINIT_ORTHOGONAL()) ||
kerasInit.equals(conf.getINIT_ORTHOGONAL_ALIAS())) {
if (kerasMajorVersion == 2) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
package org.deeplearning4j.nn.modelimport.keras.configurations;

import org.deeplearning4j.nn.conf.distribution.*;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.layers.core.KerasDense;
import org.deeplearning4j.nn.weights.WeightInit;
import org.junit.Test;
import org.nd4j.linalg.primitives.Pair;

import java.util.HashMap;
import java.util.Map;
Expand All @@ -15,13 +17,13 @@

public class KerasInitilizationTest {

double scale = 0.4;
double minValue = -0.3;
double scale = 0.2;
double minValue = -0.2;
double maxValue = 0.2;
double mean = 0.1;
double stdDev = 0.3;
double value = 0.0;
double gain = 2.0;
double mean = 0.0;
double stdDev = 0.2;
double value = 42.0;
double gain = 0.2;
String distribution = "normal";
String mode = "fan_in";

Expand All @@ -37,12 +39,14 @@ public void testInitializers() throws Exception {
String[] keras1Inits = initializers(conf1);
String[] keras2Inits = initializers(conf2);
WeightInit[] dl4jInits = dl4jInitializers();
Distribution[] dl4jDistributions = dl4jDistributions();

for (int i=0; i< dl4jInits.length - 1; i++) {
initilizationDenseLayer(conf1, keras1, keras1Inits[i], dl4jInits[i]);
initilizationDenseLayer(conf2, keras2, keras2Inits[i], dl4jInits[i]);
initilizationDenseLayer(conf1, keras1, keras1Inits[i], dl4jInits[i], dl4jDistributions[i]);
initilizationDenseLayer(conf2, keras2, keras2Inits[i], dl4jInits[i], dl4jDistributions[i]);

initilizationDenseLayer(conf2, keras2, keras2Inits[dl4jInits.length-1], dl4jInits[dl4jInits.length-1]);
initilizationDenseLayer(conf2, keras2, keras2Inits[dl4jInits.length-1],
dl4jInits[dl4jInits.length-1], dl4jDistributions[dl4jInits.length-1]);
}
}

Expand Down Expand Up @@ -80,13 +84,30 @@ private WeightInit[] dl4jInitializers() {
WeightInit.IDENTITY,
WeightInit.DISTRIBUTION,
WeightInit.DISTRIBUTION,
WeightInit.ONES,
WeightInit.VAR_SCALING_NORMAL_FAN_IN,
};
WeightInit.DISTRIBUTION,
WeightInit.VAR_SCALING_NORMAL_FAN_IN};
}

private Distribution[] dl4jDistributions() {
return new Distribution[] {
null,
null,
null,
null,
new UniformDistribution(minValue, maxValue),
null,
null,
null,
null,
null,
new NormalDistribution(mean, stdDev),
new OrthogonalDistribution(gain),
new ConstantDistribution(value),
null};
}

private void initilizationDenseLayer(KerasLayerConfiguration conf, Integer kerasVersion,
String initializer, WeightInit dl4jInitializer)
String initializer, WeightInit dl4jInitializer, Distribution dl4jDistribution)
throws Exception {
Map<String, Object> layerConfig = new HashMap<>();
layerConfig.put(conf.getLAYER_FIELD_CLASS_NAME(), conf.getLAYER_CLASS_NAME_DENSE());
Expand Down Expand Up @@ -125,5 +146,7 @@ private void initilizationDenseLayer(KerasLayerConfiguration conf, Integer keras

DenseLayer layer = new KerasDense(layerConfig, false).getDenseLayer();
assertEquals(dl4jInitializer, layer.getWeightInit());
assertEquals(dl4jDistribution, layer.getDist());

}
}

0 comments on commit 23dce46

Please sign in to comment.