Skip to content

Commit 89ced1a

Browse files
committed
InputPreProcessor InputType API change + CNN preprocessors getOutputType method
1 parent ac527a4 commit 89ced1a

File tree

5 files changed

+66
-0
lines changed

5 files changed

+66
-0
lines changed

deeplearning4j-core/src/main/java/org/deeplearning4j/nn/conf/InputPreProcessor.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import com.fasterxml.jackson.annotation.JsonSubTypes;
2323
import com.fasterxml.jackson.annotation.JsonTypeInfo;
24+
import org.deeplearning4j.nn.conf.inputs.InputType;
2425
import org.deeplearning4j.nn.conf.preprocessor.*;
2526
import org.nd4j.linalg.api.ndarray.INDArray;
2627

@@ -68,4 +69,12 @@ public interface InputPreProcessor extends Serializable, Cloneable {
6869
INDArray backprop(INDArray output, int miniBatchSize);
6970

7071
InputPreProcessor clone();
72+
73+
/**
74+
* For a given type of input to this preprocessor, what is the type of the output?
75+
*
76+
* @param inputType Type of input for the preprocessor
77+
* @return Type of input after applying the preprocessor
78+
*/
79+
InputType getOutputType(InputType inputType);
7180
}

deeplearning4j-core/src/main/java/org/deeplearning4j/nn/conf/preprocessor/CnnToFeedForwardPreProcessor.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import lombok.Data;
2525

2626
import org.deeplearning4j.nn.conf.InputPreProcessor;
27+
import org.deeplearning4j.nn.conf.inputs.InputType;
2728
import org.nd4j.linalg.api.ndarray.INDArray;
2829
import org.nd4j.linalg.api.shape.Shape;
2930
import org.nd4j.linalg.util.ArrayUtil;
@@ -113,4 +114,15 @@ public CnnToFeedForwardPreProcessor clone() {
113114
throw new RuntimeException(e);
114115
}
115116
}
117+
118+
@Override
119+
public InputType getOutputType(InputType inputType) {
120+
if(inputType == null || inputType.getType() != InputType.Type.CNN){
121+
throw new IllegalStateException("Invalid input type: Expected input of type CNN, got " + inputType);
122+
}
123+
124+
InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional)inputType;
125+
int outSize = c.getDepth() * c.getHeight() * c.getWidth();
126+
return InputType.feedForward(outSize);
127+
}
116128
}

deeplearning4j-core/src/main/java/org/deeplearning4j/nn/conf/preprocessor/CnnToRnnPreProcessor.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import lombok.Getter;
88
import lombok.Setter;
99
import org.deeplearning4j.nn.conf.InputPreProcessor;
10+
import org.deeplearning4j.nn.conf.inputs.InputType;
1011
import org.nd4j.linalg.api.ndarray.INDArray;
1112
import org.nd4j.linalg.util.ArrayUtil;
1213

@@ -91,4 +92,15 @@ public INDArray backprop(INDArray output, int miniBatchSize) {
9192
public CnnToRnnPreProcessor clone() {
9293
return new CnnToRnnPreProcessor(inputHeight,inputWidth,numChannels);
9394
}
95+
96+
@Override
97+
public InputType getOutputType(InputType inputType) {
98+
if(inputType == null || inputType.getType() != InputType.Type.CNN){
99+
throw new IllegalStateException("Invalid input type: Expected input of type CNN, got " + inputType);
100+
}
101+
102+
InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional)inputType;
103+
int outSize = c.getDepth() * c.getHeight() * c.getWidth();
104+
return InputType.recurrent(outSize);
105+
}
94106
}

deeplearning4j-core/src/main/java/org/deeplearning4j/nn/conf/preprocessor/FeedForwardToCnnPreProcessor.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import lombok.Getter;
2626
import lombok.Setter;
2727
import org.deeplearning4j.nn.conf.InputPreProcessor;
28+
import org.deeplearning4j.nn.conf.inputs.InputType;
2829
import org.nd4j.linalg.api.ndarray.INDArray;
2930
import org.nd4j.linalg.api.shape.Shape;
3031
import org.nd4j.linalg.util.ArrayUtil;
@@ -117,5 +118,20 @@ public FeedForwardToCnnPreProcessor clone() {
117118
}
118119
}
119120

121+
@Override
122+
public InputType getOutputType(InputType inputType) {
123+
if(inputType == null || inputType.getType() != InputType.Type.FF){
124+
throw new IllegalStateException("Invalid input type: Expected input of type FeedForward, got " + inputType);
125+
}
126+
127+
InputType.InputTypeFeedForward c = (InputType.InputTypeFeedForward)inputType;
128+
int expSize = inputHeight * inputWidth * numChannels;
129+
if(c.getSize() != expSize){
130+
throw new IllegalStateException("Invalid input: expected FeedForward input of size " + expSize + " = (d=" + numChannels +
131+
" * w=" + inputWidth + " * h=" + inputHeight + "), got " + inputType);
132+
}
133+
134+
return InputType.convolutional(inputHeight, inputWidth, numChannels);
135+
}
120136

121137
}

deeplearning4j-core/src/main/java/org/deeplearning4j/nn/conf/preprocessor/RnnToCnnPreProcessor.java

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import lombok.Getter;
77
import lombok.Setter;
88
import org.deeplearning4j.nn.conf.InputPreProcessor;
9+
import org.deeplearning4j.nn.conf.inputs.InputType;
910
import org.nd4j.linalg.api.ndarray.INDArray;
1011
import org.nd4j.linalg.util.ArrayUtil;
1112

@@ -82,4 +83,20 @@ public INDArray backprop(INDArray output, int miniBatchSize) {
8283
public RnnToCnnPreProcessor clone() {
8384
return new RnnToCnnPreProcessor(inputHeight, inputWidth, numChannels);
8485
}
86+
87+
@Override
88+
public InputType getOutputType(InputType inputType) {
89+
if(inputType == null || inputType.getType() != InputType.Type.RNN){
90+
throw new IllegalStateException("Invalid input type: Expected input of type RNN, got " + inputType);
91+
}
92+
93+
InputType.InputTypeRecurrent c = (InputType.InputTypeRecurrent)inputType;
94+
int expSize = inputHeight * inputWidth * numChannels;
95+
if(c.getSize() != expSize){
96+
throw new IllegalStateException("Invalid input: expected RNN input of size " + expSize + " = (d=" + numChannels +
97+
" * w=" + inputWidth + " * h=" + inputHeight + "), got " + inputType);
98+
}
99+
100+
return InputType.convolutional(inputHeight, inputWidth, numChannels);
101+
}
85102
}

0 commit comments

Comments
 (0)