Skip to content

Commit

Permalink
init changes to keras import / document todos
Browse files Browse the repository at this point in the history
  • Loading branch information
maxpumperla committed Sep 12, 2017
1 parent e044b17 commit d5f9485
Showing 1 changed file with 47 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@
package org.deeplearning4j.nn.modelimport.keras.utils;

import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.primitives.Pair;

import java.util.HashMap;
import java.util.Map;
Expand All @@ -41,21 +43,21 @@ public class KerasInitilizationUtils {
* @return DL4J weight initialization enum
* @see WeightInit
*/
public static WeightInit mapWeightInitialization(String kerasInit, KerasLayerConfiguration conf)
public static Pair<WeightInit, Distribution> mapWeightInitialization(String kerasInit, KerasLayerConfiguration conf)
throws UnsupportedKerasConfigurationException {

WeightInit init = WeightInit.XAVIER;
Distribution dist = null;
if (kerasInit != null) {
if (kerasInit.equals(conf.getINIT_GLOROT_NORMAL())) {
init = WeightInit.XAVIER;
} else if (kerasInit.equals(conf.getINIT_GLOROT_UNIFORM())) {
init = WeightInit.XAVIER_UNIFORM;
} else if (kerasInit.equals(conf.getINIT_LECUN_NORMAL())) {
init = WeightInit.NORMAL;
} else if (kerasInit.equals(conf.getINIT_UNIFORM()) ||
kerasInit.equals(conf.getINIT_RANDOM_UNIFORM()) ||
kerasInit.equals(conf.getINIT_RANDOM_UNIFORM_ALIAS())) {
init = WeightInit.UNIFORM;
} else if (kerasInit.equals(conf.getINIT_HE_NORMAL())) {
init = WeightInit.LECUN_NORMAL;
} else if (kerasInit.equals(conf.getINIT_LECUN_UNIFORM())) {
init = WeightInit.LECUN_UNIFORM;
} else if (kerasInit.equals(conf.getINIT_HE_NORMAL())) {
init = WeightInit.RELU;
} else if (kerasInit.equals(conf.getINIT_HE_UNIFORM())) {
init = WeightInit.RELU_UNIFORM;
Expand All @@ -67,16 +69,50 @@ public static WeightInit mapWeightInitialization(String kerasInit, KerasLayerCon
kerasInit.equals(conf.getINIT_ZEROS()) ||
kerasInit.equals(conf.getINIT_ZEROS_ALIAS())) {
init = WeightInit.ZERO;
} else if (kerasInit.equals(conf.getINIT_CONSTANT()) ||
kerasInit.equals(conf.getINIT_CONSTANT_ALIAS())) {
// FIXME: CONSTANT
// keras.initializers.Constant(value=0)
init = WeightInit.ZERO;
} else if (kerasInit.equals(conf.getINIT_UNIFORM()) ||
kerasInit.equals(conf.getINIT_RANDOM_UNIFORM()) ||
kerasInit.equals(conf.getINIT_RANDOM_UNIFORM_ALIAS())) {
// FIXME: read minval and maxval from config
// keras.initializers.RandomUniform(minval=-0.05, maxval=0.05, seed=None) keras1: scale
init = WeightInit.UNIFORM;
} else if (kerasInit.equals(conf.getINIT_RANDOM_NORMAL()) ||
kerasInit.equals(conf.getINIT_RANDOM_NORMAL_ALIAS())) {
// FIXME: read mean and stddev from config
// keras.initializers.RandomNormal(mean=0.0, stddev=0.05, seed=None)
init = WeightInit.DISTRIBUTION;
} else if (kerasInit.equals(conf.getINIT_ORTHOGONAL()) ||
kerasInit.equals(conf.getINIT_ORTHOGONAL_ALIAS())) {
// TODO keras.initializers.Orthogonal(gain=1.0, seed=None)
init = WeightInit.DISTRIBUTION;
} else if (kerasInit.equals(conf.getINIT_TRUNCATED_NORMAL()) ||
kerasInit.equals(conf.getINIT_TRUNCATED_NORMAL_ALIAS())) {
// FIXME: read mean and stddev from config
// keras.initializers.TruncatedNormal(mean=0.0, stddev=0.05, seed=None) keras1: no mean, always 0, stddev is scale
init = WeightInit.DISTRIBUTION;
} else if (kerasInit.equals(conf.getINIT_IDENTITY()) ||
kerasInit.equals(conf.getINIT_IDENTITY_ALIAS())) {
// TODO: takes gain/scale parameter
// keras.initializers.Identity(gain=1.0) keras1: scale
init = WeightInit.IDENTITY;
} else if (kerasInit.equals(conf.getINIT_VARIANCE_SCALING())) {
// keras.initializers.VarianceScaling(scale=1.0, mode='fan_in', distribution='normal', seed=None)
// With distribution="normal", samples are drawn from a truncated normal distribution centered on zero, with stddev = sqrt(scale / n) where n is:
// number of input units in the weight tensor, if mode = "fan_in"
// number of output units, if mode = "fan_out"
// average of the numbers of input and output units, if mode = "fan_avg"
// With distribution="uniform", samples are drawn from a uniform distribution within [-limit, limit], with limit = sqrt(3 * scale / n).

init = WeightInit.XAVIER_UNIFORM;
} else {
throw new UnsupportedKerasConfigurationException("Unknown keras weight initializer " + kerasInit);
}
}
return init;
return new Pair<>(init, dist);
}

/**
Expand Down Expand Up @@ -105,18 +141,18 @@ public static WeightInit getWeightInitFromConfig(Map<String, Object> layerConfig
kerasInit = (String) initMap.get("class_name");
}
}
WeightInit init;
Pair<WeightInit, Distribution> init;
try {
init = mapWeightInitialization(kerasInit, conf);
} catch (UnsupportedKerasConfigurationException e) {
if (enforceTrainingConfig)
throw e;
else {
init = WeightInit.XAVIER;
init = new Pair<>(WeightInit.XAVIER, null);
log.warn("Unknown weight initializer " + kerasInit + " (Using XAVIER instead).");
}
}
return init;
return init.getFirst();
}

}

0 comments on commit d5f9485

Please sign in to comment.