Skip to content

Commit

Permalink
Merge pull request deeplearning4j#1486 from deeplearning4j/ab_moretes…
Browse files Browse the repository at this point in the history
…tfixes

More test fixes; various config validation fixes; dropout fix, etc
  • Loading branch information
AlexDBlack committed May 2, 2016
2 parents d3714f3 + 67b2819 commit 6a19b44
Show file tree
Hide file tree
Showing 38 changed files with 80 additions and 534 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -792,17 +792,20 @@ private void learningRateValidation(String layerName){
private void generalValidation(String layerName){
if (useDropConnect && (Double.isNaN(dropOut) && (Double.isNaN(layer.getDropOut()))))
throw new IllegalStateException(layerName +" dropConnect is set to true but dropout rate has not been added to configuration.");
if (useRegularization && (Double.isNaN(l1) && layer != null && Double.isNaN(layer.getL1()) && Double.isNaN(l2) && Double.isNaN(layer.getL2())))
log.warn(layerName +" regularization is set to true but l1 or l2 has not been added to configuration.");
if(useDropConnect && dropOut == 0.0) throw new IllegalStateException(layerName + " dropConnect is set to true but dropout rate is set to 0.0");
if (useRegularization && (Double.isNaN(l1) && layer != null && Double.isNaN(layer.getL1())
&& Double.isNaN(l2) && Double.isNaN(layer.getL2())
&& (Double.isNaN(dropOut) || dropOut==0.0) && (Double.isNaN(layer.getDropOut()) || layer.getDropOut() == 0.0)))
log.warn(layerName +" regularization is set to true but l1, l2 or dropout has not been added to configuration.");
// CompGraph may have null layers TODO confirm valid configuration
if (layer != null) {
if (useRegularization) {
if (!Double.isNaN(l1) && Double.isNaN(layer.getL1()))
layer.setL1(l1);
if (!Double.isNaN(l2) && Double.isNaN(layer.getL2()))
layer.setL2(l2);
} else if (!Double.isNaN(l1) || !Double.isNaN(layer.getL1()) || !Double.isNaN(l2) || !Double.isNaN(layer.getL2()))
log.warn(layerName +" l1 or l2 has been added to configuration but useRegularization is set to false.");
} else if (!useRegularization && (!Double.isNaN(l1) || !Double.isNaN(layer.getL1()) || !Double.isNaN(l2) || !Double.isNaN(layer.getL2())) )
throw new IllegalStateException(layerName +" l1 or l2 has been added to configuration but useRegularization is set to false.");
if (Double.isNaN(l2) && Double.isNaN(layer.getL2()))
layer.setL2(0.0);
if (Double.isNaN(l1) && Double.isNaN(layer.getL1()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
@JsonSubTypes.Type(value = RnnOutputLayer.class, name = "rnnoutput"),
@JsonSubTypes.Type(value = RBM.class, name = "RBM"),
@JsonSubTypes.Type(value = DenseLayer.class, name = "dense"),
@JsonSubTypes.Type(value = RecursiveAutoEncoder.class, name = "recursiveAutoEncoder"),
@JsonSubTypes.Type(value = SubsamplingLayer.class, name = "subsampling"),
@JsonSubTypes.Type(value = BatchNormalization.class, name = "batchNormalization"),
@JsonSubTypes.Type(value = LocalResponseNormalization.class, name = "localResponseNormalization"),
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ public Gradient calcGradient(Gradient layerError, INDArray activation) {
@Override
public Pair<Gradient,INDArray> backpropGradient(INDArray epsilon) {
//If this layer is layer L, then epsilon is (w^(L+1)*(d^(L+1))^T) (or equivalent)
INDArray z = preOutput(input);
INDArray z = preOutput(true); //Note: using preOutput(INDArray) can't be used as this does a setInput(input) and resets the 'appliedDropout' flag
INDArray activationDerivative = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(conf().getLayer().getActivationFunction(), z).derivative());
INDArray delta = epsilon.muli(activationDerivative);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,6 @@ protected Layer getInstance(NeuralNetConfiguration conf) {
return new org.deeplearning4j.nn.layers.OutputLayer(conf);
if(layerConfig instanceof org.deeplearning4j.nn.conf.layers.RnnOutputLayer)
return new org.deeplearning4j.nn.layers.recurrent.RnnOutputLayer(conf);
if(layerConfig instanceof org.deeplearning4j.nn.conf.layers.RecursiveAutoEncoder)
return new org.deeplearning4j.nn.layers.feedforward.autoencoder.recursive.RecursiveAutoEncoder(conf);
if(layerConfig instanceof org.deeplearning4j.nn.conf.layers.ConvolutionLayer)
return new org.deeplearning4j.nn.layers.convolution.ConvolutionLayer(conf);
if(layerConfig instanceof org.deeplearning4j.nn.conf.layers.SubsamplingLayer)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@ else if (clazz.equals(GravesBidirectionalLSTM.class))
return new GravesBidirectionalLSTMLayerFactory(GravesBidirectionalLSTM.class);
else if(clazz.equals(GRU.class))
return new GRULayerFactory(GRU.class);
else if(RecursiveAutoEncoder.class.isAssignableFrom(clazz))
return new RecursiveAutoEncoderLayerFactory(RecursiveAutoEncoder.class);
else if(BasePretrainNetwork.class.isAssignableFrom(clazz))
return new PretrainLayerFactory(clazz);
else if(ConvolutionLayer.class.isAssignableFrom(clazz))
Expand Down

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -482,4 +482,4 @@ public int hashCode() {
result = 31 * result + end;
return result;
}
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ public void applyLrDecayPolicy(LearningRatePolicy decay, Layer layer, int iterat
conf.setLearningRateByParam(variable, lr * Math.pow(decayRate, Math.floor(iteration/conf.getLrPolicySteps())));
break;
case Poly:
conf.setLearningRateByParam(variable, lr * Math.pow((1 - (iteration * 1.0)/conf.getNumIterations()), conf.getLrPolicyPower()));
conf.setLearningRateByParam(variable, lr * Math.pow((1 - ((double)iteration)/conf.getNumIterations()), conf.getLrPolicyPower()));
break;
case Sigmoid:
conf.setLearningRateByParam(variable, lr / (1 + Math.exp(-decayRate * (iteration - conf.getLrPolicySteps()))));
Expand Down
Loading

0 comments on commit 6a19b44

Please sign in to comment.