forked from deeplearning4j/deeplearning4j
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
finish test and factor out upsampling base class
- Loading branch information
1 parent
f9c2c92
commit e4a4b47
Showing
9 changed files
with
331 additions
and
54 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
95 changes: 95 additions & 0 deletions
95
.../java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasUpsampling1D.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
/*- | ||
* | ||
* * Copyright 2017 Skymind,Inc. | ||
* * | ||
* * Licensed under the Apache License, Version 2.0 (the "License"); | ||
* * you may not use this file except in compliance with the License. | ||
* * You may obtain a copy of the License at | ||
* * | ||
* * http://www.apache.org/licenses/LICENSE-2.0 | ||
* * | ||
* * Unless required by applicable law or agreed to in writing, software | ||
* * distributed under the License is distributed on an "AS IS" BASIS, | ||
* * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* * See the License for the specific language governing permissions and | ||
* * limitations under the License. | ||
* | ||
*/ | ||
package org.deeplearning4j.nn.modelimport.keras.layers.convolutional; | ||
|
||
import org.deeplearning4j.nn.conf.inputs.InputType; | ||
import org.deeplearning4j.nn.conf.layers.Upsampling1D; | ||
import org.deeplearning4j.nn.modelimport.keras.KerasLayer; | ||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; | ||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; | ||
|
||
import java.util.Map; | ||
|
||
|
||
/** | ||
* Keras Upsampling1D layer support | ||
* | ||
* @author Max Pumperla | ||
*/ | ||
public class KerasUpsampling1D extends KerasLayer { | ||
|
||
/** | ||
* Constructor from parsed Keras layer configuration dictionary. | ||
* | ||
* @param layerConfig dictionary containing Keras layer configuration. | ||
* @throws InvalidKerasConfigurationException Invalid Keras configuration exception | ||
* @throws UnsupportedKerasConfigurationException Unsupported Keras configuration exception | ||
*/ | ||
public KerasUpsampling1D(Map<String, Object> layerConfig) | ||
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException { | ||
this(layerConfig, true); | ||
} | ||
|
||
/** | ||
* Constructor from parsed Keras layer configuration dictionary. | ||
* | ||
* @param layerConfig dictionary containing Keras layer configuration | ||
* @param enforceTrainingConfig whether to enforce training-related configuration options | ||
* @throws InvalidKerasConfigurationException Invalid Keras configuration exception | ||
* @throws UnsupportedKerasConfigurationException Invalid Keras configuration exception | ||
*/ | ||
public KerasUpsampling1D(Map<String, Object> layerConfig, boolean enforceTrainingConfig) | ||
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException { | ||
super(layerConfig, enforceTrainingConfig); | ||
|
||
int[] size = KerasConvolutionUtils.getUpsamplingSizeFromConfig(layerConfig, 1, conf); | ||
|
||
Upsampling1D.Builder builder = new Upsampling1D.Builder() | ||
.name(this.layerName) | ||
.dropOut(this.dropout) | ||
.size(size[0]); | ||
|
||
this.layer = builder.build(); | ||
this.vertex = null; | ||
} | ||
|
||
/** | ||
* Get DL4J Upsampling1D layer. | ||
* | ||
* @return Upsampling1D layer | ||
*/ | ||
public Upsampling1D getUpsampling1DLayer() { | ||
return (Upsampling1D) this.layer; | ||
} | ||
|
||
/** | ||
* Get layer output type. | ||
* | ||
* @param inputType Array of InputTypes | ||
* @return output type as InputType | ||
* @throws InvalidKerasConfigurationException | ||
*/ | ||
@Override | ||
public InputType getOutputType(InputType... inputType) throws InvalidKerasConfigurationException { | ||
if (inputType.length > 1) | ||
throw new InvalidKerasConfigurationException( | ||
"Keras Subsampling layer accepts only one input (received " + inputType.length + ")"); | ||
return this.getUpsampling1DLayer().getOutputType(-1, inputType[0]); | ||
} | ||
|
||
} |
69 changes: 69 additions & 0 deletions
69
...ava/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasUpsampling1DTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
/*- | ||
* | ||
* * Copyright 2017 Skymind,Inc. | ||
* * | ||
* * Licensed under the Apache License, Version 2.0 (the "License"); | ||
* * you may not use this file except in compliance with the License. | ||
* * You may obtain a copy of the License at | ||
* * | ||
* * http://www.apache.org/licenses/LICENSE-2.0 | ||
* * | ||
* * Unless required by applicable law or agreed to in writing, software | ||
* * distributed under the License is distributed on an "AS IS" BASIS, | ||
* * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* * See the License for the specific language governing permissions and | ||
* * limitations under the License. | ||
* | ||
*/ | ||
package org.deeplearning4j.nn.modelimport.keras.layers.convolution; | ||
|
||
import org.deeplearning4j.nn.conf.layers.Upsampling1D; | ||
import org.deeplearning4j.nn.conf.layers.Upsampling2D; | ||
import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; | ||
import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; | ||
import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; | ||
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasUpsampling1D; | ||
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasUpsampling2D; | ||
import org.junit.Test; | ||
|
||
import java.util.ArrayList; | ||
import java.util.HashMap; | ||
import java.util.List; | ||
import java.util.Map; | ||
|
||
import static org.junit.Assert.assertEquals; | ||
|
||
/** | ||
* @author Max Pumperla | ||
*/ | ||
public class KerasUpsampling1DTest { | ||
|
||
private final String LAYER_NAME = "upsampling_1D_layer"; | ||
private int size = 4; | ||
|
||
private Integer keras1 = 1; | ||
private Integer keras2 = 2; | ||
private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); | ||
private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); | ||
|
||
@Test | ||
public void testUpsampling1DLayer() throws Exception { | ||
buildUpsampling1DLayer(conf1, keras1); | ||
buildUpsampling1DLayer(conf2, keras2); | ||
} | ||
|
||
public void buildUpsampling1DLayer(KerasLayerConfiguration conf, Integer kerasVersion) throws Exception { | ||
Map<String, Object> layerConfig = new HashMap<>(); | ||
layerConfig.put(conf.getLAYER_FIELD_CLASS_NAME(), conf.getLAYER_CLASS_NAME_UPSAMPLING_1D()); | ||
Map<String, Object> config = new HashMap<>(); | ||
config.put(conf.getLAYER_FIELD_UPSAMPLING_1D_SIZE(), size); | ||
config.put(conf.getLAYER_FIELD_NAME(), LAYER_NAME); | ||
layerConfig.put(conf.getLAYER_FIELD_CONFIG(), config); | ||
layerConfig.put(conf.getLAYER_FIELD_KERAS_VERSION(), kerasVersion); | ||
|
||
Upsampling1D layer = new KerasUpsampling1D(layerConfig).getUpsampling1DLayer(); | ||
assertEquals(LAYER_NAME, layer.getLayerName()); | ||
assertEquals(size, layer.getSize()); | ||
} | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
116 changes: 116 additions & 0 deletions
116
deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BaseUpsamplingLayer.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
/*- | ||
* | ||
* * Copyright 2017 Skymind,Inc. | ||
* * | ||
* * Licensed under the Apache License, Version 2.0 (the "License"); | ||
* * you may not use this file except in compliance with the License. | ||
* * You may obtain a copy of the License at | ||
* * | ||
* * http://www.apache.org/licenses/LICENSE-2.0 | ||
* * | ||
* * Unless required by applicable law or agreed to in writing, software | ||
* * distributed under the License is distributed on an "AS IS" BASIS, | ||
* * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* * See the License for the specific language governing permissions and | ||
* * limitations under the License. | ||
* | ||
*/ | ||
package org.deeplearning4j.nn.conf.layers; | ||
|
||
import lombok.Data; | ||
import lombok.EqualsAndHashCode; | ||
import lombok.NoArgsConstructor; | ||
import lombok.ToString; | ||
import org.deeplearning4j.nn.api.ParamInitializer; | ||
import org.deeplearning4j.nn.conf.InputPreProcessor; | ||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; | ||
import org.deeplearning4j.nn.conf.inputs.InputType; | ||
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; | ||
import org.deeplearning4j.nn.conf.memory.MemoryReport; | ||
import org.deeplearning4j.nn.params.EmptyParamInitializer; | ||
import org.deeplearning4j.optimize.api.IterationListener; | ||
import org.nd4j.linalg.api.ndarray.INDArray; | ||
|
||
import java.util.Collection; | ||
import java.util.Map; | ||
|
||
/** | ||
* Upsampling base layer | ||
* | ||
* @author Max Pumperla | ||
*/ | ||
|
||
@Data | ||
@NoArgsConstructor | ||
@ToString(callSuper = true) | ||
@EqualsAndHashCode(callSuper = true) | ||
public abstract class BaseUpsamplingLayer extends Layer { | ||
|
||
protected int size; | ||
|
||
protected BaseUpsamplingLayer(UpsamplingBuilder builder) { | ||
super(builder); | ||
this.size = builder.size; | ||
} | ||
|
||
@Override | ||
public BaseUpsamplingLayer clone() { | ||
BaseUpsamplingLayer clone = (BaseUpsamplingLayer) super.clone(); | ||
return clone; | ||
} | ||
|
||
@Override | ||
public ParamInitializer initializer() { | ||
return EmptyParamInitializer.getInstance(); | ||
} | ||
|
||
|
||
@Override | ||
public void setNIn(InputType inputType, boolean override) { | ||
//No op: upsampling layer doesn't have nIn value | ||
} | ||
|
||
@Override | ||
public InputPreProcessor getPreProcessorForInputType(InputType inputType) { | ||
if (inputType == null) { | ||
throw new IllegalStateException("Invalid input for Upsampling layer (layer name=\"" + getLayerName() | ||
+ "\"): input is null"); | ||
} | ||
return InputTypeUtil.getPreProcessorForInputTypeCnnLayers(inputType, getLayerName()); | ||
} | ||
|
||
@Override | ||
public double getL1ByParam(String paramName) { | ||
//Not applicable | ||
return 0; | ||
} | ||
|
||
@Override | ||
public double getL2ByParam(String paramName) { | ||
//Not applicable | ||
return 0; | ||
} | ||
|
||
@Override | ||
public double getLearningRateByParam(String paramName) { | ||
//Not applicable | ||
return 0; | ||
} | ||
|
||
@Override | ||
public boolean isPretrainParam(String paramName) { | ||
throw new UnsupportedOperationException("UpsamplingLayer does not contain parameters"); | ||
} | ||
|
||
|
||
@NoArgsConstructor | ||
protected static abstract class UpsamplingBuilder<T extends UpsamplingBuilder<T>> | ||
extends Layer.Builder<T> { | ||
protected int size = 1; | ||
|
||
protected UpsamplingBuilder(int size) { | ||
this.size = size; | ||
} | ||
} | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.