,DataSet> {
public BatchDataSetsFunction(int minibatchSize) {
+ super(new BatchDataSetsFunctionAdapter(minibatchSize));
+ }
+}
+
+/**
+ * Function used to batch DataSet objects together. Typically used to combine singe-example DataSet objects out of
+ * something like {@link org.deeplearning4j.spark.datavec.DataVecDataSetFunction} together into minibatches.
+ *
+ * Usage:
+ *
+ * {@code
+ * RDD mySingleExampleDataSets = ...;
+ * RDD batchData = mySingleExampleDataSets.mapPartitions(new BatchDataSetsFunction(batchSize));
+ * }
+ *
+ *
+ * @author Alex Black
+ */
+class BatchDataSetsFunctionAdapter implements FlatMapFunctionAdapter,DataSet> {
+ private final int minibatchSize;
+
+ public BatchDataSetsFunctionAdapter(int minibatchSize) {
this.minibatchSize = minibatchSize;
}
diff --git a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/shuffle/SplitDataSetExamplesPairFlatMapFunction.java b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/shuffle/SplitDataSetExamplesPairFlatMapFunction.java
index ad743d992180..1f52bc05e18a 100644
--- a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/shuffle/SplitDataSetExamplesPairFlatMapFunction.java
+++ b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/shuffle/SplitDataSetExamplesPairFlatMapFunction.java
@@ -1,8 +1,8 @@
package org.deeplearning4j.spark.data.shuffle;
import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.function.FlatMapFunction;
-import org.apache.spark.api.java.function.PairFlatMapFunction;
+import org.deeplearning4j.spark.util.BasePairFlatMapFunctionAdaptee;
+import org.deeplearning4j.spark.util.PairFlatMapFunctionAdapter;
import org.nd4j.linalg.dataset.DataSet;
import scala.Tuple2;
@@ -18,12 +18,27 @@
*
* @author Alex Black
*/
-public class SplitDataSetExamplesPairFlatMapFunction implements PairFlatMapFunction {
+public class SplitDataSetExamplesPairFlatMapFunction extends BasePairFlatMapFunctionAdaptee {
+
+ public SplitDataSetExamplesPairFlatMapFunction(int maxKeyIndex) {
+ super(new SplitDataSetExamplesPairFlatMapFunctionAdapter(maxKeyIndex));
+ }
+}
+
+/**
+ * A PairFlatMapFunction that splits each example in a {@link DataSet} object into its own {@link DataSet}.
+ * Also adds a random key (integer value) in the range 0 to maxKeyIndex-1.
+ *
+ * Used in {@link org.deeplearning4j.spark.util.SparkUtils#shuffleExamples(JavaRDD, int, int)}
+ *
+ * @author Alex Black
+ */
+class SplitDataSetExamplesPairFlatMapFunctionAdapter implements PairFlatMapFunctionAdapter {
private transient Random r;
private int maxKeyIndex;
- public SplitDataSetExamplesPairFlatMapFunction(int maxKeyIndex){
+ public SplitDataSetExamplesPairFlatMapFunctionAdapter(int maxKeyIndex){
this.maxKeyIndex = maxKeyIndex;
}
diff --git a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/RDDMiniBatches.java b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/RDDMiniBatches.java
index 494c090ca741..e3eccd87072b 100644
--- a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/RDDMiniBatches.java
+++ b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/RDDMiniBatches.java
@@ -17,7 +17,8 @@
package org.deeplearning4j.spark.datavec;
import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.function.FlatMapFunction;
+import org.datavec.spark.functions.FlatMapFunctionAdapter;
+import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee;
import org.nd4j.linalg.dataset.DataSet;
import java.io.Serializable;
@@ -43,11 +44,17 @@ public JavaRDD miniBatchesJava() {
return toSplitJava.mapPartitions(new MiniBatchFunction(miniBatches));
}
+ public static class MiniBatchFunction extends BaseFlatMapFunctionAdaptee, DataSet> {
- public static class MiniBatchFunction implements FlatMapFunction, DataSet> {
+ public MiniBatchFunction(int batchSize) {
+ super(new MiniBatchFunctionAdapter(batchSize));
+ }
+ }
+
+ static class MiniBatchFunctionAdapter implements FlatMapFunctionAdapter, DataSet> {
private int batchSize = 10;
- public MiniBatchFunction(int batchSize) {
+ public MiniBatchFunctionAdapter(int batchSize) {
this.batchSize = batchSize;
}
diff --git a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/repartition/MapTupleToPairFlatMap.java b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/repartition/MapTupleToPairFlatMap.java
index 1669c14c4009..98294173a881 100644
--- a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/repartition/MapTupleToPairFlatMap.java
+++ b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/repartition/MapTupleToPairFlatMap.java
@@ -1,6 +1,7 @@
package org.deeplearning4j.spark.impl.common.repartition;
-import org.apache.spark.api.java.function.PairFlatMapFunction;
+import org.deeplearning4j.spark.util.BasePairFlatMapFunctionAdaptee;
+import org.deeplearning4j.spark.util.PairFlatMapFunctionAdapter;
import scala.Tuple2;
import java.util.ArrayList;
@@ -13,7 +14,14 @@
*
* @author Alex Black
*/
-public class MapTupleToPairFlatMap implements PairFlatMapFunction>,T,U> {
+public class MapTupleToPairFlatMap extends BasePairFlatMapFunctionAdaptee>,T,U> {
+
+ public MapTupleToPairFlatMap() {
+ super(new MapTupleToPairFlatMapAdapter());
+ }
+}
+
+class MapTupleToPairFlatMapAdapter implements PairFlatMapFunctionAdapter>,T,U> {
@Override
public Iterable> call(Iterator> tuple2Iterator) throws Exception {
diff --git a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/score/BaseVaeReconstructionProbWithKeyFunction.java b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/score/BaseVaeReconstructionProbWithKeyFunctionAdapter.java
similarity index 81%
rename from deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/score/BaseVaeReconstructionProbWithKeyFunction.java
rename to deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/score/BaseVaeReconstructionProbWithKeyFunctionAdapter.java
index 47bedb9b58c7..9592d56de8d5 100644
--- a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/score/BaseVaeReconstructionProbWithKeyFunction.java
+++ b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/score/BaseVaeReconstructionProbWithKeyFunctionAdapter.java
@@ -11,7 +11,7 @@
* @param Type of key, associated with each example. Used to keep track of which score belongs to which example
* @author Alex Black
*/
-public abstract class BaseVaeReconstructionProbWithKeyFunction extends BaseVaeScoreWithKeyFunction {
+public abstract class BaseVaeReconstructionProbWithKeyFunctionAdapter extends BaseVaeScoreWithKeyFunctionAdapter {
private final boolean useLogProbability;
private final int numSamples;
@@ -23,8 +23,8 @@ public abstract class BaseVaeReconstructionProbWithKeyFunction extends BaseVa
* @param batchSize Batch size to use when scoring
* @param numSamples Number of samples to use when calling {@link VariationalAutoencoder#reconstructionLogProbability(INDArray, int)}
*/
- public BaseVaeReconstructionProbWithKeyFunction(Broadcast params, Broadcast jsonConfig, boolean useLogProbability,
- int batchSize, int numSamples){
+ public BaseVaeReconstructionProbWithKeyFunctionAdapter(Broadcast params, Broadcast jsonConfig, boolean useLogProbability,
+ int batchSize, int numSamples){
super(params, jsonConfig, batchSize);
this.useLogProbability = useLogProbability;
this.numSamples = numSamples;
diff --git a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/score/BaseVaeScoreWithKeyFunction.java b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/score/BaseVaeScoreWithKeyFunctionAdapter.java
similarity index 90%
rename from deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/score/BaseVaeScoreWithKeyFunction.java
rename to deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/score/BaseVaeScoreWithKeyFunctionAdapter.java
index cf38f168c731..e8ff45d49b53 100644
--- a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/score/BaseVaeScoreWithKeyFunction.java
+++ b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/score/BaseVaeScoreWithKeyFunctionAdapter.java
@@ -22,6 +22,8 @@
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.broadcast.Broadcast;
import org.deeplearning4j.nn.layers.variational.VariationalAutoencoder;
+import org.deeplearning4j.spark.util.BasePairFlatMapFunctionAdaptee;
+import org.deeplearning4j.spark.util.PairFlatMapFunctionAdapter;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.executioner.GridExecutioner;
import org.nd4j.linalg.factory.Nd4j;
@@ -42,7 +44,7 @@
* @author Alex Black
*/
@Slf4j
-public abstract class BaseVaeScoreWithKeyFunction implements PairFlatMapFunction>, K, Double> {
+public abstract class BaseVaeScoreWithKeyFunctionAdapter implements PairFlatMapFunctionAdapter>, K, Double> {
protected final Broadcast params;
protected final Broadcast jsonConfig;
@@ -54,7 +56,7 @@ public abstract class BaseVaeScoreWithKeyFunction implements PairFlatMapFunct
* @param jsonConfig MultiLayerConfiguration, as json
* @param batchSize Batch size to use when scoring
*/
- public BaseVaeScoreWithKeyFunction(Broadcast params, Broadcast jsonConfig, int batchSize) {
+ public BaseVaeScoreWithKeyFunctionAdapter(Broadcast params, Broadcast jsonConfig, int batchSize) {
this.params = params;
this.jsonConfig = jsonConfig;
this.batchSize = batchSize;
diff --git a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/CGVaeReconstructionErrorWithKeyFunction.java b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/CGVaeReconstructionErrorWithKeyFunction.java
index f51253a13678..22e14b6deec6 100644
--- a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/CGVaeReconstructionErrorWithKeyFunction.java
+++ b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/CGVaeReconstructionErrorWithKeyFunction.java
@@ -5,8 +5,7 @@
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.layers.variational.VariationalAutoencoder;
-import org.deeplearning4j.spark.impl.common.score.BaseVaeReconstructionProbWithKeyFunction;
-import org.deeplearning4j.spark.impl.common.score.BaseVaeScoreWithKeyFunction;
+import org.deeplearning4j.spark.impl.common.score.BaseVaeScoreWithKeyFunctionAdapter;
import org.nd4j.linalg.api.ndarray.INDArray;
/**
@@ -18,7 +17,7 @@
* @author Alex Black
* @see CGVaeReconstructionProbWithKeyFunction
*/
-public class CGVaeReconstructionErrorWithKeyFunction extends BaseVaeScoreWithKeyFunction {
+public class CGVaeReconstructionErrorWithKeyFunction extends BaseVaeScoreWithKeyFunctionAdapter {
/**
diff --git a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/CGVaeReconstructionProbWithKeyFunction.java b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/CGVaeReconstructionProbWithKeyFunction.java
index c0214aa2631b..f293a47401ed 100644
--- a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/CGVaeReconstructionProbWithKeyFunction.java
+++ b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/CGVaeReconstructionProbWithKeyFunction.java
@@ -5,8 +5,7 @@
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.layers.variational.VariationalAutoencoder;
-import org.deeplearning4j.spark.impl.common.score.BaseVaeReconstructionProbWithKeyFunction;
-import org.deeplearning4j.spark.impl.common.score.BaseVaeScoreWithKeyFunction;
+import org.deeplearning4j.spark.impl.common.score.BaseVaeReconstructionProbWithKeyFunctionAdapter;
import org.nd4j.linalg.api.ndarray.INDArray;
/**
@@ -16,7 +15,7 @@
*
* @author Alex Black
*/
-public class CGVaeReconstructionProbWithKeyFunction extends BaseVaeReconstructionProbWithKeyFunction {
+public class CGVaeReconstructionProbWithKeyFunction extends BaseVaeReconstructionProbWithKeyFunctionAdapter {
/**
diff --git a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/GraphFeedForwardWithKeyFunction.java b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/GraphFeedForwardWithKeyFunction.java
index 562134d50820..22d4d6920f0a 100644
--- a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/GraphFeedForwardWithKeyFunction.java
+++ b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/GraphFeedForwardWithKeyFunction.java
@@ -18,10 +18,11 @@
package org.deeplearning4j.spark.impl.graph.scoring;
-import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.broadcast.Broadcast;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.graph.ComputationGraph;
+import org.deeplearning4j.spark.util.BasePairFlatMapFunctionAdaptee;
+import org.deeplearning4j.spark.util.PairFlatMapFunctionAdapter;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.executioner.GridExecutioner;
import org.nd4j.linalg.factory.Nd4j;
@@ -42,7 +43,21 @@
* @param Type of key, associated with each example. Used to keep track of which output belongs to which input example
* @author Alex Black
*/
-public class GraphFeedForwardWithKeyFunction implements PairFlatMapFunction>, K, INDArray[]> {
+public class GraphFeedForwardWithKeyFunction extends BasePairFlatMapFunctionAdaptee>, K, INDArray[]> {
+
+ public GraphFeedForwardWithKeyFunction(Broadcast params, Broadcast jsonConfig, int batchSize) {
+ super(new GraphFeedForwardWithKeyFunctionAdapter(params, jsonConfig, batchSize));
+ }
+}
+
+/**
+ * Function to feed-forward examples, and get the network output (for example, class probabilities).
+ * A key value is used to keey track of which output corresponds to which input.
+ *
+ * @param Type of key, associated with each example. Used to keep track of which output belongs to which input example
+ * @author Alex Black
+ */
+class GraphFeedForwardWithKeyFunctionAdapter implements PairFlatMapFunctionAdapter>, K, INDArray[]> {
protected static Logger log = LoggerFactory.getLogger(GraphFeedForwardWithKeyFunction.class);
@@ -55,7 +70,7 @@ public class GraphFeedForwardWithKeyFunction implements PairFlatMapFunction 1 for efficiency)
*/
- public GraphFeedForwardWithKeyFunction(Broadcast params, Broadcast jsonConfig, int batchSize) {
+ public GraphFeedForwardWithKeyFunctionAdapter(Broadcast params, Broadcast jsonConfig, int batchSize) {
this.params = params;
this.jsonConfig = jsonConfig;
this.batchSize = batchSize;
diff --git a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreExamplesFunction.java b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreExamplesFunction.java
index f077b006e7cb..bcb7bbbabf83 100644
--- a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreExamplesFunction.java
+++ b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreExamplesFunction.java
@@ -18,10 +18,11 @@
package org.deeplearning4j.spark.impl.graph.scoring;
-import org.apache.spark.api.java.function.DoubleFlatMapFunction;
import org.apache.spark.broadcast.Broadcast;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.graph.ComputationGraph;
+import org.deeplearning4j.spark.util.BaseDoubleFlatMapFunctionAdaptee;
+import org.deeplearning4j.spark.util.DoubleFlatMapFunctionAdapter;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.executioner.GridExecutioner;
import org.nd4j.linalg.dataset.api.MultiDataSet;
@@ -34,6 +35,22 @@
import java.util.Iterator;
import java.util.List;
+
+/**Function to score examples individually. Note that scoring is batched for computational efficiency.
+ * This is essentially a Spark implementation of the {@link ComputationGraph#scoreExamples(MultiDataSet, boolean)} method
+ * Note: This method returns a score for each example, but the association between examples and scores is lost. In
+ * cases where we need to know the score for particular examples, use {@link ScoreExamplesWithKeyFunction}
+ * @author Alex Black
+ * @see ScoreExamplesWithKeyFunction
+ */
+public class ScoreExamplesFunction extends BaseDoubleFlatMapFunctionAdaptee> {
+
+ public ScoreExamplesFunction(Broadcast params, Broadcast jsonConfig, boolean addRegularizationTerms,
+ int batchSize) {
+ super(new ScoreExamplesFunctionAdapter(params, jsonConfig, addRegularizationTerms, batchSize));
+ }
+}
+
/**Function to score examples individually. Note that scoring is batched for computational efficiency.
* This is essentially a Spark implementation of the {@link ComputationGraph#scoreExamples(MultiDataSet, boolean)} method
* Note: This method returns a score for each example, but the association between examples and scores is lost. In
@@ -41,7 +58,7 @@
* @author Alex Black
* @see ScoreExamplesWithKeyFunction
*/
-public class ScoreExamplesFunction implements DoubleFlatMapFunction> {
+class ScoreExamplesFunctionAdapter implements DoubleFlatMapFunctionAdapter> {
protected static final Logger log = LoggerFactory.getLogger(ScoreExamplesFunction.class);
private final Broadcast params;
@@ -49,7 +66,7 @@ public class ScoreExamplesFunction implements DoubleFlatMapFunction params, Broadcast jsonConfig, boolean addRegularizationTerms,
+ public ScoreExamplesFunctionAdapter(Broadcast params, Broadcast jsonConfig, boolean addRegularizationTerms,
int batchSize){
this.params = params;
this.jsonConfig = jsonConfig;
diff --git a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreExamplesWithKeyFunction.java b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreExamplesWithKeyFunction.java
index f898540e6579..6ad501a8e251 100644
--- a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreExamplesWithKeyFunction.java
+++ b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreExamplesWithKeyFunction.java
@@ -18,10 +18,11 @@
package org.deeplearning4j.spark.impl.graph.scoring;
-import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.broadcast.Broadcast;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.graph.ComputationGraph;
+import org.deeplearning4j.spark.util.BasePairFlatMapFunctionAdaptee;
+import org.deeplearning4j.spark.util.PairFlatMapFunctionAdapter;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.executioner.GridExecutioner;
import org.nd4j.linalg.dataset.api.MultiDataSet;
@@ -44,7 +45,24 @@
* @param Type of key, associated with each example. Used to keep track of which score belongs to which example
* @see ScoreExamplesFunction
*/
-public class ScoreExamplesWithKeyFunction implements PairFlatMapFunction>,K,Double> {
+public class ScoreExamplesWithKeyFunction extends BasePairFlatMapFunctionAdaptee>,K,Double> {
+
+ public ScoreExamplesWithKeyFunction(Broadcast params, Broadcast jsonConfig, boolean addRegularizationTerms,
+ int batchSize) {
+ super(new ScoreExamplesWithKeyFunctionAdapter(params, jsonConfig, addRegularizationTerms, batchSize));
+ }
+}
+
+/**Function to score examples individually, where each example is associated with a particular key
+ * Note that scoring is batched for computational efficiency.
+ * This is the Spark implementation of the {@link ComputationGraph#scoreExamples(MultiDataSet, boolean)} method
+ * Note: The MultiDataSet objects passed in must have exactly one example in them (otherwise: can't have a 1:1 association
+ * between keys and data sets to score)
+ * @author Alex Black
+ * @param Type of key, associated with each example. Used to keep track of which score belongs to which example
+ * @see ScoreExamplesFunction
+ */
+class ScoreExamplesWithKeyFunctionAdapter implements PairFlatMapFunctionAdapter>,K,Double> {
protected static Logger log = LoggerFactory.getLogger(ScoreExamplesWithKeyFunction.class);
@@ -59,7 +77,7 @@ public class ScoreExamplesWithKeyFunction implements PairFlatMapFunction params, Broadcast jsonConfig, boolean addRegularizationTerms,
+ public ScoreExamplesWithKeyFunctionAdapter(Broadcast params, Broadcast jsonConfig, boolean addRegularizationTerms,
int batchSize){
this.params = params;
this.jsonConfig = jsonConfig;
diff --git a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreFlatMapFunctionCGDataSet.java b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreFlatMapFunctionCGDataSet.java
index 8265adeb0afe..592369f89414 100644
--- a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreFlatMapFunctionCGDataSet.java
+++ b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreFlatMapFunctionCGDataSet.java
@@ -18,8 +18,9 @@
package org.deeplearning4j.spark.impl.graph.scoring;
-import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.broadcast.Broadcast;
+import org.datavec.spark.functions.FlatMapFunctionAdapter;
+import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee;
import org.deeplearning4j.datasets.iterator.IteratorDataSetIterator;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.graph.ComputationGraph;
@@ -38,7 +39,15 @@
import java.util.List;
/** Function used to score a DataSet using a ComputationGraph */
-public class ScoreFlatMapFunctionCGDataSet implements FlatMapFunction, Tuple2> {
+public class ScoreFlatMapFunctionCGDataSet extends BaseFlatMapFunctionAdaptee, Tuple2> {
+
+ public ScoreFlatMapFunctionCGDataSet(String json, Broadcast params, int minibatchSize) {
+ super(new ScoreFlatMapFunctionCGDataSetAdapter(json, params, minibatchSize));
+ }
+}
+
+/** Function used to score a DataSet using a ComputationGraph */
+class ScoreFlatMapFunctionCGDataSetAdapter implements FlatMapFunctionAdapter, Tuple2> {
private static final Logger log = LoggerFactory.getLogger(ScoreFlatMapFunctionCGDataSet.class);
private String json;
@@ -46,7 +55,7 @@ public class ScoreFlatMapFunctionCGDataSet implements FlatMapFunction params, int minibatchSize){
+ public ScoreFlatMapFunctionCGDataSetAdapter(String json, Broadcast params, int minibatchSize){
this.json = json;
this.params = params;
this.minibatchSize = minibatchSize;
diff --git a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreFlatMapFunctionCGMultiDataSet.java b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreFlatMapFunctionCGMultiDataSet.java
index 55a9367c3196..f197045160df 100644
--- a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreFlatMapFunctionCGMultiDataSet.java
+++ b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreFlatMapFunctionCGMultiDataSet.java
@@ -18,8 +18,9 @@
package org.deeplearning4j.spark.impl.graph.scoring;
-import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.broadcast.Broadcast;
+import org.datavec.spark.functions.FlatMapFunctionAdapter;
+import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee;
import org.deeplearning4j.datasets.iterator.IteratorMultiDataSetIterator;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.graph.ComputationGraph;
@@ -38,7 +39,15 @@
import java.util.List;
/** Function used to score a MultiDataSet using a given ComputationGraph */
-public class ScoreFlatMapFunctionCGMultiDataSet implements FlatMapFunction,Tuple2> {
+public class ScoreFlatMapFunctionCGMultiDataSet extends BaseFlatMapFunctionAdaptee,Tuple2> {
+
+ public ScoreFlatMapFunctionCGMultiDataSet(String json, Broadcast params, int minibatchSize) {
+ super(new ScoreFlatMapFunctionCGMultiDataSetAdapter(json, params, minibatchSize));
+ }
+}
+
+/** Function used to score a MultiDataSet using a given ComputationGraph */
+class ScoreFlatMapFunctionCGMultiDataSetAdapter implements FlatMapFunctionAdapter,Tuple2> {
private static final Logger log = LoggerFactory.getLogger(ScoreFlatMapFunctionCGMultiDataSet.class);
private String json;
@@ -46,7 +55,7 @@ public class ScoreFlatMapFunctionCGMultiDataSet implements FlatMapFunction