Skip to content

Commit

Permalink
add orthogonal and truncated normal
Browse files Browse the repository at this point in the history
  • Loading branch information
maxpumperla committed Sep 13, 2017
1 parent 1267664 commit 140d9d6
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,15 @@
package org.deeplearning4j.nn.modelimport.keras.utils;

import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
import org.deeplearning4j.nn.conf.distribution.UniformDistribution;
import org.deeplearning4j.nn.conf.distribution.*;
import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.api.rng.distribution.impl.OrthogonalDistribution;
import org.nd4j.linalg.primitives.Pair;

import java.util.HashMap;
import java.util.Map;
import java.util.Objects;

/**
* Utility functionality for Keras weight initializers
Expand All @@ -53,6 +49,7 @@ public static Pair<WeightInit, Distribution> mapWeightInitialization(String kera
int kerasMajorVersion)
throws UnsupportedKerasConfigurationException, InvalidKerasConfigurationException {

// TODO: Identity and VarianceScaling need "scale" factor
WeightInit init = null;
Distribution dist = null;
if (kerasInit != null) {
Expand Down Expand Up @@ -111,20 +108,19 @@ public static Pair<WeightInit, Distribution> mapWeightInitialization(String kera
kerasInit.equals(conf.getINIT_ORTHOGONAL_ALIAS())) {
if (kerasMajorVersion == 2) {
double gain = (double) initConfig.get(conf.getLAYER_FIELD_INIT_GAIN());
// TODO: dist = new OrthogonalDistribution(gain);
dist = new OrthogonalDistribution(gain);
} else {
double scale = 1.1;
if (initConfig.containsKey(conf.getLAYER_FIELD_INIT_SCALE()))
scale = (double) initConfig.get(conf.getLAYER_FIELD_INIT_SCALE());
//TODO: dist = new OrthogonalDistribution(scale);
dist = new OrthogonalDistribution(scale);
}
init = WeightInit.DISTRIBUTION;
} else if (kerasInit.equals(conf.getINIT_TRUNCATED_NORMAL()) ||
kerasInit.equals(conf.getINIT_TRUNCATED_NORMAL_ALIAS())) {
double mean = (double) initConfig.get(conf.getLAYER_FIELD_INIT_MEAN());
double stdDev = (double) initConfig.get(conf.getLAYER_FIELD_INIT_STDDEV());
// TODO: map to truncated
dist = new NormalDistribution(mean, stdDev);
dist = new TruncatedNormalDistribution(mean, stdDev);
init = WeightInit.DISTRIBUTION;
} else if (kerasInit.equals(conf.getINIT_IDENTITY()) ||
kerasInit.equals(conf.getINIT_IDENTITY_ALIAS())) {
Expand All @@ -139,7 +135,6 @@ public static Pair<WeightInit, Distribution> mapWeightInitialization(String kera
if (scale != 1.)
log.warn("Scaled identity weight init not supported, setting scale=1");
}
// TODO: map to scaled Identity
init = WeightInit.IDENTITY;
} else if (kerasInit.equals(conf.getINIT_VARIANCE_SCALING())) {
double scale;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,12 @@ private String[] initializers(KerasLayerConfiguration conf) {
conf.getINIT_ONES(),
conf.getINIT_ZERO(),
conf.getINIT_IDENTITY(),
conf.getINIT_NORMAL(),
conf.getINIT_ORTHOGONAL(),
conf.getINIT_VARIANCE_SCALING()
// TODO: add these initializations
// conf.getINIT_CONSTANT(),
// conf.getINIT_NORMAL(),
// conf.getINIT_ORTHOGONAL(),

};
}

Expand All @@ -78,6 +79,8 @@ private WeightInit[] dl4jInitializers() {
WeightInit.ONES,
WeightInit.ZERO,
WeightInit.IDENTITY,
WeightInit.DISTRIBUTION,
WeightInit.DISTRIBUTION,
WeightInit.VAR_SCALING_NORMAL_FAN_IN,
};
}
Expand Down

0 comments on commit 140d9d6

Please sign in to comment.