Skip to content

Commit

Permalink
Merge pull request deeplearning4j#4066 from deeplearning4j/sa_kerasfix
Browse files Browse the repository at this point in the history
Fix Keras 2 import of BatchNormalization
  • Loading branch information
maxpumperla authored Sep 14, 2017
2 parents 1f9f14f + 4abd836 commit f23db9f
Showing 1 changed file with 8 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -151,22 +151,22 @@ public void setWeights(Map<String, INDArray> weights) throws InvalidKerasConfigu
else
throw new InvalidKerasConfigurationException(
"Parameter " + PARAM_NAME_GAMMA + " does not exist in weights");
if (weights.containsKey(PARAM_NAME_RUNNING_MEAN))
this.weights.put(BatchNormalizationParamInitializer.GLOBAL_MEAN, weights.get(PARAM_NAME_RUNNING_MEAN));
if (weights.containsKey(conf.getLAYER_FIELD_BATCHNORMALIZATION_MOVING_MEAN()))
this.weights.put(BatchNormalizationParamInitializer.GLOBAL_MEAN, weights.get(conf.getLAYER_FIELD_BATCHNORMALIZATION_MOVING_MEAN()));
else
throw new InvalidKerasConfigurationException(
"Parameter " + PARAM_NAME_RUNNING_MEAN + " does not exist in weights");
if (weights.containsKey(PARAM_NAME_RUNNING_STD))
this.weights.put(BatchNormalizationParamInitializer.GLOBAL_VAR, weights.get(PARAM_NAME_RUNNING_STD));
"Parameter " + conf.getLAYER_FIELD_BATCHNORMALIZATION_MOVING_MEAN() + " does not exist in weights");
if (weights.containsKey(conf.getLAYER_FIELD_BATCHNORMALIZATION_MOVING_VARIANCE()))
this.weights.put(BatchNormalizationParamInitializer.GLOBAL_VAR, weights.get(conf.getLAYER_FIELD_BATCHNORMALIZATION_MOVING_VARIANCE()));
else
throw new InvalidKerasConfigurationException(
"Parameter " + PARAM_NAME_RUNNING_STD + " does not exist in weights");
"Parameter " + conf.getLAYER_FIELD_BATCHNORMALIZATION_MOVING_VARIANCE() + " does not exist in weights");
if (weights.size() > 4) {
Set<String> paramNames = weights.keySet();
paramNames.remove(PARAM_NAME_BETA);
paramNames.remove(PARAM_NAME_GAMMA);
paramNames.remove(PARAM_NAME_RUNNING_MEAN);
paramNames.remove(PARAM_NAME_RUNNING_STD);
paramNames.remove(conf.getLAYER_FIELD_BATCHNORMALIZATION_MOVING_MEAN());
paramNames.remove(conf.getLAYER_FIELD_BATCHNORMALIZATION_MOVING_VARIANCE());
String unknownParamNames = paramNames.toString();
log.warn("Attemping to set weights for unknown parameters: "
+ unknownParamNames.substring(1, unknownParamNames.length() - 1));
Expand Down

0 comments on commit f23db9f

Please sign in to comment.