Skip to content

Commit

Permalink
Upsampling 2D import support
Browse files Browse the repository at this point in the history
  • Loading branch information
maxpumperla committed Sep 11, 2017
1 parent ecfd11e commit ded241f
Show file tree
Hide file tree
Showing 8 changed files with 181 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ public class Keras1LayerConfiguration extends KerasLayerConfiguration {
/* Pooling / Upsampling layer properties */
private final String LAYER_FIELD_POOL_1D_SIZE = "pool_length";
private final String LAYER_FIELD_POOL_1D_STRIDES = "stride";
private final String LAYER_FIELD_UPSAMPLING_SIZE = "length";
private final String LAYER_FIELD_UPSAMPLING_1D_SIZE = "length";

/* Keras convolution border modes. */
private final String LAYER_FIELD_BORDER_MODE = "border_mode";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ public class Keras2LayerConfiguration extends KerasLayerConfiguration {
/* Pooling / Upsampling layer properties */
private final String LAYER_FIELD_POOL_1D_SIZE = "pool_size";
private final String LAYER_FIELD_POOL_1D_STRIDES = "strides";
private final String LAYER_FIELD_UPSAMPLING_SIZE = "size";
private final String LAYER_FIELD_UPSAMPLING_1D_SIZE = "size";

/* Keras convolution border modes. */
private final String LAYER_FIELD_BORDER_MODE = "padding";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,9 @@ public class KerasLayerConfiguration {
private final String LAYER_FIELD_POOL_STRIDES = "strides";
private final String LAYER_FIELD_POOL_1D_SIZE = ""; // 1: pool_length, 2: pool_size
private final String LAYER_FIELD_POOL_1D_STRIDES = ""; // 1: stride, 2: strides
private final String LAYER_FIELD_UPSAMPLING_SIZE = ""; // 1: length, 2: size
private final String LAYER_FIELD_UPSAMPLING_1D_SIZE = ""; // 1: length, 2: size
private final String LAYER_FIELD_UPSAMPLING_2D_SIZE = "size";


/* Keras convolution border modes. */
private final String LAYER_FIELD_BORDER_MODE = ""; // 1: border_mode, 2: padding
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,33 @@ public static int[] getDilationRate(Map<String, Object> layerConfig, int dimensi

}

/**
* Get upsampling size from Keras layer configuration.
*
* @param layerConfig dictionary containing Keras layer configuration
* @return
* @throws InvalidKerasConfigurationException
*/
public static int[] getUpsamplingSizeFromConfig(Map<String, Object> layerConfig, int dimension,
KerasLayerConfiguration conf)
throws InvalidKerasConfigurationException {
Map<String, Object> innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf);
int[] size;
if (innerConfig.containsKey(conf.getLAYER_FIELD_UPSAMPLING_2D_SIZE()) && dimension == 2) {
List<Integer> sizeList = (List<Integer>) innerConfig.get(conf.getLAYER_FIELD_UPSAMPLING_2D_SIZE());
size = ArrayUtil.toArray(sizeList);
} else if (innerConfig.containsKey(conf.getLAYER_FIELD_UPSAMPLING_1D_SIZE()) && dimension == 1) {
int upsamplingSize1D = (int) innerConfig.get(conf.getLAYER_FIELD_UPSAMPLING_1D_SIZE());
size = new int[]{ upsamplingSize1D };
} else {
throw new InvalidKerasConfigurationException("Could not determine kernel size: no "
+ conf.getLAYER_FIELD_UPSAMPLING_1D_SIZE() + ", "
+ conf.getLAYER_FIELD_UPSAMPLING_2D_SIZE());
}
return size;
}



/**
* Get (convolution) kernel size from Keras layer configuration.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
/*-
*
* * Copyright 2017 Skymind,Inc.
* *
* * Licensed under the Apache License, Version 2.0 (the "License");
* * you may not use this file except in compliance with the License.
* * You may obtain a copy of the License at
* *
* * http://www.apache.org/licenses/LICENSE-2.0
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS,
* * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* * See the License for the specific language governing permissions and
* * limitations under the License.
*
*/
package org.deeplearning4j.nn.modelimport.keras.layers.convolutional;

import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.Upsampling2D;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;

import java.util.Map;


/**
* Keras Upsampling2D layer support
*
* @author Max Pumperla
*/
public class KerasUpsampling2D extends KerasLayer {

/**
* Constructor from parsed Keras layer configuration dictionary.
*
* @param layerConfig dictionary containing Keras layer configuration.
* @throws InvalidKerasConfigurationException Invalid Keras configuration exception
* @throws UnsupportedKerasConfigurationException Unsupported Keras configuration exception
*/
public KerasUpsampling2D(Map<String, Object> layerConfig)
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
this(layerConfig, true);
}

/**
* Constructor from parsed Keras layer configuration dictionary.
*
* @param layerConfig dictionary containing Keras layer configuration
* @param enforceTrainingConfig whether to enforce training-related configuration options
* @throws InvalidKerasConfigurationException Invalid Keras configuration exception
* @throws UnsupportedKerasConfigurationException Invalid Keras configuration exception
*/
public KerasUpsampling2D(Map<String, Object> layerConfig, boolean enforceTrainingConfig)
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
super(layerConfig, enforceTrainingConfig);

int[] size = KerasConvolutionUtils.getUpsamplingSizeFromConfig(layerConfig, 2, conf);
if (size[0] != size[1])
throw new UnsupportedKerasConfigurationException("First and second size arguments have to be the same" +
"got: " + size[0] + " and " + size[1]);

Upsampling2D.Builder builder = new Upsampling2D.Builder()
.name(this.layerName)
.dropOut(this.dropout)
.size(size[0]);

this.layer = builder.build();
this.vertex = null;
}

/**
* Get DL4J Upsampling2D layer.
*
* @return Upsampling2D layer
*/
public Upsampling2D getUpsampling2DLayer() {
return (Upsampling2D) this.layer;
}

/**
* Get layer output type.
*
* @param inputType Array of InputTypes
* @return output type as InputType
* @throws InvalidKerasConfigurationException
*/
@Override
public InputType getOutputType(InputType... inputType) throws InvalidKerasConfigurationException {
if (inputType.length > 1)
throw new InvalidKerasConfigurationException(
"Keras Subsampling layer accepts only one input (received " + inputType.length + ")");
return this.getUpsampling2DLayer().getOutputType(-1, inputType[0]);
}

}
Original file line number Diff line number Diff line change
@@ -1,3 +1,20 @@
/*-
*
* * Copyright 2017 Skymind,Inc.
* *
* * Licensed under the Apache License, Version 2.0 (the "License");
* * you may not use this file except in compliance with the License.
* * You may obtain a copy of the License at
* *
* * http://www.apache.org/licenses/LICENSE-2.0
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS,
* * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* * See the License for the specific language governing permissions and
* * limitations under the License.
*
*/
package org.deeplearning4j.nn.modelimport.keras.layers.pooling;

import lombok.Data;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,20 @@
/*-
*
* * Copyright 2017 Skymind,Inc.
* *
* * Licensed under the Apache License, Version 2.0 (the "License");
* * you may not use this file except in compliance with the License.
* * You may obtain a copy of the License at
* *
* * http://www.apache.org/licenses/LICENSE-2.0
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS,
* * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* * See the License for the specific language governing permissions and
* * limitations under the License.
*
*/
package org.deeplearning4j.nn.modelimport.keras.layers.pooling;

import lombok.extern.slf4j.Slf4j;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,20 @@
/*-
*
* * Copyright 2017 Skymind,Inc.
* *
* * Licensed under the Apache License, Version 2.0 (the "License");
* * you may not use this file except in compliance with the License.
* * You may obtain a copy of the License at
* *
* * http://www.apache.org/licenses/LICENSE-2.0
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS,
* * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* * See the License for the specific language governing permissions and
* * limitations under the License.
*
*/
package org.deeplearning4j.nn.modelimport.keras.layers.pooling;

import lombok.extern.slf4j.Slf4j;
Expand Down

0 comments on commit ded241f

Please sign in to comment.