Skip to content

Commit

Permalink
add variance scaling inits etc
Browse files Browse the repository at this point in the history
  • Loading branch information
maxpumperla committed Sep 13, 2017
1 parent 1f13f10 commit b2def7e
Show file tree
Hide file tree
Showing 5 changed files with 166 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,15 @@ public class KerasLayerConfiguration {
/* Keras weight initializers. */
private final String LAYER_FIELD_INIT = ""; // 1: init, 2: kernel_initializer
private final String LAYER_FIELD_BIAS_INIT = "bias_initializer"; // keras 2 only
private final String LAYER_FIELD_INIT_MEAN = "mean";
private final String LAYER_FIELD_INIT_STDDEV = "stddev";
private final String LAYER_FIELD_INIT_SCALE = "scale";
private final String LAYER_FIELD_INIT_MINVAL = "minval";
private final String LAYER_FIELD_INIT_MAXVAL = "maxval";
private final String LAYER_FIELD_INIT_VALUE = "value";
private final String LAYER_FIELD_INIT_GAIN = "gain";
private final String LAYER_FIELD_INIT_MODE = "mode";
private final String LAYER_FIELD_INIT_DISTRIBUTION = "distribution";

private final String INIT_UNIFORM = "uniform";
private final String INIT_RANDOM_UNIFORM = "random_uniform";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,18 @@

import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
import org.deeplearning4j.nn.conf.distribution.UniformDistribution;
import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.api.rng.distribution.impl.OrthogonalDistribution;
import org.nd4j.linalg.primitives.Pair;

import java.util.HashMap;
import java.util.Map;
import java.util.Objects;

/**
* Utility functionality for Keras weight initializers
Expand All @@ -45,10 +49,11 @@ public class KerasInitilizationUtils {
*/
public static Pair<WeightInit, Distribution> mapWeightInitialization(String kerasInit,
KerasLayerConfiguration conf,
Map<String, Object> initConfig)
throws UnsupportedKerasConfigurationException {
Map<String, Object> initConfig,
int kerasMajorVersion)
throws UnsupportedKerasConfigurationException, InvalidKerasConfigurationException {

WeightInit init = WeightInit.XAVIER;
WeightInit init = null;
Distribution dist = null;
if (kerasInit != null) {
if (kerasInit.equals(conf.getINIT_GLOROT_NORMAL())) {
Expand All @@ -71,45 +76,103 @@ public static Pair<WeightInit, Distribution> mapWeightInitialization(String kera
kerasInit.equals(conf.getINIT_ZEROS()) ||
kerasInit.equals(conf.getINIT_ZEROS_ALIAS())) {
init = WeightInit.ZERO;
} else if (kerasInit.equals(conf.getINIT_CONSTANT()) ||
kerasInit.equals(conf.getINIT_CONSTANT_ALIAS())) {
// FIXME: CONSTANT
// keras.initializers.Constant(value=0)
init = WeightInit.ZERO;
} else if (kerasInit.equals(conf.getINIT_UNIFORM()) ||
kerasInit.equals(conf.getINIT_RANDOM_UNIFORM()) ||
kerasInit.equals(conf.getINIT_RANDOM_UNIFORM_ALIAS())) {
// FIXME: read minval and maxval from config
// keras.initializers.RandomUniform(minval=-0.05, maxval=0.05, seed=None) keras1: scale
init = WeightInit.UNIFORM;
if (kerasMajorVersion == 2) {
double minVal = (double) initConfig.get(conf.getLAYER_FIELD_INIT_MINVAL());
double maxVal = (double) initConfig.get(conf.getLAYER_FIELD_INIT_MAXVAL());
dist = new UniformDistribution(minVal, maxVal);
} else {
double scale = 0.05;
if (initConfig.containsKey(conf.getLAYER_FIELD_INIT_SCALE()))
scale = (double) initConfig.get(conf.getLAYER_FIELD_INIT_SCALE());
dist = new UniformDistribution(-scale, scale);
}
init = WeightInit.DISTRIBUTION;
} else if (kerasInit.equals(conf.getINIT_RANDOM_NORMAL()) ||
kerasInit.equals(conf.getINIT_RANDOM_NORMAL_ALIAS())) {
// FIXME: read mean and stddev from config
// keras.initializers.RandomNormal(mean=0.0, stddev=0.05, seed=None)
if (kerasMajorVersion == 2) {
double mean = (double) initConfig.get(conf.getLAYER_FIELD_INIT_MEAN());
double stdDev = (double) initConfig.get(conf.getLAYER_FIELD_INIT_STDDEV());
dist = new NormalDistribution(mean, stdDev);
} else {
double scale = 0.05;
if (initConfig.containsKey(conf.getLAYER_FIELD_INIT_SCALE()))
scale = (double) initConfig.get(conf.getLAYER_FIELD_INIT_SCALE());
dist = new NormalDistribution(0, scale);
}
init = WeightInit.DISTRIBUTION;
} else if (kerasInit.equals(conf.getINIT_CONSTANT()) ||
kerasInit.equals(conf.getINIT_CONSTANT_ALIAS())) {
// FIXME: CONSTANT keras.initializers.Constant(value=0)
init = WeightInit.ZERO;
} else if (kerasInit.equals(conf.getINIT_ORTHOGONAL()) ||
kerasInit.equals(conf.getINIT_ORTHOGONAL_ALIAS())) {
// TODO keras.initializers.Orthogonal(gain=1.0, seed=None)
if (kerasMajorVersion == 2) {
double gain = (double) initConfig.get(conf.getLAYER_FIELD_INIT_GAIN());
// TODO: dist = new OrthogonalDistribution(gain);
} else {
double scale = 1.1;
if (initConfig.containsKey(conf.getLAYER_FIELD_INIT_SCALE()))
scale = (double) initConfig.get(conf.getLAYER_FIELD_INIT_SCALE());
//TODO: dist = new OrthogonalDistribution(scale);
}
init = WeightInit.DISTRIBUTION;
} else if (kerasInit.equals(conf.getINIT_TRUNCATED_NORMAL()) ||
kerasInit.equals(conf.getINIT_TRUNCATED_NORMAL_ALIAS())) {
// FIXME: read mean and stddev from config
// keras.initializers.TruncatedNormal(mean=0.0, stddev=0.05, seed=None) keras1: no mean, always 0, stddev is scale
double mean = (double) initConfig.get(conf.getLAYER_FIELD_INIT_MEAN());
double stdDev = (double) initConfig.get(conf.getLAYER_FIELD_INIT_STDDEV());
// TODO: map to truncated
dist = new NormalDistribution(mean, stdDev);
init = WeightInit.DISTRIBUTION;
} else if (kerasInit.equals(conf.getINIT_IDENTITY()) ||
kerasInit.equals(conf.getINIT_IDENTITY_ALIAS())) {
// TODO: takes gain/scale parameter
// keras.initializers.Identity(gain=1.0) keras1: scale
if (kerasMajorVersion == 2) {
double gain = (double) initConfig.get(conf.getLAYER_FIELD_INIT_GAIN());
if (gain != 1.)
log.warn("Scaled identity weight init not supported, setting gain=1");
} else {
double scale = 1.;
if (initConfig.containsKey(conf.getLAYER_FIELD_INIT_SCALE()))
scale = (double) initConfig.get(conf.getLAYER_FIELD_INIT_SCALE());
if (scale != 1.)
log.warn("Scaled identity weight init not supported, setting scale=1");
}
// TODO: map to scaled Identity
init = WeightInit.IDENTITY;
} else if (kerasInit.equals(conf.getINIT_VARIANCE_SCALING())) {
// keras.initializers.VarianceScaling(scale=1.0, mode='fan_in', distribution='normal', seed=None)
// With distribution="normal", samples are drawn from a truncated normal distribution centered on zero, with stddev = sqrt(scale / n) where n is:
// number of input units in the weight tensor, if mode = "fan_in"
// number of output units, if mode = "fan_out"
// average of the numbers of input and output units, if mode = "fan_avg"
// With distribution="uniform", samples are drawn from a uniform distribution within [-limit, limit], with limit = sqrt(3 * scale / n).

init = WeightInit.XAVIER_UNIFORM;
int scale = (int) initConfig.get(conf.getLAYER_FIELD_INIT_SCALE());
if (scale != 1)
log.warn("Scaled identity weight init not supported, setting scale=1");
String mode = (String) initConfig.get(conf.getLAYER_FIELD_INIT_MODE());
String distribution = (String) initConfig.get(conf.getLAYER_FIELD_INIT_DISTRIBUTION());
switch (mode) {
case "fan_in":
if (distribution.equals("normal")) {
init = WeightInit.VAR_SCALING_NORMAL_FAN_IN;
} else {
init = WeightInit.VAR_SCALING_UNIFORM_FAN_IN;
}
break;
case "fan_out":
if (distribution.equals("normal")) {
init = WeightInit.VAR_SCALING_NORMAL_FAN_OUT;
} else {
init = WeightInit.VAR_SCALING_UNIFORM_FAN_OUT;
}
break;
case "fan_avg":
if (distribution.equals("normal")) {
init = WeightInit.VAR_SCALING_NORMAL_FAN_AVG;
} else {
init = WeightInit.VAR_SCALING_UNIFORM_FAN_AVG;
}
break;
default:
throw new InvalidKerasConfigurationException("Initialization argument 'mode' has to be either " +
"fan_in, fan_out or fan_avg");
}
} else {
throw new UnsupportedKerasConfigurationException("Unknown keras weight initializer " + kerasInit);
}
Expand All @@ -134,20 +197,23 @@ public static WeightInit getWeightInitFromConfig(Map<String, Object> layerConfig
Map<String, Object> innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf);
if (!innerConfig.containsKey(initField))
throw new InvalidKerasConfigurationException("Keras layer is missing " + initField + " field");
String kerasInit = "glorot_normal";
String kerasInit;
Map<String, Object> initMap;
if (kerasMajorVersion != 2) {
kerasInit = (String) innerConfig.get(initField);
initMap = innerConfig;
} else {
initMap = (HashMap) innerConfig.get(initField);
if (initMap.containsKey("class_name")) {
kerasInit = (String) initMap.get("class_name");
Map<String, Object> fullInitMap = (HashMap) innerConfig.get(initField);
initMap = (HashMap) fullInitMap.get("config");
if (fullInitMap.containsKey("class_name")) {
kerasInit = (String) fullInitMap.get("class_name");
} else {
throw new UnsupportedKerasConfigurationException("Incomplete initialization class");
}
}
Pair<WeightInit, Distribution> init;
try {
init = mapWeightInitialization(kerasInit, conf, initMap);
init = mapWeightInitialization(kerasInit, conf, initMap, kerasMajorVersion);
} catch (UnsupportedKerasConfigurationException e) {
if (enforceTrainingConfig)
throw e;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,16 @@

public class KerasInitilizationTest {

double scale = 0.4;
double minValue = -0.3;
double maxValue = 0.2;
double mean = 0.1;
double stdDev = 0.3;
double value = 0.0;
double gain = 2.0;
String distribution = "normal";
String mode = "fan_in";

private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration();
private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration();

Expand All @@ -28,9 +38,11 @@ public void testInitializers() throws Exception {
String[] keras2Inits = initializers(conf2);
WeightInit[] dl4jInits = dl4jInitializers();

for (int i=0; i< dl4jInits.length; i++) {
for (int i=0; i< dl4jInits.length - 1; i++) {
initilizationDenseLayer(conf1, keras1, keras1Inits[i], dl4jInits[i]);
initilizationDenseLayer(conf2, keras2, keras2Inits[i], dl4jInits[i]);

initilizationDenseLayer(conf2, keras2, keras2Inits[dl4jInits.length-1], dl4jInits[dl4jInits.length-1]);
}
}

Expand All @@ -39,6 +51,7 @@ private String[] initializers(KerasLayerConfiguration conf) {
conf.getINIT_GLOROT_NORMAL(),
conf.getINIT_GLOROT_UNIFORM(),
conf.getINIT_LECUN_NORMAL(),
conf.getINIT_LECUN_UNIFORM(),
conf.getINIT_RANDOM_UNIFORM(),
conf.getINIT_HE_NORMAL(),
conf.getINIT_HE_UNIFORM(),
Expand All @@ -50,7 +63,6 @@ private String[] initializers(KerasLayerConfiguration conf) {
// conf.getINIT_CONSTANT(),
// conf.getINIT_NORMAL(),
// conf.getINIT_ORTHOGONAL(),
// conf.getINIT_LECUN_UNIFORM()
};
}

Expand All @@ -59,13 +71,14 @@ private WeightInit[] dl4jInitializers() {
WeightInit.XAVIER,
WeightInit.XAVIER_UNIFORM,
WeightInit.LECUN_NORMAL,
WeightInit.UNIFORM,
WeightInit.LECUN_UNIFORM,
WeightInit.DISTRIBUTION,
WeightInit.RELU,
WeightInit.RELU_UNIFORM,
WeightInit.ONES,
WeightInit.ZERO,
WeightInit.IDENTITY,
WeightInit.XAVIER_UNIFORM // TODO: Variance scaling is incorrectly mapped
WeightInit.VAR_SCALING_NORMAL_FAN_IN,
};
}

Expand All @@ -79,9 +92,28 @@ private void initilizationDenseLayer(KerasLayerConfiguration conf, Integer keras
config.put(conf.getLAYER_FIELD_NAME(), "init_test");
if (kerasVersion == 1) {
config.put(conf.getLAYER_FIELD_INIT(), initializer);
config.put(conf.getLAYER_FIELD_INIT_MEAN(), mean);
config.put(conf.getLAYER_FIELD_INIT_STDDEV(), stdDev);
config.put(conf.getLAYER_FIELD_INIT_SCALE(), scale);
config.put(conf.getLAYER_FIELD_INIT_MINVAL(), minValue);
config.put(conf.getLAYER_FIELD_INIT_MAXVAL(), maxValue);
config.put(conf.getLAYER_FIELD_INIT_VALUE(), value);
config.put(conf.getLAYER_FIELD_INIT_GAIN(), gain);
} else {
Map<String, Object> init = new HashMap<>();
init.put("class_name", initializer);
Map<String, Object> innerInit = new HashMap<>();
innerInit.put(conf.getLAYER_FIELD_INIT_MEAN(), mean);
innerInit.put(conf.getLAYER_FIELD_INIT_STDDEV(), stdDev);
innerInit.put(conf.getLAYER_FIELD_INIT_SCALE(), scale);
innerInit.put(conf.getLAYER_FIELD_INIT_MINVAL(), minValue);
innerInit.put(conf.getLAYER_FIELD_INIT_MAXVAL(), maxValue);
innerInit.put(conf.getLAYER_FIELD_INIT_VALUE(), value);
innerInit.put(conf.getLAYER_FIELD_INIT_GAIN(), gain);
innerInit.put(conf.getLAYER_FIELD_INIT_MODE(), mode);
innerInit.put(conf.getLAYER_FIELD_INIT_DISTRIBUTION(), distribution);

init.put(conf.getLAYER_FIELD_CONFIG(), innerInit);
config.put(conf.getLAYER_FIELD_INIT(), init);
}
config.put(conf.getLAYER_FIELD_OUTPUT_DIM(), 1337);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,5 +53,6 @@
*/
public enum WeightInit {
DISTRIBUTION, ZERO, ONES, SIGMOID_UNIFORM, NORMAL, LECUN_NORMAL, UNIFORM, XAVIER, XAVIER_UNIFORM, XAVIER_FAN_IN, XAVIER_LEGACY, RELU,
RELU_UNIFORM, IDENTITY, LECUN_UNIFORM
RELU_UNIFORM, IDENTITY, LECUN_UNIFORM, VAR_SCALING_NORMAL_FAN_IN, VAR_SCALING_NORMAL_FAN_OUT, VAR_SCALING_NORMAL_FAN_AVG,
VAR_SCALING_UNIFORM_FAN_IN, VAR_SCALING_UNIFORM_FAN_OUT, VAR_SCALING_UNIFORM_FAN_AVG
}
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,31 @@ public static INDArray initWeights(double fanIn, double fanOut, int[] shape, Wei
ret = Nd4j.createUninitialized(shape, order).assign(Nd4j.eye(shape[0]));
}
break;
case VAR_SCALING_NORMAL_FAN_IN:
// TODO: needs to be truncated normal to match keras.
ret = Nd4j.randn(order, shape).divi(FastMath.sqrt(fanIn));
break;
case VAR_SCALING_NORMAL_FAN_OUT:
ret = Nd4j.randn(order, shape).divi(FastMath.sqrt(fanOut));
break;
case VAR_SCALING_NORMAL_FAN_AVG:
ret = Nd4j.randn(order, shape).divi(FastMath.sqrt((fanIn + fanOut) / 2));
break;
case VAR_SCALING_UNIFORM_FAN_IN:
double scalingFanIn = 3.0 / Math.sqrt(fanIn);
ret = Nd4j.rand(shape, Nd4j.getDistributions().createUniform(-scalingFanIn, scalingFanIn));
break;
case VAR_SCALING_UNIFORM_FAN_OUT:
double scalingFanOut = 3.0 / Math.sqrt(fanOut);
ret = Nd4j.rand(shape, Nd4j.getDistributions().createUniform(-scalingFanOut, scalingFanOut));
break;
case VAR_SCALING_UNIFORM_FAN_AVG:
double scalingFanAvg = 3.0 / Math.sqrt((fanIn + fanOut) / 2);
ret = Nd4j.rand(shape, Nd4j.getDistributions().createUniform(-scalingFanAvg, scalingFanAvg));
break;
default:
throw new IllegalStateException("Illegal weight init value: " + initScheme);
}

INDArray flat = Nd4j.toFlattened(order, ret);
if (flat.length() != paramView.length())
throw new RuntimeException("ParamView length does not match initialized weights length (view length: "
Expand Down

0 comments on commit b2def7e

Please sign in to comment.