diff --git a/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasAtrousConvolution1D.java b/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasAtrousConvolution1D.java index 3e1870cf7e3a..42e8fbc215fe 100644 --- a/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasAtrousConvolution1D.java +++ b/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasAtrousConvolution1D.java @@ -18,11 +18,14 @@ package org.deeplearning4j.nn.modelimport.keras.layers.convolutional; import org.deeplearning4j.nn.api.layers.LayerConstraint; +import org.deeplearning4j.nn.conf.distribution.Distribution; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.Convolution1DLayer; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils; +import org.deeplearning4j.nn.weights.WeightInit; +import org.nd4j.linalg.primitives.Pair; import java.util.Map; @@ -80,11 +83,15 @@ public KerasAtrousConvolution1D(Map layerConfig, boolean enforce LayerConstraint weightConstraint = KerasConstraintUtils.getConstraintsFromConfig( layerConfig, conf.getLAYER_FIELD_W_CONSTRAINT(), conf, kerasMajorVersion); + Pair init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(), + enforceTrainingConfig, conf, kerasMajorVersion); + WeightInit weightInit = init.getFirst(); + Distribution distribution = init.getSecond(); + Convolution1DLayer.Builder builder = new Convolution1DLayer.Builder().name(this.layerName) .nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout) .activation(getActivationFromConfig(layerConfig, conf)) - .weightInit(getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(), - enforceTrainingConfig, conf, kerasMajorVersion)) + .weightInit(weightInit) .dilation(getDilationRate(layerConfig, 1, conf, true)) .l1(this.weightL1Regularization).l2(this.weightL2Regularization) .convolutionMode(getConvolutionModeFromConfig(layerConfig, conf)) @@ -92,6 +99,8 @@ public KerasAtrousConvolution1D(Map layerConfig, boolean enforce .hasBias(hasBias) .stride(getStrideFromConfig(layerConfig, 1, conf)[0]); int[] padding = getPaddingFromBorderModeConfig(layerConfig, 1, conf, kerasMajorVersion); + if (distribution != null) + builder.dist(distribution); if (hasBias) builder.biasInit(0.0); if (padding != null) diff --git a/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasAtrousConvolution2D.java b/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasAtrousConvolution2D.java index 14b78e827f51..51d55b57a72a 100644 --- a/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasAtrousConvolution2D.java +++ b/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasAtrousConvolution2D.java @@ -18,11 +18,14 @@ package org.deeplearning4j.nn.modelimport.keras.layers.convolutional; import org.deeplearning4j.nn.api.layers.LayerConstraint; +import org.deeplearning4j.nn.conf.distribution.Distribution; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils; +import org.deeplearning4j.nn.weights.WeightInit; +import org.nd4j.linalg.primitives.Pair; import java.util.Map; @@ -82,11 +85,15 @@ public KerasAtrousConvolution2D(Map layerConfig, boolean enforce LayerConstraint weightConstraint = KerasConstraintUtils.getConstraintsFromConfig( layerConfig, conf.getLAYER_FIELD_W_CONSTRAINT(), conf, kerasMajorVersion); + Pair init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(), + enforceTrainingConfig, conf, kerasMajorVersion); + WeightInit weightInit = init.getFirst(); + Distribution distribution = init.getSecond(); + ConvolutionLayer.Builder builder = new ConvolutionLayer.Builder().name(this.layerName) .nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout) .activation(getActivationFromConfig(layerConfig, conf)) - .weightInit(getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(), - enforceTrainingConfig, conf, kerasMajorVersion)) + .weightInit(weightInit) .dilation(getDilationRate(layerConfig, 2, conf, true)) .l1(this.weightL1Regularization).l2(this.weightL2Regularization) .convolutionMode(getConvolutionModeFromConfig(layerConfig, conf)) @@ -94,6 +101,8 @@ public KerasAtrousConvolution2D(Map layerConfig, boolean enforce .hasBias(hasBias) .stride(getStrideFromConfig(layerConfig, 2, conf)); int[] padding = getPaddingFromBorderModeConfig(layerConfig, 2, conf, kerasMajorVersion); + if (distribution != null) + builder.dist(distribution); if (hasBias) builder.biasInit(0.0); if (padding != null) diff --git a/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolution1D.java b/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolution1D.java index c62a1d83c2e0..b85b19c55c55 100644 --- a/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolution1D.java +++ b/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolution1D.java @@ -20,11 +20,14 @@ import lombok.Data; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.nn.api.layers.LayerConstraint; +import org.deeplearning4j.nn.conf.distribution.Distribution; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.Convolution1DLayer; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils; +import org.deeplearning4j.nn.weights.WeightInit; +import org.nd4j.linalg.primitives.Pair; import java.util.Map; @@ -84,17 +87,23 @@ public KerasConvolution1D(Map layerConfig, boolean enforceTraini LayerConstraint weightConstraint = KerasConstraintUtils.getConstraintsFromConfig( layerConfig, conf.getLAYER_FIELD_W_CONSTRAINT(), conf, kerasMajorVersion); + Pair init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(), + enforceTrainingConfig, conf, kerasMajorVersion); + WeightInit weightInit = init.getFirst(); + Distribution distribution = init.getSecond(); + Convolution1DLayer.Builder builder = new Convolution1DLayer.Builder().name(this.layerName) .nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout) .activation(getActivationFromConfig(layerConfig, conf)) - .weightInit(getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(), - enforceTrainingConfig, conf, kerasMajorVersion)) + .weightInit(weightInit) .l1(this.weightL1Regularization).l2(this.weightL2Regularization) .convolutionMode(getConvolutionModeFromConfig(layerConfig, conf)) .kernelSize(getKernelSizeFromConfig(layerConfig, 1, conf, kerasMajorVersion)[0]) .hasBias(hasBias) .stride(getStrideFromConfig(layerConfig, 1, conf)[0]); int[] padding = getPaddingFromBorderModeConfig(layerConfig, 1, conf, kerasMajorVersion); + if (distribution != null) + builder.dist(distribution); if (hasBias) builder.biasInit(0.0); if (padding != null) diff --git a/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolution2D.java b/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolution2D.java index 74b14a9a1d20..1dbab9307c65 100644 --- a/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolution2D.java +++ b/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolution2D.java @@ -20,11 +20,14 @@ import lombok.Data; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.nn.api.layers.LayerConstraint; +import org.deeplearning4j.nn.conf.distribution.Distribution; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils; +import org.deeplearning4j.nn.weights.WeightInit; +import org.nd4j.linalg.primitives.Pair; import java.util.Map; @@ -82,6 +85,11 @@ public KerasConvolution2D(Map layerConfig, boolean enforceTraini numTrainableParams = hasBias ? 2 : 1; int[] dilationRate = getDilationRate(layerConfig, 2, conf, false); + Pair init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(), + enforceTrainingConfig, conf, kerasMajorVersion); + WeightInit weightInit = init.getFirst(); + Distribution distribution = init.getSecond(); + LayerConstraint biasConstraint = KerasConstraintUtils.getConstraintsFromConfig( layerConfig, conf.getLAYER_FIELD_B_CONSTRAINT(), conf, kerasMajorVersion); LayerConstraint weightConstraint = KerasConstraintUtils.getConstraintsFromConfig( @@ -90,14 +98,15 @@ public KerasConvolution2D(Map layerConfig, boolean enforceTraini ConvolutionLayer.Builder builder = new ConvolutionLayer.Builder().name(this.layerName) .nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout) .activation(getActivationFromConfig(layerConfig, conf)) - .weightInit(getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(), - enforceTrainingConfig, conf, kerasMajorVersion)) + .weightInit(weightInit) .l1(this.weightL1Regularization).l2(this.weightL2Regularization) .convolutionMode(getConvolutionModeFromConfig(layerConfig, conf)) .kernelSize(getKernelSizeFromConfig(layerConfig, 2, conf, kerasMajorVersion)) .hasBias(hasBias) .stride(getStrideFromConfig(layerConfig, 2, conf)); int[] padding = getPaddingFromBorderModeConfig(layerConfig, 2, conf, kerasMajorVersion); + if (distribution != null) + builder.dist(distribution); if (hasBias) builder.biasInit(0.0); if (padding != null) diff --git a/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasDense.java b/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasDense.java index 564bb0252973..8fff664227fb 100644 --- a/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasDense.java +++ b/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasDense.java @@ -4,6 +4,7 @@ import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.nn.api.layers.LayerConstraint; import org.deeplearning4j.nn.conf.InputPreProcessor; +import org.deeplearning4j.nn.conf.distribution.Distribution; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.modelimport.keras.KerasLayer; @@ -11,7 +12,9 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils; import org.deeplearning4j.nn.params.DefaultParamInitializer; +import org.deeplearning4j.nn.weights.WeightInit; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.primitives.Pair; import java.util.HashMap; import java.util.Map; @@ -76,13 +79,19 @@ public KerasDense(Map layerConfig, boolean enforceTrainingConfig LayerConstraint weightConstraint = KerasConstraintUtils.getConstraintsFromConfig( layerConfig, conf.getLAYER_FIELD_W_CONSTRAINT(), conf, kerasMajorVersion); + Pair init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(), + enforceTrainingConfig, conf, kerasMajorVersion); + WeightInit weightInit = init.getFirst(); + Distribution distribution = init.getSecond(); + DenseLayer.Builder builder = new DenseLayer.Builder().name(this.layerName).nOut(getNOutFromConfig(layerConfig, conf)) .dropOut(this.dropout).activation(getActivationFromConfig(layerConfig, conf)) - .weightInit(getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(), - enforceTrainingConfig, conf, kerasMajorVersion)) + .weightInit(weightInit) .biasInit(0.0) .l1(this.weightL1Regularization).l2(this.weightL2Regularization) .hasBias(hasBias); + if (distribution != null) + builder.dist(distribution); if (biasConstraint != null) builder.constrainBias(biasConstraint); if (weightConstraint != null) diff --git a/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/embeddings/KerasEmbedding.java b/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/embeddings/KerasEmbedding.java index 8d993342b85d..b25d72245294 100644 --- a/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/embeddings/KerasEmbedding.java +++ b/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/embeddings/KerasEmbedding.java @@ -4,6 +4,7 @@ import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.nn.api.layers.LayerConstraint; import org.deeplearning4j.nn.conf.InputPreProcessor; +import org.deeplearning4j.nn.conf.distribution.Distribution; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.EmbeddingLayer; import org.deeplearning4j.nn.modelimport.keras.KerasLayer; @@ -12,8 +13,10 @@ import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils; import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils; import org.deeplearning4j.nn.params.DefaultParamInitializer; +import org.deeplearning4j.nn.weights.WeightInit; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.primitives.Pair; import java.util.HashMap; import java.util.Map; @@ -70,15 +73,21 @@ public KerasEmbedding(Map layerConfig, boolean enforceTrainingCo "in DL4J, apply masking as a pre-processing step to your input." + "See https://deeplearning4j.org/usingrnns#masking for more on this."); + Pair init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_EMBEDDING_INIT(), + enforceTrainingConfig, conf, kerasMajorVersion); + WeightInit weightInit = init.getFirst(); + Distribution distribution = init.getSecond(); + LayerConstraint embeddingConstraint = KerasConstraintUtils.getConstraintsFromConfig( layerConfig, conf.getLAYER_FIELD_EMBEDDINGS_CONSTRAINT(), conf, kerasMajorVersion); EmbeddingLayer.Builder builder = new EmbeddingLayer.Builder().name(this.layerName).nIn(inputDim) .nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout).activation(Activation.IDENTITY) - .weightInit(getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_EMBEDDING_INIT(), - enforceTrainingConfig, conf, kerasMajorVersion)) + .weightInit(weightInit) .biasInit(0.0) .l1(this.weightL1Regularization).l2(this.weightL2Regularization).hasBias(false); + if (distribution != null) + builder.dist(distribution); if (embeddingConstraint != null) builder.constrainWeights(embeddingConstraint); this.layer = builder.build(); diff --git a/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasLstm.java b/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasLstm.java index 384cbd84136c..ded9d34564e4 100644 --- a/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasLstm.java +++ b/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasLstm.java @@ -4,6 +4,7 @@ import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.nn.api.layers.LayerConstraint; import org.deeplearning4j.nn.conf.InputPreProcessor; +import org.deeplearning4j.nn.conf.distribution.Distribution; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.LSTM; import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToRnnPreProcessor; @@ -19,6 +20,7 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.INDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex; +import org.nd4j.linalg.primitives.Pair; import java.util.HashMap; import java.util.Map; @@ -98,17 +100,24 @@ public KerasLstm(Map layerConfig) public KerasLstm(Map layerConfig, boolean enforceTrainingConfig) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException { super(layerConfig, enforceTrainingConfig); - WeightInit weightInit = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(), + + Pair init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(), enforceTrainingConfig, conf, kerasMajorVersion); - WeightInit recurrentWeightInit = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INNER_INIT(), + WeightInit weightInit = init.getFirst(); + Distribution distribution = init.getSecond(); + + Pair recurrentInit = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INNER_INIT(), enforceTrainingConfig, conf, kerasMajorVersion); + WeightInit recurrentWeightInit = recurrentInit.getFirst(); + Distribution recurrentDistribution = recurrentInit.getSecond(); + Map innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf); Boolean returnSequences = (Boolean) innerConfig.get(conf.getLAYER_FIELD_RETURN_SEQUENCES()); if (!returnSequences) { log.warn("Keras setting 'return_sequences = False' is not properly supported," + "DL4J's LSTM layer returns sequences by default"); } - if (weightInit != recurrentWeightInit) + if (weightInit != recurrentWeightInit || distribution != recurrentDistribution) if (enforceTrainingConfig) throw new UnsupportedKerasConfigurationException( "Specifying different initialization for recurrent weights not supported."); @@ -135,6 +144,8 @@ public KerasLstm(Map layerConfig, boolean enforceTrainingConfig) .biasInit(0.0) .l1(this.weightL1Regularization) .l2(this.weightL2Regularization); + if (distribution != null) + builder.dist(distribution); if (biasConstraint != null) builder.constrainBias(biasConstraint); if (weightConstraint != null) diff --git a/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasInitilizationUtils.java b/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasInitilizationUtils.java index 4260172c0935..c36a4d3acdf4 100644 --- a/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasInitilizationUtils.java +++ b/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasInitilizationUtils.java @@ -87,7 +87,8 @@ public static Pair mapWeightInitialization(String kera dist = new UniformDistribution(-scale, scale); } init = WeightInit.DISTRIBUTION; - } else if (kerasInit.equals(conf.getINIT_RANDOM_NORMAL()) || + } else if (kerasInit.equals(conf.getINIT_NORMAL()) || + kerasInit.equals(conf.getINIT_RANDOM_NORMAL()) || kerasInit.equals(conf.getINIT_RANDOM_NORMAL_ALIAS())) { if (kerasMajorVersion == 2) { double mean = (double) initConfig.get(conf.getLAYER_FIELD_INIT_MEAN()); @@ -102,8 +103,8 @@ public static Pair mapWeightInitialization(String kera init = WeightInit.DISTRIBUTION; } else if (kerasInit.equals(conf.getINIT_CONSTANT()) || kerasInit.equals(conf.getINIT_CONSTANT_ALIAS())) { - // FIXME: CONSTANT keras.initializers.Constant(value=0) - init = WeightInit.ZERO; + // TODO: CONSTANT keras.initializers.Constant(value=0) + init = WeightInit.ONES; } else if (kerasInit.equals(conf.getINIT_ORTHOGONAL()) || kerasInit.equals(conf.getINIT_ORTHOGONAL_ALIAS())) { if (kerasMajorVersion == 2) { @@ -189,7 +190,7 @@ public static Pair mapWeightInitialization(String kera * @throws InvalidKerasConfigurationException * @throws UnsupportedKerasConfigurationException */ - public static WeightInit getWeightInitFromConfig(Map layerConfig, String initField, + public static Pair getWeightInitFromConfig(Map layerConfig, String initField, boolean enforceTrainingConfig, KerasLayerConfiguration conf, int kerasMajorVersion) @@ -222,7 +223,7 @@ public static WeightInit getWeightInitFromConfig(Map layerConfig log.warn("Unknown weight initializer " + kerasInit + " (Using XAVIER instead)."); } } - return init.getFirst(); + return init; } } diff --git a/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/KerasInitilizationTest.java b/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/KerasInitilizationTest.java index 351ea60957ce..0ca933b6460d 100644 --- a/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/KerasInitilizationTest.java +++ b/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/KerasInitilizationTest.java @@ -60,9 +60,8 @@ private String[] initializers(KerasLayerConfiguration conf) { conf.getINIT_IDENTITY(), conf.getINIT_NORMAL(), conf.getINIT_ORTHOGONAL(), + conf.getINIT_CONSTANT(), conf.getINIT_VARIANCE_SCALING() - // TODO: add these initializations - // conf.getINIT_CONSTANT(), }; } @@ -81,6 +80,7 @@ private WeightInit[] dl4jInitializers() { WeightInit.IDENTITY, WeightInit.DISTRIBUTION, WeightInit.DISTRIBUTION, + WeightInit.ONES, WeightInit.VAR_SCALING_NORMAL_FAN_IN, }; }