Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
maxpumperla committed Sep 13, 2017
1 parent 8df3047 commit 1267664
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,13 @@ public static Pair<WeightInit, Distribution> mapWeightInitialization(String kera
// TODO: map to scaled Identity
init = WeightInit.IDENTITY;
} else if (kerasInit.equals(conf.getINIT_VARIANCE_SCALING())) {
int scale = (int) initConfig.get(conf.getLAYER_FIELD_INIT_SCALE());
if (scale != 1)
double scale;
try {
scale = (double) initConfig.get(conf.getLAYER_FIELD_INIT_SCALE());
} catch (Exception e) {
scale = (int) initConfig.get(conf.getLAYER_FIELD_INIT_SCALE());
}
if (scale != 1.)
log.warn("Scaled identity weight init not supported, setting scale=1");
String mode = (String) initConfig.get(conf.getLAYER_FIELD_INIT_MODE());
String distribution = (String) initConfig.get(conf.getLAYER_FIELD_INIT_DISTRIBUTION());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,5 @@ public void buildBatchNormalizationLayer(KerasLayerConfiguration conf, Integer k
BatchNormalization layer = new KerasBatchNormalization(layerConfig).getBatchNormalizationLayer();
assertEquals(LAYER_NAME, layer.getLayerName());
assertEquals(epsilon, layer.getEps(), 0.0);
assertEquals(momentum, ((Nesterovs)layer.getIUpdater()).getMomentum(), 0.0);
}
}

0 comments on commit 1267664

Please sign in to comment.