Skip to content

Commit

Permalink
Keras upsampling 2D layer test
Browse files Browse the repository at this point in the history
  • Loading branch information
maxpumperla committed Sep 11, 2017
1 parent ded241f commit 40be2c9
Showing 1 changed file with 73 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/*-
*
* * Copyright 2017 Skymind,Inc.
* *
* * Licensed under the Apache License, Version 2.0 (the "License");
* * you may not use this file except in compliance with the License.
* * You may obtain a copy of the License at
* *
* * http://www.apache.org/licenses/LICENSE-2.0
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS,
* * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* * See the License for the specific language governing permissions and
* * limitations under the License.
*
*/
package org.deeplearning4j.nn.modelimport.keras.layers.convolution;

import org.deeplearning4j.nn.conf.layers.Upsampling2D;
import org.deeplearning4j.nn.conf.layers.ZeroPadding1DLayer;
import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasUpsampling2D;
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasZeroPadding1D;
import org.junit.Test;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import static org.junit.Assert.assertEquals;

/**
* @author Max Pumperla
*/
public class KerasUpsampling2DTest {

private final String LAYER_NAME = "upsampling_2D_layer";
private int[] size = new int[] {2, 2};

private Integer keras1 = 1;
private Integer keras2 = 2;
private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration();
private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration();

@Test
public void testUpsampling2DLayer() throws Exception {
buildUpsampling2DLayer(conf1, keras1);
buildUpsampling2DLayer(conf2, keras2);
}


public void buildUpsampling2DLayer(KerasLayerConfiguration conf, Integer kerasVersion) throws Exception {
Map<String, Object> layerConfig = new HashMap<>();
layerConfig.put(conf.getLAYER_FIELD_CLASS_NAME(), conf.getLAYER_CLASS_NAME_MAX_POOLING_1D());
Map<String, Object> config = new HashMap<>();
List<Integer> sizeList = new ArrayList<>();
sizeList.add(size[0]);
sizeList.add(size[1]);
config.put(conf.getLAYER_FIELD_UPSAMPLING_2D_SIZE(), sizeList);
config.put(conf.getLAYER_FIELD_NAME(), LAYER_NAME);
layerConfig.put(conf.getLAYER_FIELD_CONFIG(), config);
layerConfig.put(conf.getLAYER_FIELD_KERAS_VERSION(), kerasVersion);

Upsampling2D layer = new KerasUpsampling2D(layerConfig).getUpsampling2DLayer();
assertEquals(LAYER_NAME, layer.getLayerName());
assertEquals(size[0], layer.getSize());
}

}

0 comments on commit 40be2c9

Please sign in to comment.