Skip to content

Commit

Permalink
pass init config to initializer
Browse files Browse the repository at this point in the history
  • Loading branch information
maxpumperla committed Sep 12, 2017
1 parent d5f9485 commit 1f13f10
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ public class KerasInitilizationUtils {
* @return DL4J weight initialization enum
* @see WeightInit
*/
public static Pair<WeightInit, Distribution> mapWeightInitialization(String kerasInit, KerasLayerConfiguration conf)
public static Pair<WeightInit, Distribution> mapWeightInitialization(String kerasInit,
KerasLayerConfiguration conf,
Map<String, Object> initConfig)
throws UnsupportedKerasConfigurationException {

WeightInit init = WeightInit.XAVIER;
Expand Down Expand Up @@ -133,17 +135,19 @@ public static WeightInit getWeightInitFromConfig(Map<String, Object> layerConfig
if (!innerConfig.containsKey(initField))
throw new InvalidKerasConfigurationException("Keras layer is missing " + initField + " field");
String kerasInit = "glorot_normal";
if (kerasMajorVersion != 2)
Map<String, Object> initMap;
if (kerasMajorVersion != 2) {
kerasInit = (String) innerConfig.get(initField);
else {
HashMap initMap = (HashMap) innerConfig.get(initField);
initMap = innerConfig;
} else {
initMap = (HashMap) innerConfig.get(initField);
if (initMap.containsKey("class_name")) {
kerasInit = (String) initMap.get("class_name");
}
}
Pair<WeightInit, Distribution> init;
try {
init = mapWeightInitialization(kerasInit, conf);
init = mapWeightInitialization(kerasInit, conf, initMap);
} catch (UnsupportedKerasConfigurationException e) {
if (enforceTrainingConfig)
throw e;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ private WeightInit[] dl4jInitializers() {
return new WeightInit[] {
WeightInit.XAVIER,
WeightInit.XAVIER_UNIFORM,
WeightInit.NORMAL,
WeightInit.LECUN_NORMAL,
WeightInit.UNIFORM,
WeightInit.RELU,
WeightInit.RELU_UNIFORM,
Expand Down

0 comments on commit 1f13f10

Please sign in to comment.