Skip to content

Commit

Permalink
Merge pull request deeplearning4j#4053 from deeplearning4j/mp_fix_emb…
Browse files Browse the repository at this point in the history
…edding_lstm_import

Fix imdb lstm config tests
  • Loading branch information
maxpumperla authored Sep 12, 2017
2 parents 1f04073 + 64c4c45 commit 503c774
Show file tree
Hide file tree
Showing 7 changed files with 67 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,16 @@
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToRnnPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.RnnToFeedForwardPreProcessor;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.TensorFlowCnnToFeedForwardPreProcessor;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils;
import org.deeplearning4j.nn.params.LSTMParamInitializer;
Expand Down Expand Up @@ -163,7 +168,11 @@ public InputType getOutputType(InputType... inputType) throws InvalidKerasConfig
if (inputType.length > 1)
throw new InvalidKerasConfigurationException(
"Keras LSTM layer accepts only one input (received " + inputType.length + ")");
return this.getLSTMLayer().getOutputType(-1, inputType[0]);
InputPreProcessor preProcessor = getInputPreprocessor(inputType);
if (preProcessor != null)
return preProcessor.getOutputType(inputType[0]);
else
return this.getLSTMLayer().getOutputType(-1, inputType[0]);
}

/**
Expand All @@ -176,6 +185,28 @@ public int getNumParams() {
return kerasMajorVersion == 2 ? NUM_TRAINABLE_PARAMS_KERAS_2 : NUM_TRAINABLE_PARAMS;
}

/**
* Gets appropriate DL4J InputPreProcessor for given InputTypes.
*
* @param inputType Array of InputTypes
* @return DL4J InputPreProcessor
* @throws InvalidKerasConfigurationException Invalid Keras configuration exception
* @see org.deeplearning4j.nn.conf.InputPreProcessor
*/
@Override
public InputPreProcessor getInputPreprocessor(InputType... inputType) throws InvalidKerasConfigurationException {
if (inputType.length > 1)
throw new InvalidKerasConfigurationException(
"Keras LSTM layer accepts only one input (received " + inputType.length + ")");
InputPreProcessor preprocessor = null;
if (inputType[0] instanceof InputType.InputTypeFeedForward) {
preprocessor = new FeedForwardToRnnPreProcessor();
}
return preprocessor;
}



/**
* Set weights for layer.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,15 @@ public class Keras1ModelConfigurationTest {

private ClassLoader classLoader = getClass().getClassLoader();

// @Test
// public void imdbLstmTfSequentialConfigTest() throws Exception {
// runSequentialConfigTest("configs/keras1/imdb_lstm_tf_keras_1_config.json");
// }
//
// @Test
// public void imdbLstmThSequentialConfigTest() throws Exception {
// runSequentialConfigTest("configs/keras1/imdb_lstm_th_keras_1_config.json");
// }
@Test
public void imdbLstmTfSequentialConfigTest() throws Exception {
runSequentialConfigTest("configs/keras1/imdb_lstm_tf_keras_1_config.json", true);
}

@Test
public void imdbLstmThSequentialConfigTest() throws Exception {
runSequentialConfigTest("configs/keras1/imdb_lstm_th_keras_1_config.json", true);
}

@Test
public void mnistMlpTfSequentialConfigTest() throws Exception {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,15 @@ public class Keras2ModelConfigurationTest {

ClassLoader classLoader = getClass().getClassLoader();

// @Test
// public void imdbLstmTfSequentialConfigTest() throws Exception {
// runSequentialConfigTest("configs/keras2/imdb_lstm_tf_keras_2_config.json");
// }
//
// @Test
// public void imdbLstmThSequentialConfigTest() throws Exception {
// runSequentialConfigTest("configs/keras2/imdb_lstm_th_keras_2_config.json");
// }
@Test
public void imdbLstmTfSequentialConfigTest() throws Exception {
runSequentialConfigTest("configs/keras2/imdb_lstm_tf_keras_2_config.json");
}

@Test
public void imdbLstmThSequentialConfigTest() throws Exception {
runSequentialConfigTest("configs/keras2/imdb_lstm_th_keras_2_config.json");
}

@Test
public void mnistMlpTfSequentialConfigTest() throws Exception {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@
"consume_less": "cpu",
"stateful": false,
"init": "glorot_uniform",
"inner_init": "orthogonal",
"dropout_U": 0.2,
"inner_init": "glorot_uniform",
"dropout_U": 0.0,
"dropout_W": 0.2,
"input_dim": 128,
"return_sequences": false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
null
],
"W_regularizer": null,
"dropout": 0.2,
"dropout": 0.2,
"output_dim": 128,
"input_length": null
}
Expand All @@ -36,8 +36,8 @@
"consume_less": "cpu",
"stateful": false,
"init": "glorot_uniform",
"inner_init": "orthogonal",
"dropout_U": 0.2,
"inner_init": "glorot_uniform",
"dropout_U": 0.0,
"dropout_W": 0.2,
"input_dim": 128,
"return_sequences": false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,12 @@
"recurrent_activation": "hard_sigmoid",
"trainable": true,
"recurrent_initializer": {
"class_name": "Orthogonal",
"class_name": "VarianceScaling",
"config": {
"seed": null,
"gain": 1.0
"distribution": "uniform",
"scale": 1.0,
"seed": null,
"mode": "fan_avg"
}
},
"use_bias": true,
Expand All @@ -53,7 +55,7 @@
"units": 128,
"unit_forget_bias": true,
"activity_regularizer": null,
"recurrent_dropout": 0.2,
"recurrent_dropout": 0.0,
"kernel_initializer": {
"class_name": "VarianceScaling",
"config": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,12 @@
"recurrent_activation": "hard_sigmoid",
"trainable": true,
"recurrent_initializer": {
"class_name": "Orthogonal",
"class_name": "VarianceScaling",
"config": {
"seed": null,
"gain": 1.0
"distribution": "uniform",
"scale": 1.0,
"seed": null,
"mode": "fan_avg"
}
},
"use_bias": true,
Expand All @@ -53,7 +55,7 @@
"units": 128,
"unit_forget_bias": true,
"activity_regularizer": null,
"recurrent_dropout": 0.2,
"recurrent_dropout": 0.0,
"kernel_initializer": {
"class_name": "VarianceScaling",
"config": {
Expand Down

0 comments on commit 503c774

Please sign in to comment.