Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
maxpumperla committed Sep 8, 2017
1 parent a3507a0 commit b515399
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,6 @@ public void testUpsampling() throws Exception {
double[] outArray = new double[] {1., 1., 2., 2., 1., 1., 2., 2., 3., 3., 4., 4., 3., 3., 4., 4.};
INDArray containedExpectedOut = Nd4j.create(outArray, new int[] {1, 1, 4, 4});
INDArray containedInput = getContainedData();
System.out.println(containedInput);
System.out.println(containedExpectedOut);
INDArray input = getData();
Layer layer = getUpsamplingLayer();

Expand All @@ -70,7 +68,7 @@ public void testUpsampling2DBackprop() throws Exception {
Nd4j.create(new double[] {1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.},
new int[] {1, 1, 4, 4});

INDArray expectedContainedEpsilonResult = Nd4j.create(new double[] {1., 1., 1., 1.},
INDArray expectedContainedEpsilonResult = Nd4j.create(new double[] {4., 4., 4., 4.},
new int[] {1, 1, 2, 2});

INDArray input = getContainedData();
Expand All @@ -79,6 +77,7 @@ public void testUpsampling2DBackprop() throws Exception {
layer.activate(input);

Pair<Gradient, INDArray> containedOutput = layer.backpropGradient(expectedContainedEpsilonInput);

assertEquals(expectedContainedEpsilonResult, containedOutput.getSecond());
assertEquals(null, containedOutput.getFirst().getGradientFor("W"));
assertEquals(expectedContainedEpsilonResult.shape().length, containedOutput.getSecond().shape().length);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,19 +86,20 @@ public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon) {
int size = layerConf().getSize();

INDArray outEpsilon = Nd4j.createUninitialized(miniBatch * inDepth * inH * inW);
INDArray reshapedEpsilon = outEpsilon.reshape('c', miniBatch, inDepth, inH, inW);

INDArray forwardOutput = preOutput(true, true);

Gradient gradient = new DefaultGradient();

CustomOp op = DynamicCustomOp.builder("upsampling_bp")
.setIntegerArguments(size)
.setInputs(forwardOutput, epsilon)
.setOutputs(outEpsilon)
.setOutputs(reshapedEpsilon)
.callInplace(false)
.build();
Nd4j.getExecutioner().exec(op);

INDArray reshapedEpsilon = outEpsilon.reshape('c', miniBatch, inDepth, inH, inW);

return new Pair<>(gradient, reshapedEpsilon);
}
Expand Down Expand Up @@ -135,17 +136,18 @@ public INDArray preOutput(boolean training, boolean forBackprop) {
int outW = inW * size;

INDArray output = Nd4j.createUninitialized(miniBatch * inDepth * outH * outW);
INDArray reshapedOutput = output.reshape('c', miniBatch, inDepth, outH, outW);

CustomOp op = DynamicCustomOp.builder("upsampling")
.setIntegerArguments(size)
.setInputs(input)
.setOutputs(output)
.setOutputs(reshapedOutput)
.callInplace(false)
.build();

Nd4j.getExecutioner().exec(op);

return output.reshape('c', miniBatch, inDepth, outH, outW);
return reshapedOutput;
}

@Override
Expand Down

0 comments on commit b515399

Please sign in to comment.