Skip to content

Commit

Permalink
finish test and factor out upsampling base class
Browse files Browse the repository at this point in the history
  • Loading branch information
maxpumperla committed Sep 12, 2017
1 parent f9c2c92 commit e4a4b47
Show file tree
Hide file tree
Showing 9 changed files with 331 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ public class KerasLayerConfiguration {
private final String LAYER_CLASS_NAME_CONVOLUTION_1D = ""; // 1: Convolution1D, 2: Conv1D
private final String LAYER_CLASS_NAME_CONVOLUTION_2D = ""; // 1: Convolution2D, 2: Conv2D
private final String LAYER_CLASS_NAME_LEAKY_RELU = "LeakyReLU";
private final String LAYER_CLASS_NAME_UPSAMPLING_1D = "UpSampling1D";
private final String LAYER_CLASS_NAME_UPSAMPLING_2D = "UpSampling2D";


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/*-
*
* * Copyright 2017 Skymind,Inc.
* *
* * Licensed under the Apache License, Version 2.0 (the "License");
* * you may not use this file except in compliance with the License.
* * You may obtain a copy of the License at
* *
* * http://www.apache.org/licenses/LICENSE-2.0
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS,
* * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* * See the License for the specific language governing permissions and
* * limitations under the License.
*
*/
package org.deeplearning4j.nn.modelimport.keras.layers.convolutional;

import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.Upsampling1D;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;

import java.util.Map;


/**
* Keras Upsampling1D layer support
*
* @author Max Pumperla
*/
public class KerasUpsampling1D extends KerasLayer {

/**
* Constructor from parsed Keras layer configuration dictionary.
*
* @param layerConfig dictionary containing Keras layer configuration.
* @throws InvalidKerasConfigurationException Invalid Keras configuration exception
* @throws UnsupportedKerasConfigurationException Unsupported Keras configuration exception
*/
public KerasUpsampling1D(Map<String, Object> layerConfig)
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
this(layerConfig, true);
}

/**
* Constructor from parsed Keras layer configuration dictionary.
*
* @param layerConfig dictionary containing Keras layer configuration
* @param enforceTrainingConfig whether to enforce training-related configuration options
* @throws InvalidKerasConfigurationException Invalid Keras configuration exception
* @throws UnsupportedKerasConfigurationException Invalid Keras configuration exception
*/
public KerasUpsampling1D(Map<String, Object> layerConfig, boolean enforceTrainingConfig)
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
super(layerConfig, enforceTrainingConfig);

int[] size = KerasConvolutionUtils.getUpsamplingSizeFromConfig(layerConfig, 1, conf);

Upsampling1D.Builder builder = new Upsampling1D.Builder()
.name(this.layerName)
.dropOut(this.dropout)
.size(size[0]);

this.layer = builder.build();
this.vertex = null;
}

/**
* Get DL4J Upsampling1D layer.
*
* @return Upsampling1D layer
*/
public Upsampling1D getUpsampling1DLayer() {
return (Upsampling1D) this.layer;
}

/**
* Get layer output type.
*
* @param inputType Array of InputTypes
* @return output type as InputType
* @throws InvalidKerasConfigurationException
*/
@Override
public InputType getOutputType(InputType... inputType) throws InvalidKerasConfigurationException {
if (inputType.length > 1)
throw new InvalidKerasConfigurationException(
"Keras Subsampling layer accepts only one input (received " + inputType.length + ")");
return this.getUpsampling1DLayer().getOutputType(-1, inputType[0]);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*-
*
* * Copyright 2017 Skymind,Inc.
* *
* * Licensed under the Apache License, Version 2.0 (the "License");
* * you may not use this file except in compliance with the License.
* * You may obtain a copy of the License at
* *
* * http://www.apache.org/licenses/LICENSE-2.0
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS,
* * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* * See the License for the specific language governing permissions and
* * limitations under the License.
*
*/
package org.deeplearning4j.nn.modelimport.keras.layers.convolution;

import org.deeplearning4j.nn.conf.layers.Upsampling1D;
import org.deeplearning4j.nn.conf.layers.Upsampling2D;
import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasUpsampling1D;
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasUpsampling2D;
import org.junit.Test;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import static org.junit.Assert.assertEquals;

/**
* @author Max Pumperla
*/
public class KerasUpsampling1DTest {

private final String LAYER_NAME = "upsampling_1D_layer";
private int size = 4;

private Integer keras1 = 1;
private Integer keras2 = 2;
private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration();
private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration();

@Test
public void testUpsampling1DLayer() throws Exception {
buildUpsampling1DLayer(conf1, keras1);
buildUpsampling1DLayer(conf2, keras2);
}

public void buildUpsampling1DLayer(KerasLayerConfiguration conf, Integer kerasVersion) throws Exception {
Map<String, Object> layerConfig = new HashMap<>();
layerConfig.put(conf.getLAYER_FIELD_CLASS_NAME(), conf.getLAYER_CLASS_NAME_UPSAMPLING_1D());
Map<String, Object> config = new HashMap<>();
config.put(conf.getLAYER_FIELD_UPSAMPLING_1D_SIZE(), size);
config.put(conf.getLAYER_FIELD_NAME(), LAYER_NAME);
layerConfig.put(conf.getLAYER_FIELD_CONFIG(), config);
layerConfig.put(conf.getLAYER_FIELD_KERAS_VERSION(), kerasVersion);

Upsampling1D layer = new KerasUpsampling1D(layerConfig).getUpsampling1DLayer();
assertEquals(LAYER_NAME, layer.getLayerName());
assertEquals(size, layer.getSize());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ public void testUpsampling2DLayer() throws Exception {

public void buildUpsampling2DLayer(KerasLayerConfiguration conf, Integer kerasVersion) throws Exception {
Map<String, Object> layerConfig = new HashMap<>();
layerConfig.put(conf.getLAYER_FIELD_CLASS_NAME(), conf.getLAYER_CLASS_NAME_MAX_POOLING_1D());
layerConfig.put(conf.getLAYER_FIELD_CLASS_NAME(), conf.getLAYER_CLASS_NAME_UPSAMPLING_2D());
Map<String, Object> config = new HashMap<>();
List<Integer> sizeList = new ArrayList<>();
sizeList.add(size[0]);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
/*-
*
* * Copyright 2017 Skymind,Inc.
* *
* * Licensed under the Apache License, Version 2.0 (the "License");
* * you may not use this file except in compliance with the License.
* * You may obtain a copy of the License at
* *
* * http://www.apache.org/licenses/LICENSE-2.0
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS,
* * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* * See the License for the specific language governing permissions and
* * limitations under the License.
*
*/
package org.deeplearning4j.nn.conf.layers;

import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import lombok.ToString;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
import org.deeplearning4j.nn.conf.memory.MemoryReport;
import org.deeplearning4j.nn.params.EmptyParamInitializer;
import org.deeplearning4j.optimize.api.IterationListener;
import org.nd4j.linalg.api.ndarray.INDArray;

import java.util.Collection;
import java.util.Map;

/**
* Upsampling base layer
*
* @author Max Pumperla
*/

@Data
@NoArgsConstructor
@ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true)
public abstract class BaseUpsamplingLayer extends Layer {

protected int size;

protected BaseUpsamplingLayer(UpsamplingBuilder builder) {
super(builder);
this.size = builder.size;
}

@Override
public BaseUpsamplingLayer clone() {
BaseUpsamplingLayer clone = (BaseUpsamplingLayer) super.clone();
return clone;
}

@Override
public ParamInitializer initializer() {
return EmptyParamInitializer.getInstance();
}


@Override
public void setNIn(InputType inputType, boolean override) {
//No op: upsampling layer doesn't have nIn value
}

@Override
public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
if (inputType == null) {
throw new IllegalStateException("Invalid input for Upsampling layer (layer name=\"" + getLayerName()
+ "\"): input is null");
}
return InputTypeUtil.getPreProcessorForInputTypeCnnLayers(inputType, getLayerName());
}

@Override
public double getL1ByParam(String paramName) {
//Not applicable
return 0;
}

@Override
public double getL2ByParam(String paramName) {
//Not applicable
return 0;
}

@Override
public double getLearningRateByParam(String paramName) {
//Not applicable
return 0;
}

@Override
public boolean isPretrainParam(String paramName) {
throw new UnsupportedOperationException("UpsamplingLayer does not contain parameters");
}


@NoArgsConstructor
protected static abstract class UpsamplingBuilder<T extends UpsamplingBuilder<T>>
extends Layer.Builder<T> {
protected int size = 1;

protected UpsamplingBuilder(int size) {
this.size = size;
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,11 @@
@NoArgsConstructor
@ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true)
public class Upsampling1D extends Upsampling2D {
public class Upsampling1D extends BaseUpsamplingLayer {

protected int size;

protected Upsampling1D(Builder builder) {
protected Upsampling1D(UpsamplingBuilder builder) {
super(builder);
this.size = builder.size;
}
Expand All @@ -74,8 +74,46 @@ public Upsampling1D clone() {
return clone;
}

@Override
public InputType getOutputType(int layerIndex, InputType inputType) {
if (inputType == null || inputType.getType() != InputType.Type.RNN) {
throw new IllegalStateException("Invalid input for 1D Upsampling layer (layer index = " + layerIndex
+ ", layer name = \"" + getLayerName() + "\"): expect RNN input type with size > 0. Got: "
+ inputType);
}
InputType.InputTypeRecurrent recurrent = (InputType.InputTypeRecurrent) inputType;
return InputType.recurrent(recurrent.getSize(), recurrent.getTimeSeriesLength());
}

@Override
public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
if (inputType == null) {
throw new IllegalStateException("Invalid input for Upsampling layer (layer name=\"" + getLayerName()
+ "\"): input is null");
}
return InputTypeUtil.getPreProcessorForInputTypeCnnLayers(inputType, getLayerName());
}

@Override
public LayerMemoryReport getMemoryReport(InputType inputType) {
InputType.InputTypeRecurrent recurrent = (InputType.InputTypeRecurrent) inputType;
InputType.InputTypeRecurrent outputType = (InputType.InputTypeRecurrent) getOutputType(-1, inputType);

int im2colSizePerEx = recurrent.getSize() * outputType.getTimeSeriesLength() * size;
int trainingWorkingSizePerEx = im2colSizePerEx;
if (getDropOut() > 0) {
trainingWorkingSizePerEx += inputType.arrayElementsPerExample();
}

return new LayerMemoryReport.Builder(layerName, Upsampling1D.class, inputType, outputType)
.standardMemory(0, 0) //No params
.workingMemory(0, im2colSizePerEx, 0, trainingWorkingSizePerEx)
.cacheMemory(MemoryReport.CACHE_MODE_ALL_ZEROS, MemoryReport.CACHE_MODE_ALL_ZEROS) //No caching
.build();
}

@NoArgsConstructor
public static class Builder extends Upsampling2D.Upsampling2DBuilder {
public static class Builder extends UpsamplingBuilder<Builder> {

public Builder(int size) {
super(size);
Expand Down
Loading

0 comments on commit e4a4b47

Please sign in to comment.