forked from deeplearning4j/deeplearning4j
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request deeplearning4j#4029 from deeplearning4j/mp_upsampl…
…ing_layers Upsampling 2D layers
- Loading branch information
Showing
6 changed files
with
700 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
118 changes: 118 additions & 0 deletions
118
...rning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/Upsampling2DTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
package org.deeplearning4j.nn.layers.convolution; | ||
|
||
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; | ||
import org.deeplearning4j.nn.api.Layer; | ||
import org.deeplearning4j.nn.conf.GradientNormalization; | ||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; | ||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; | ||
import org.deeplearning4j.nn.conf.inputs.InputType; | ||
import org.deeplearning4j.nn.conf.layers.OutputLayer; | ||
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer; | ||
import org.deeplearning4j.nn.conf.layers.Upsampling2D; | ||
import org.deeplearning4j.nn.gradient.Gradient; | ||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; | ||
import org.deeplearning4j.nn.weights.WeightInit; | ||
import org.junit.Test; | ||
import org.nd4j.linalg.activations.Activation; | ||
import org.nd4j.linalg.api.ndarray.INDArray; | ||
import org.nd4j.linalg.dataset.DataSet; | ||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; | ||
import org.nd4j.linalg.factory.Nd4j; | ||
import org.nd4j.linalg.primitives.Pair; | ||
|
||
import java.util.Arrays; | ||
|
||
import static org.junit.Assert.*; | ||
|
||
/** | ||
* @author Max Pumperla | ||
*/ | ||
public class Upsampling2DTest { | ||
|
||
private int nExamples = 1; | ||
private int depth = 20; | ||
private int nChannelsIn = 1; | ||
private int inputWidth = 28; | ||
private int inputHeight = 28; | ||
|
||
private int size = 2; | ||
private int outputWidth = inputWidth * size; | ||
private int outputHeight = inputHeight * size; | ||
|
||
private INDArray epsilon = Nd4j.ones(nExamples, depth, outputHeight, outputWidth); | ||
|
||
|
||
@Test | ||
public void testUpsampling() throws Exception { | ||
|
||
double[] outArray = new double[] {1., 1., 2., 2., 1., 1., 2., 2., 3., 3., 4., 4., 3., 3., 4., 4.}; | ||
INDArray containedExpectedOut = Nd4j.create(outArray, new int[] {1, 1, 4, 4}); | ||
INDArray containedInput = getContainedData(); | ||
INDArray input = getData(); | ||
Layer layer = getUpsamplingLayer(); | ||
|
||
INDArray containedOutput = layer.activate(containedInput); | ||
assertTrue(Arrays.equals(containedExpectedOut.shape(), containedOutput.shape())); | ||
assertEquals(containedExpectedOut, containedOutput); | ||
|
||
INDArray output = layer.activate(input); | ||
assertTrue(Arrays.equals(new int[] {nExamples, nChannelsIn, outputWidth, outputHeight}, | ||
output.shape())); | ||
assertEquals(nChannelsIn, output.size(1), 1e-4); | ||
} | ||
|
||
|
||
@Test | ||
public void testUpsampling2DBackprop() throws Exception { | ||
INDArray expectedContainedEpsilonInput = | ||
Nd4j.create(new double[] {1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.}, | ||
new int[] {1, 1, 4, 4}); | ||
|
||
INDArray expectedContainedEpsilonResult = Nd4j.create(new double[] {4., 4., 4., 4.}, | ||
new int[] {1, 1, 2, 2}); | ||
|
||
INDArray input = getContainedData(); | ||
|
||
Layer layer = getUpsamplingLayer(); | ||
layer.activate(input); | ||
|
||
Pair<Gradient, INDArray> containedOutput = layer.backpropGradient(expectedContainedEpsilonInput); | ||
|
||
assertEquals(expectedContainedEpsilonResult, containedOutput.getSecond()); | ||
assertEquals(null, containedOutput.getFirst().getGradientFor("W")); | ||
assertEquals(expectedContainedEpsilonResult.shape().length, containedOutput.getSecond().shape().length); | ||
|
||
INDArray input2 = getData(); | ||
layer.activate(input2); | ||
int depth = input2.size(1); | ||
|
||
epsilon = Nd4j.ones(5, depth, outputHeight, outputWidth); | ||
|
||
Pair<Gradient, INDArray> out = layer.backpropGradient(epsilon); | ||
assertEquals(input.shape().length, out.getSecond().shape().length); | ||
assertEquals(depth, out.getSecond().size(1)); | ||
} | ||
|
||
|
||
private Layer getUpsamplingLayer() { | ||
NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder() | ||
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer).seed(123) | ||
.layer(new Upsampling2D.Builder(size).build()).build(); | ||
return conf.getLayer().instantiate(conf, null, 0, null, true); | ||
} | ||
|
||
public INDArray getData() throws Exception { | ||
DataSetIterator data = new MnistDataSetIterator(5, 5); | ||
DataSet mnist = data.next(); | ||
nExamples = mnist.numExamples(); | ||
return mnist.getFeatureMatrix().reshape(nExamples, nChannelsIn, inputWidth, inputHeight); | ||
} | ||
|
||
private INDArray getContainedData() { | ||
INDArray ret = Nd4j.create | ||
(new double[] {1., 2., 3., 4.}, | ||
new int[] {1, 1, 2, 2}); | ||
return ret; | ||
} | ||
|
||
} |
Oops, something went wrong.