diff --git a/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/config/KerasLayerConfiguration.java b/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/config/KerasLayerConfiguration.java index f201507cbeaa..ab418844da64 100644 --- a/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/config/KerasLayerConfiguration.java +++ b/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/config/KerasLayerConfiguration.java @@ -189,6 +189,15 @@ public class KerasLayerConfiguration { /* Keras weight initializers. */ private final String LAYER_FIELD_INIT = ""; // 1: init, 2: kernel_initializer private final String LAYER_FIELD_BIAS_INIT = "bias_initializer"; // keras 2 only + private final String LAYER_FIELD_INIT_MEAN = "mean"; + private final String LAYER_FIELD_INIT_STDDEV = "stddev"; + private final String LAYER_FIELD_INIT_SCALE = "scale"; + private final String LAYER_FIELD_INIT_MINVAL = "minval"; + private final String LAYER_FIELD_INIT_MAXVAL = "maxval"; + private final String LAYER_FIELD_INIT_VALUE = "value"; + private final String LAYER_FIELD_INIT_GAIN = "gain"; + private final String LAYER_FIELD_INIT_MODE = "mode"; + private final String LAYER_FIELD_INIT_DISTRIBUTION = "distribution"; private final String INIT_UNIFORM = "uniform"; private final String INIT_RANDOM_UNIFORM = "random_uniform"; diff --git a/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasInitilizationUtils.java b/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasInitilizationUtils.java index c75e93d4fcdb..63e3d1893709 100644 --- a/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasInitilizationUtils.java +++ b/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasInitilizationUtils.java @@ -19,14 +19,18 @@ import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.nn.conf.distribution.Distribution; +import org.deeplearning4j.nn.conf.distribution.NormalDistribution; +import org.deeplearning4j.nn.conf.distribution.UniformDistribution; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; import org.deeplearning4j.nn.weights.WeightInit; +import org.nd4j.linalg.api.rng.distribution.impl.OrthogonalDistribution; import org.nd4j.linalg.primitives.Pair; import java.util.HashMap; import java.util.Map; +import java.util.Objects; /** * Utility functionality for Keras weight initializers @@ -45,10 +49,11 @@ public class KerasInitilizationUtils { */ public static Pair mapWeightInitialization(String kerasInit, KerasLayerConfiguration conf, - Map initConfig) - throws UnsupportedKerasConfigurationException { + Map initConfig, + int kerasMajorVersion) + throws UnsupportedKerasConfigurationException, InvalidKerasConfigurationException { - WeightInit init = WeightInit.XAVIER; + WeightInit init = null; Distribution dist = null; if (kerasInit != null) { if (kerasInit.equals(conf.getINIT_GLOROT_NORMAL())) { @@ -71,45 +76,103 @@ public static Pair mapWeightInitialization(String kera kerasInit.equals(conf.getINIT_ZEROS()) || kerasInit.equals(conf.getINIT_ZEROS_ALIAS())) { init = WeightInit.ZERO; - } else if (kerasInit.equals(conf.getINIT_CONSTANT()) || - kerasInit.equals(conf.getINIT_CONSTANT_ALIAS())) { - // FIXME: CONSTANT - // keras.initializers.Constant(value=0) - init = WeightInit.ZERO; } else if (kerasInit.equals(conf.getINIT_UNIFORM()) || kerasInit.equals(conf.getINIT_RANDOM_UNIFORM()) || kerasInit.equals(conf.getINIT_RANDOM_UNIFORM_ALIAS())) { - // FIXME: read minval and maxval from config - // keras.initializers.RandomUniform(minval=-0.05, maxval=0.05, seed=None) keras1: scale - init = WeightInit.UNIFORM; + if (kerasMajorVersion == 2) { + double minVal = (double) initConfig.get(conf.getLAYER_FIELD_INIT_MINVAL()); + double maxVal = (double) initConfig.get(conf.getLAYER_FIELD_INIT_MAXVAL()); + dist = new UniformDistribution(minVal, maxVal); + } else { + double scale = 0.05; + if (initConfig.containsKey(conf.getLAYER_FIELD_INIT_SCALE())) + scale = (double) initConfig.get(conf.getLAYER_FIELD_INIT_SCALE()); + dist = new UniformDistribution(-scale, scale); + } + init = WeightInit.DISTRIBUTION; } else if (kerasInit.equals(conf.getINIT_RANDOM_NORMAL()) || kerasInit.equals(conf.getINIT_RANDOM_NORMAL_ALIAS())) { - // FIXME: read mean and stddev from config - // keras.initializers.RandomNormal(mean=0.0, stddev=0.05, seed=None) + if (kerasMajorVersion == 2) { + double mean = (double) initConfig.get(conf.getLAYER_FIELD_INIT_MEAN()); + double stdDev = (double) initConfig.get(conf.getLAYER_FIELD_INIT_STDDEV()); + dist = new NormalDistribution(mean, stdDev); + } else { + double scale = 0.05; + if (initConfig.containsKey(conf.getLAYER_FIELD_INIT_SCALE())) + scale = (double) initConfig.get(conf.getLAYER_FIELD_INIT_SCALE()); + dist = new NormalDistribution(0, scale); + } init = WeightInit.DISTRIBUTION; + } else if (kerasInit.equals(conf.getINIT_CONSTANT()) || + kerasInit.equals(conf.getINIT_CONSTANT_ALIAS())) { + // FIXME: CONSTANT keras.initializers.Constant(value=0) + init = WeightInit.ZERO; } else if (kerasInit.equals(conf.getINIT_ORTHOGONAL()) || kerasInit.equals(conf.getINIT_ORTHOGONAL_ALIAS())) { - // TODO keras.initializers.Orthogonal(gain=1.0, seed=None) + if (kerasMajorVersion == 2) { + double gain = (double) initConfig.get(conf.getLAYER_FIELD_INIT_GAIN()); + // TODO: dist = new OrthogonalDistribution(gain); + } else { + double scale = 1.1; + if (initConfig.containsKey(conf.getLAYER_FIELD_INIT_SCALE())) + scale = (double) initConfig.get(conf.getLAYER_FIELD_INIT_SCALE()); + //TODO: dist = new OrthogonalDistribution(scale); + } init = WeightInit.DISTRIBUTION; } else if (kerasInit.equals(conf.getINIT_TRUNCATED_NORMAL()) || kerasInit.equals(conf.getINIT_TRUNCATED_NORMAL_ALIAS())) { - // FIXME: read mean and stddev from config - // keras.initializers.TruncatedNormal(mean=0.0, stddev=0.05, seed=None) keras1: no mean, always 0, stddev is scale + double mean = (double) initConfig.get(conf.getLAYER_FIELD_INIT_MEAN()); + double stdDev = (double) initConfig.get(conf.getLAYER_FIELD_INIT_STDDEV()); + // TODO: map to truncated + dist = new NormalDistribution(mean, stdDev); init = WeightInit.DISTRIBUTION; } else if (kerasInit.equals(conf.getINIT_IDENTITY()) || kerasInit.equals(conf.getINIT_IDENTITY_ALIAS())) { - // TODO: takes gain/scale parameter - // keras.initializers.Identity(gain=1.0) keras1: scale + if (kerasMajorVersion == 2) { + double gain = (double) initConfig.get(conf.getLAYER_FIELD_INIT_GAIN()); + if (gain != 1.) + log.warn("Scaled identity weight init not supported, setting gain=1"); + } else { + double scale = 1.; + if (initConfig.containsKey(conf.getLAYER_FIELD_INIT_SCALE())) + scale = (double) initConfig.get(conf.getLAYER_FIELD_INIT_SCALE()); + if (scale != 1.) + log.warn("Scaled identity weight init not supported, setting scale=1"); + } + // TODO: map to scaled Identity init = WeightInit.IDENTITY; } else if (kerasInit.equals(conf.getINIT_VARIANCE_SCALING())) { - // keras.initializers.VarianceScaling(scale=1.0, mode='fan_in', distribution='normal', seed=None) - // With distribution="normal", samples are drawn from a truncated normal distribution centered on zero, with stddev = sqrt(scale / n) where n is: - // number of input units in the weight tensor, if mode = "fan_in" - // number of output units, if mode = "fan_out" - // average of the numbers of input and output units, if mode = "fan_avg" - // With distribution="uniform", samples are drawn from a uniform distribution within [-limit, limit], with limit = sqrt(3 * scale / n). - - init = WeightInit.XAVIER_UNIFORM; + int scale = (int) initConfig.get(conf.getLAYER_FIELD_INIT_SCALE()); + if (scale != 1) + log.warn("Scaled identity weight init not supported, setting scale=1"); + String mode = (String) initConfig.get(conf.getLAYER_FIELD_INIT_MODE()); + String distribution = (String) initConfig.get(conf.getLAYER_FIELD_INIT_DISTRIBUTION()); + switch (mode) { + case "fan_in": + if (distribution.equals("normal")) { + init = WeightInit.VAR_SCALING_NORMAL_FAN_IN; + } else { + init = WeightInit.VAR_SCALING_UNIFORM_FAN_IN; + } + break; + case "fan_out": + if (distribution.equals("normal")) { + init = WeightInit.VAR_SCALING_NORMAL_FAN_OUT; + } else { + init = WeightInit.VAR_SCALING_UNIFORM_FAN_OUT; + } + break; + case "fan_avg": + if (distribution.equals("normal")) { + init = WeightInit.VAR_SCALING_NORMAL_FAN_AVG; + } else { + init = WeightInit.VAR_SCALING_UNIFORM_FAN_AVG; + } + break; + default: + throw new InvalidKerasConfigurationException("Initialization argument 'mode' has to be either " + + "fan_in, fan_out or fan_avg"); + } } else { throw new UnsupportedKerasConfigurationException("Unknown keras weight initializer " + kerasInit); } @@ -134,20 +197,23 @@ public static WeightInit getWeightInitFromConfig(Map layerConfig Map innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf); if (!innerConfig.containsKey(initField)) throw new InvalidKerasConfigurationException("Keras layer is missing " + initField + " field"); - String kerasInit = "glorot_normal"; + String kerasInit; Map initMap; if (kerasMajorVersion != 2) { kerasInit = (String) innerConfig.get(initField); initMap = innerConfig; } else { - initMap = (HashMap) innerConfig.get(initField); - if (initMap.containsKey("class_name")) { - kerasInit = (String) initMap.get("class_name"); + Map fullInitMap = (HashMap) innerConfig.get(initField); + initMap = (HashMap) fullInitMap.get("config"); + if (fullInitMap.containsKey("class_name")) { + kerasInit = (String) fullInitMap.get("class_name"); + } else { + throw new UnsupportedKerasConfigurationException("Incomplete initialization class"); } } Pair init; try { - init = mapWeightInitialization(kerasInit, conf, initMap); + init = mapWeightInitialization(kerasInit, conf, initMap, kerasMajorVersion); } catch (UnsupportedKerasConfigurationException e) { if (enforceTrainingConfig) throw e; diff --git a/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/KerasInitilizationTest.java b/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/KerasInitilizationTest.java index 09ac5e24dc66..e65f98a41f14 100644 --- a/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/KerasInitilizationTest.java +++ b/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/KerasInitilizationTest.java @@ -15,6 +15,16 @@ public class KerasInitilizationTest { + double scale = 0.4; + double minValue = -0.3; + double maxValue = 0.2; + double mean = 0.1; + double stdDev = 0.3; + double value = 0.0; + double gain = 2.0; + String distribution = "normal"; + String mode = "fan_in"; + private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); @@ -28,9 +38,11 @@ public void testInitializers() throws Exception { String[] keras2Inits = initializers(conf2); WeightInit[] dl4jInits = dl4jInitializers(); - for (int i=0; i< dl4jInits.length; i++) { + for (int i=0; i< dl4jInits.length - 1; i++) { initilizationDenseLayer(conf1, keras1, keras1Inits[i], dl4jInits[i]); initilizationDenseLayer(conf2, keras2, keras2Inits[i], dl4jInits[i]); + + initilizationDenseLayer(conf2, keras2, keras2Inits[dl4jInits.length-1], dl4jInits[dl4jInits.length-1]); } } @@ -39,6 +51,7 @@ private String[] initializers(KerasLayerConfiguration conf) { conf.getINIT_GLOROT_NORMAL(), conf.getINIT_GLOROT_UNIFORM(), conf.getINIT_LECUN_NORMAL(), + conf.getINIT_LECUN_UNIFORM(), conf.getINIT_RANDOM_UNIFORM(), conf.getINIT_HE_NORMAL(), conf.getINIT_HE_UNIFORM(), @@ -50,7 +63,6 @@ private String[] initializers(KerasLayerConfiguration conf) { // conf.getINIT_CONSTANT(), // conf.getINIT_NORMAL(), // conf.getINIT_ORTHOGONAL(), - // conf.getINIT_LECUN_UNIFORM() }; } @@ -59,13 +71,14 @@ private WeightInit[] dl4jInitializers() { WeightInit.XAVIER, WeightInit.XAVIER_UNIFORM, WeightInit.LECUN_NORMAL, - WeightInit.UNIFORM, + WeightInit.LECUN_UNIFORM, + WeightInit.DISTRIBUTION, WeightInit.RELU, WeightInit.RELU_UNIFORM, WeightInit.ONES, WeightInit.ZERO, WeightInit.IDENTITY, - WeightInit.XAVIER_UNIFORM // TODO: Variance scaling is incorrectly mapped + WeightInit.VAR_SCALING_NORMAL_FAN_IN, }; } @@ -79,9 +92,28 @@ private void initilizationDenseLayer(KerasLayerConfiguration conf, Integer keras config.put(conf.getLAYER_FIELD_NAME(), "init_test"); if (kerasVersion == 1) { config.put(conf.getLAYER_FIELD_INIT(), initializer); + config.put(conf.getLAYER_FIELD_INIT_MEAN(), mean); + config.put(conf.getLAYER_FIELD_INIT_STDDEV(), stdDev); + config.put(conf.getLAYER_FIELD_INIT_SCALE(), scale); + config.put(conf.getLAYER_FIELD_INIT_MINVAL(), minValue); + config.put(conf.getLAYER_FIELD_INIT_MAXVAL(), maxValue); + config.put(conf.getLAYER_FIELD_INIT_VALUE(), value); + config.put(conf.getLAYER_FIELD_INIT_GAIN(), gain); } else { Map init = new HashMap<>(); init.put("class_name", initializer); + Map innerInit = new HashMap<>(); + innerInit.put(conf.getLAYER_FIELD_INIT_MEAN(), mean); + innerInit.put(conf.getLAYER_FIELD_INIT_STDDEV(), stdDev); + innerInit.put(conf.getLAYER_FIELD_INIT_SCALE(), scale); + innerInit.put(conf.getLAYER_FIELD_INIT_MINVAL(), minValue); + innerInit.put(conf.getLAYER_FIELD_INIT_MAXVAL(), maxValue); + innerInit.put(conf.getLAYER_FIELD_INIT_VALUE(), value); + innerInit.put(conf.getLAYER_FIELD_INIT_GAIN(), gain); + innerInit.put(conf.getLAYER_FIELD_INIT_MODE(), mode); + innerInit.put(conf.getLAYER_FIELD_INIT_DISTRIBUTION(), distribution); + + init.put(conf.getLAYER_FIELD_CONFIG(), innerInit); config.put(conf.getLAYER_FIELD_INIT(), init); } config.put(conf.getLAYER_FIELD_OUTPUT_DIM(), 1337); diff --git a/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/weights/WeightInit.java b/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/weights/WeightInit.java index 7bdf0a64faa1..58bb746d7319 100755 --- a/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/weights/WeightInit.java +++ b/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/weights/WeightInit.java @@ -53,5 +53,6 @@ */ public enum WeightInit { DISTRIBUTION, ZERO, ONES, SIGMOID_UNIFORM, NORMAL, LECUN_NORMAL, UNIFORM, XAVIER, XAVIER_UNIFORM, XAVIER_FAN_IN, XAVIER_LEGACY, RELU, - RELU_UNIFORM, IDENTITY, LECUN_UNIFORM + RELU_UNIFORM, IDENTITY, LECUN_UNIFORM, VAR_SCALING_NORMAL_FAN_IN, VAR_SCALING_NORMAL_FAN_OUT, VAR_SCALING_NORMAL_FAN_AVG, + VAR_SCALING_UNIFORM_FAN_IN, VAR_SCALING_UNIFORM_FAN_OUT, VAR_SCALING_UNIFORM_FAN_AVG } diff --git a/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/weights/WeightInitUtil.java b/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/weights/WeightInitUtil.java index 88fa12b6695d..e8efb701f616 100755 --- a/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/weights/WeightInitUtil.java +++ b/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/weights/WeightInitUtil.java @@ -123,10 +123,31 @@ public static INDArray initWeights(double fanIn, double fanOut, int[] shape, Wei ret = Nd4j.createUninitialized(shape, order).assign(Nd4j.eye(shape[0])); } break; + case VAR_SCALING_NORMAL_FAN_IN: + // TODO: needs to be truncated normal to match keras. + ret = Nd4j.randn(order, shape).divi(FastMath.sqrt(fanIn)); + break; + case VAR_SCALING_NORMAL_FAN_OUT: + ret = Nd4j.randn(order, shape).divi(FastMath.sqrt(fanOut)); + break; + case VAR_SCALING_NORMAL_FAN_AVG: + ret = Nd4j.randn(order, shape).divi(FastMath.sqrt((fanIn + fanOut) / 2)); + break; + case VAR_SCALING_UNIFORM_FAN_IN: + double scalingFanIn = 3.0 / Math.sqrt(fanIn); + ret = Nd4j.rand(shape, Nd4j.getDistributions().createUniform(-scalingFanIn, scalingFanIn)); + break; + case VAR_SCALING_UNIFORM_FAN_OUT: + double scalingFanOut = 3.0 / Math.sqrt(fanOut); + ret = Nd4j.rand(shape, Nd4j.getDistributions().createUniform(-scalingFanOut, scalingFanOut)); + break; + case VAR_SCALING_UNIFORM_FAN_AVG: + double scalingFanAvg = 3.0 / Math.sqrt((fanIn + fanOut) / 2); + ret = Nd4j.rand(shape, Nd4j.getDistributions().createUniform(-scalingFanAvg, scalingFanAvg)); + break; default: throw new IllegalStateException("Illegal weight init value: " + initScheme); } - INDArray flat = Nd4j.toFlattened(order, ret); if (flat.length() != paramView.length()) throw new RuntimeException("ParamView length does not match initialized weights length (view length: "