diff --git a/buildmultiplescalaversions.sh b/buildmultiplescalaversions.sh index 3eb4d3ffce9d..cd051845378c 100755 --- a/buildmultiplescalaversions.sh +++ b/buildmultiplescalaversions.sh @@ -2,6 +2,8 @@ set -eu ./change-scala-versions.sh 2.10 mvn "$@" +mvn -Dspark.major.version=2 "$@" ./change-scala-versions.sh 2.11 mvn "$@" +mvn -Dspark.major.version=2 "$@" ./change-scala-versions.sh 2.10 diff --git a/deeplearning4j-scaleout/dl4j-streaming/pom.xml b/deeplearning4j-scaleout/dl4j-streaming/pom.xml index 4d5c5863ecbf..8a7580699200 100644 --- a/deeplearning4j-scaleout/dl4j-streaming/pom.xml +++ b/deeplearning4j-scaleout/dl4j-streaming/pom.xml @@ -4,10 +4,9 @@ 4.0.0 - org.deeplearning4j dl4j-streaming_2.10 jar - 0.7.3-SNAPSHOT + 0.7.3_spark_${spark.major.version}-SNAPSHOT deeplearning4j-scaleout @@ -27,7 +26,7 @@ org.datavec datavec-camel - ${project.version} + ${parent.version} @@ -56,15 +55,10 @@ spark-streaming_2.10 ${spark.version} - - org.apache.spark - spark-streaming-kafka_2.10 - ${spark.version} - org.datavec datavec-spark_2.10 - ${datavec.version} + ${datavec.spark.version} @@ -108,5 +102,90 @@ + + + + + org.codehaus.mojo + build-helper-maven-plugin + 1.12 + + + add-source + generate-sources + add-source + + + src/main/spark-${spark.major.version} + + + + + + + + + + + spark-default + + + !spark.major.version + + + + 1 + 1.6.2 + 2.2.0 + + + + org.apache.spark + spark-streaming-kafka_2.10 + ${spark.version} + + + + + spark-1 + + + spark.major.version + 1 + + + + 1.6.2 + 2.2.0 + + + + org.apache.spark + spark-streaming-kafka_2.10 + ${spark.version} + + + + + spark-2 + + + spark.major.version + 2 + + + + 2.1.0 + 2.2.0 + + + + org.apache.spark + spark-streaming-kafka-0-8_2.10 + ${spark.version} + + + + diff --git a/deeplearning4j-scaleout/dl4j-streaming/src/main/java/org/deeplearning4j/streaming/pipeline/spark/DataSetFlatmap.java b/deeplearning4j-scaleout/dl4j-streaming/src/main/java/org/deeplearning4j/streaming/pipeline/spark/DataSetFlatmap.java index 060047f98528..acfc825c323c 100644 --- a/deeplearning4j-scaleout/dl4j-streaming/src/main/java/org/deeplearning4j/streaming/pipeline/spark/DataSetFlatmap.java +++ b/deeplearning4j-scaleout/dl4j-streaming/src/main/java/org/deeplearning4j/streaming/pipeline/spark/DataSetFlatmap.java @@ -3,6 +3,8 @@ import org.apache.commons.net.util.Base64; import org.apache.spark.api.java.function.FlatMapFunction; import org.datavec.api.writable.Writable; +import org.datavec.spark.functions.FlatMapFunctionAdapter; +import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee; import org.deeplearning4j.streaming.conversion.dataset.RecordToDataSet; import org.deeplearning4j.streaming.serde.RecordDeSerializer; import org.nd4j.linalg.dataset.DataSet; @@ -15,11 +17,22 @@ * Flat maps a binary dataset string in to a * dataset */ -public class DataSetFlatmap implements FlatMapFunction, DataSet> { +public class DataSetFlatmap extends BaseFlatMapFunctionAdaptee, DataSet> { + + public DataSetFlatmap(int numLabels, RecordToDataSet recordToDataSetFunction) { + super(new DataSetFlatmapAdapter(numLabels, recordToDataSetFunction)); + } +} + +/** + * Flat maps a binary dataset string in to a + * dataset + */ +class DataSetFlatmapAdapter implements FlatMapFunctionAdapter, DataSet> { private int numLabels; private RecordToDataSet recordToDataSetFunction; - public DataSetFlatmap(int numLabels, RecordToDataSet recordToDataSetFunction) { + public DataSetFlatmapAdapter(int numLabels, RecordToDataSet recordToDataSetFunction) { this.numLabels = numLabels; this.recordToDataSetFunction = recordToDataSetFunction; } diff --git a/deeplearning4j-scaleout/dl4j-streaming/src/main/java/org/deeplearning4j/streaming/pipeline/spark/SparkStreamingPipeline.java b/deeplearning4j-scaleout/dl4j-streaming/src/main/java/org/deeplearning4j/streaming/pipeline/spark/SparkStreamingPipeline.java index a68586314aa8..b0e63859aaee 100644 --- a/deeplearning4j-scaleout/dl4j-streaming/src/main/java/org/deeplearning4j/streaming/pipeline/spark/SparkStreamingPipeline.java +++ b/deeplearning4j-scaleout/dl4j-streaming/src/main/java/org/deeplearning4j/streaming/pipeline/spark/SparkStreamingPipeline.java @@ -71,9 +71,6 @@ public JavaDStream createStream() { @Override public void startStreamingConsumption(long timeout) { jssc.start(); - if(timeout < 0) - jssc.awaitTermination(); - else - jssc.awaitTermination(timeout); + StreamingContextUtils.awaitTermination(jssc, timeout); } } diff --git a/deeplearning4j-scaleout/dl4j-streaming/src/main/java/org/deeplearning4j/streaming/pipeline/spark/inerference/NDArrayFlatMap.java b/deeplearning4j-scaleout/dl4j-streaming/src/main/java/org/deeplearning4j/streaming/pipeline/spark/inerference/NDArrayFlatMap.java index 1941c2a7146f..3ed313c73718 100644 --- a/deeplearning4j-scaleout/dl4j-streaming/src/main/java/org/deeplearning4j/streaming/pipeline/spark/inerference/NDArrayFlatMap.java +++ b/deeplearning4j-scaleout/dl4j-streaming/src/main/java/org/deeplearning4j/streaming/pipeline/spark/inerference/NDArrayFlatMap.java @@ -1,8 +1,9 @@ package org.deeplearning4j.streaming.pipeline.spark.inerference; import org.apache.commons.net.util.Base64; -import org.apache.spark.api.java.function.FlatMapFunction; import org.datavec.api.writable.Writable; +import org.datavec.spark.functions.FlatMapFunctionAdapter; +import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee; import org.deeplearning4j.streaming.conversion.ndarray.RecordToNDArray; import org.deeplearning4j.streaming.serde.RecordDeSerializer; import org.nd4j.linalg.api.ndarray.INDArray; @@ -16,10 +17,22 @@ * dataset * @author Adam Gibson */ -public class NDArrayFlatMap implements FlatMapFunction, INDArray> { - private RecordToNDArray recordToDataSetFunction; +public class NDArrayFlatMap extends BaseFlatMapFunctionAdaptee, INDArray> { public NDArrayFlatMap(RecordToNDArray recordToDataSetFunction) { + super(new NDArrayFlatMapAdapter(recordToDataSetFunction)); + } +} + +/** + * Flat maps a binary dataset string in to a + * dataset + * @author Adam Gibson + */ +class NDArrayFlatMapAdapter implements FlatMapFunctionAdapter, INDArray> { + private RecordToNDArray recordToDataSetFunction; + + public NDArrayFlatMapAdapter(RecordToNDArray recordToDataSetFunction) { this.recordToDataSetFunction = recordToDataSetFunction; } diff --git a/deeplearning4j-scaleout/dl4j-streaming/src/main/java/org/deeplearning4j/streaming/pipeline/spark/inerference/SparkStreamingInferencePipeline.java b/deeplearning4j-scaleout/dl4j-streaming/src/main/java/org/deeplearning4j/streaming/pipeline/spark/inerference/SparkStreamingInferencePipeline.java index 43f4e3332d42..f4f2750fe2bd 100644 --- a/deeplearning4j-scaleout/dl4j-streaming/src/main/java/org/deeplearning4j/streaming/pipeline/spark/inerference/SparkStreamingInferencePipeline.java +++ b/deeplearning4j-scaleout/dl4j-streaming/src/main/java/org/deeplearning4j/streaming/pipeline/spark/inerference/SparkStreamingInferencePipeline.java @@ -16,6 +16,7 @@ import org.apache.spark.streaming.kafka.KafkaUtils; import org.deeplearning4j.streaming.conversion.ndarray.RecordToNDArray; import org.deeplearning4j.streaming.pipeline.kafka.BaseKafkaPipeline; +import org.deeplearning4j.streaming.pipeline.spark.StreamingContextUtils; import org.nd4j.linalg.api.ndarray.INDArray; import java.util.Collections; @@ -90,9 +91,6 @@ public JavaDStream createStream() { @Override public void startStreamingConsumption(long timeout) { jssc.start(); - if(timeout < 0) - jssc.awaitTermination(); - else - jssc.awaitTermination(timeout); + StreamingContextUtils.awaitTermination(jssc, timeout); } } diff --git a/deeplearning4j-scaleout/dl4j-streaming/src/main/java/org/deeplearning4j/streaming/pipeline/spark/PrintDataSet.java b/deeplearning4j-scaleout/dl4j-streaming/src/main/spark-1/org/deeplearning4j/streaming/pipeline/spark/PrintDataSet.java similarity index 100% rename from deeplearning4j-scaleout/dl4j-streaming/src/main/java/org/deeplearning4j/streaming/pipeline/spark/PrintDataSet.java rename to deeplearning4j-scaleout/dl4j-streaming/src/main/spark-1/org/deeplearning4j/streaming/pipeline/spark/PrintDataSet.java diff --git a/deeplearning4j-scaleout/dl4j-streaming/src/main/spark-1/org/deeplearning4j/streaming/pipeline/spark/StreamingContextUtils.java b/deeplearning4j-scaleout/dl4j-streaming/src/main/spark-1/org/deeplearning4j/streaming/pipeline/spark/StreamingContextUtils.java new file mode 100644 index 000000000000..0f48ebb892ee --- /dev/null +++ b/deeplearning4j-scaleout/dl4j-streaming/src/main/spark-1/org/deeplearning4j/streaming/pipeline/spark/StreamingContextUtils.java @@ -0,0 +1,23 @@ +package org.deeplearning4j.streaming.pipeline.spark; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.streaming.api.java.JavaDStream; +import org.apache.spark.streaming.api.java.JavaStreamingContext; + +/** + * In order to handle changes between Spark 1.x and 2.x + */ +public class StreamingContextUtils { + + public static void awaitTermination(JavaStreamingContext jssc, long timeout) { + if(timeout < 0) + jssc.awaitTermination(); + else + jssc.awaitTermination(timeout); + } + + public static void foreach(JavaDStream stream, Function, Void> func) { + stream.foreach(func); + } +} diff --git a/deeplearning4j-scaleout/dl4j-streaming/src/main/spark-2/org/deeplearning4j/streaming/pipeline/spark/PrintDataSet.java b/deeplearning4j-scaleout/dl4j-streaming/src/main/spark-2/org/deeplearning4j/streaming/pipeline/spark/PrintDataSet.java new file mode 100644 index 000000000000..4c00889cf8cf --- /dev/null +++ b/deeplearning4j-scaleout/dl4j-streaming/src/main/spark-2/org/deeplearning4j/streaming/pipeline/spark/PrintDataSet.java @@ -0,0 +1,22 @@ +package org.deeplearning4j.streaming.pipeline.spark; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.VoidFunction; +import org.nd4j.linalg.dataset.DataSet; + +/** + * Created by agibsonccc on 6/11/16. + */ +public class PrintDataSet implements VoidFunction> { + @Override + public void call(JavaRDD dataSetJavaRDD) throws Exception { + dataSetJavaRDD.foreach(new VoidFunction() { + @Override + public void call(DataSet dataSet) throws Exception { + System.out.println(dataSet); + } + }); + } +} + diff --git a/deeplearning4j-scaleout/dl4j-streaming/src/main/spark-2/org/deeplearning4j/streaming/pipeline/spark/StreamingContextUtils.java b/deeplearning4j-scaleout/dl4j-streaming/src/main/spark-2/org/deeplearning4j/streaming/pipeline/spark/StreamingContextUtils.java new file mode 100644 index 000000000000..352a0bf9ca49 --- /dev/null +++ b/deeplearning4j-scaleout/dl4j-streaming/src/main/spark-2/org/deeplearning4j/streaming/pipeline/spark/StreamingContextUtils.java @@ -0,0 +1,27 @@ +package org.deeplearning4j.streaming.pipeline.spark; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.function.VoidFunction; +import org.apache.spark.streaming.api.java.JavaDStream; +import org.apache.spark.streaming.api.java.JavaStreamingContext; + +/** + * In order to handle changes between Spark 1.x and 2.x + */ +public class StreamingContextUtils { + + public static void awaitTermination(JavaStreamingContext jssc, long timeout) { + try { + if(timeout < 0) + jssc.awaitTermination(); + else + jssc.awaitTerminationOrTimeout(timeout); + } catch (InterruptedException e) { + e.printStackTrace(); + } + } + + public static void foreach(JavaDStream stream, VoidFunction> func) { + stream.foreachRDD(func); + } +} diff --git a/deeplearning4j-scaleout/dl4j-streaming/src/test/java/org/deeplearning4j/streaming/embedded/JavaDirectKafkaWordCount.java b/deeplearning4j-scaleout/dl4j-streaming/src/test/java/org/deeplearning4j/streaming/embedded/JavaDirectKafkaWordCount.java index 8455f725b309..e30abda424b3 100644 --- a/deeplearning4j-scaleout/dl4j-streaming/src/test/java/org/deeplearning4j/streaming/embedded/JavaDirectKafkaWordCount.java +++ b/deeplearning4j-scaleout/dl4j-streaming/src/test/java/org/deeplearning4j/streaming/embedded/JavaDirectKafkaWordCount.java @@ -29,6 +29,8 @@ import org.apache.spark.streaming.api.java.JavaPairInputDStream; import org.apache.spark.streaming.api.java.JavaStreamingContext; import org.apache.spark.streaming.kafka.KafkaUtils; +import org.datavec.spark.functions.FlatMapFunctionAdapter; +import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee; import scala.Tuple2; import java.util.*; @@ -87,12 +89,12 @@ public String call(Tuple2 tuple2) { } }); - JavaDStream words = lines.flatMap(new FlatMapFunction() { + JavaDStream words = lines.flatMap(new BaseFlatMapFunctionAdaptee<>(new FlatMapFunctionAdapter() { @Override public Iterable call(String x) { return Arrays.asList(SPACE.split(x)); } - }); + })); JavaPairDStream wordCounts = words.mapToPair( new PairFunction() { @Override @@ -112,4 +114,4 @@ public Integer call(Integer i1, Integer i2) { jssc.start(); jssc.awaitTermination(); } -} \ No newline at end of file +} diff --git a/deeplearning4j-scaleout/dl4j-streaming/src/test/java/org/deeplearning4j/streaming/pipeline/PipelineTest.java b/deeplearning4j-scaleout/dl4j-streaming/src/test/java/org/deeplearning4j/streaming/pipeline/PipelineTest.java index da8d562d9a30..dae20d8c2d46 100644 --- a/deeplearning4j-scaleout/dl4j-streaming/src/test/java/org/deeplearning4j/streaming/pipeline/PipelineTest.java +++ b/deeplearning4j-scaleout/dl4j-streaming/src/test/java/org/deeplearning4j/streaming/pipeline/PipelineTest.java @@ -9,6 +9,7 @@ import org.deeplearning4j.streaming.embedded.TestUtils; import org.deeplearning4j.streaming.pipeline.spark.PrintDataSet; import org.deeplearning4j.streaming.pipeline.spark.SparkStreamingPipeline; +import org.deeplearning4j.streaming.pipeline.spark.StreamingContextUtils; import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Test; @@ -53,7 +54,7 @@ public void testPipeline() throws Exception { //NOTE THAT YOU NEED TO DO SOMETHING WITH THE STREAM OTHERWISE IT ERRORS OUT. //ALSO NOTE HERE THAT YOU NEED TO HAVE THE FUNCTION BE AN OBJECT NOT AN ANONYMOUS //CLASS BECAUSE OF TASK SERIALIZATION - dataSetJavaDStream.foreach(new PrintDataSet()); + StreamingContextUtils.foreach(dataSetJavaDStream, new PrintDataSet()); pipeline.startStreamingConsumption(1000); diff --git a/deeplearning4j-scaleout/pom.xml b/deeplearning4j-scaleout/pom.xml index 39b60e8b3f39..9b02db87ecec 100644 --- a/deeplearning4j-scaleout/pom.xml +++ b/deeplearning4j-scaleout/pom.xml @@ -17,21 +17,25 @@ ~ */ --> - - 4.0.0 - - org.deeplearning4j - deeplearning4j-parent - 0.7.3-SNAPSHOT - - deeplearning4j-scaleout - pom - DeepLearning4j-scaleout-parent - - dl4j-streaming - deeplearning4j-aws - spark + + 4.0.0 + + org.deeplearning4j + deeplearning4j-parent + 0.7.3-SNAPSHOT + + deeplearning4j-scaleout + 0.7.3_spark_${spark.major.version}-SNAPSHOT + pom + DeepLearning4j-scaleout-parent + + + dl4j-streaming + deeplearning4j-aws + spark deeplearning4j-scaleout-parallelwrapper deeplearning4j-scaleout-parallelwrapper-parameter-server + diff --git a/deeplearning4j-scaleout/spark/dl4j-spark-nlp/pom.xml b/deeplearning4j-scaleout/spark/dl4j-spark-nlp/pom.xml index c14917dc58c8..db3b54c5c856 100644 --- a/deeplearning4j-scaleout/spark/dl4j-spark-nlp/pom.xml +++ b/deeplearning4j-scaleout/spark/dl4j-spark-nlp/pom.xml @@ -25,6 +25,7 @@ 4.0.0 dl4j-spark-nlp_2.10 + 0.7.3_spark_${spark.major.version}-SNAPSHOT jar dl4j-spark-nlp @@ -47,7 +48,7 @@ org.deeplearning4j deeplearning4j-nlp - ${project.version} + ${parent.version} org.deeplearning4j @@ -59,5 +60,10 @@ junit test + + org.datavec + datavec-spark_2.10 + ${project.version} + diff --git a/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/FirstIterationFunction.java b/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/FirstIterationFunction.java index 1a07f760a55f..336d3ccf0e9c 100644 --- a/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/FirstIterationFunction.java +++ b/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/FirstIterationFunction.java @@ -1,8 +1,9 @@ package org.deeplearning4j.spark.models.embeddings.word2vec; import org.apache.commons.lang3.tuple.Pair; -import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.broadcast.Broadcast; +import org.datavec.spark.functions.FlatMapFunctionAdapter; +import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee; import org.deeplearning4j.models.word2vec.VocabWord; import org.deeplearning4j.models.word2vec.wordstore.VocabCache; import org.nd4j.linalg.api.ndarray.INDArray; @@ -18,229 +19,11 @@ * @author raver119@gmail.com */ public class FirstIterationFunction - implements FlatMapFunction< Iterator, Long>>, Entry > { - - private int ithIteration = 1; - private int vectorLength; - private boolean useAdaGrad; - private int batchSize = 0; - private double negative; - private int window; - private double alpha; - private double minAlpha; - private long totalWordCount; - private long seed; - private int maxExp; - private double[] expTable; - private int iterations; - private Map indexSyn0VecMap; - private Map pointSyn1VecMap; - private AtomicLong nextRandom = new AtomicLong(5); - - private volatile VocabCache vocab; - private volatile NegativeHolder negativeHolder; - private AtomicLong cid = new AtomicLong(0); - private AtomicLong aff = new AtomicLong(0); - - - - + extends BaseFlatMapFunctionAdaptee< Iterator, Long>>, Entry > { public FirstIterationFunction(Broadcast> word2vecVarMapBroadcast, Broadcast expTableBroadcast, Broadcast> vocabCacheBroadcast) { - - Map word2vecVarMap = word2vecVarMapBroadcast.getValue(); - this.expTable = expTableBroadcast.getValue(); - this.vectorLength = (int) word2vecVarMap.get("vectorLength"); - this.useAdaGrad = (boolean) word2vecVarMap.get("useAdaGrad"); - this.negative = (double) word2vecVarMap.get("negative"); - this.window = (int) word2vecVarMap.get("window"); - this.alpha = (double) word2vecVarMap.get("alpha"); - this.minAlpha = (double) word2vecVarMap.get("minAlpha"); - this.totalWordCount = (long) word2vecVarMap.get("totalWordCount"); - this.seed = (long) word2vecVarMap.get("seed"); - this.maxExp = (int) word2vecVarMap.get("maxExp"); - this.iterations = (int) word2vecVarMap.get("iterations"); - this.batchSize = (int) word2vecVarMap.get("batchSize"); - this.indexSyn0VecMap = new HashMap<>(); - this.pointSyn1VecMap = new HashMap<>(); - this.vocab = vocabCacheBroadcast.getValue(); - - if (this.vocab == null) throw new RuntimeException("VocabCache is null"); - - if (negative > 0) { - negativeHolder = NegativeHolder.getInstance(); - negativeHolder.initHolder(vocab, expTable, this.vectorLength); - } - } - - - - @Override - public Iterable> call(Iterator, Long>> pairIter) { - while (pairIter.hasNext()) { - List, Long>> batch = new ArrayList<>(); - while (pairIter.hasNext() && batch.size() < batchSize) { - Tuple2, Long> pair = pairIter.next(); - List vocabWordsList = pair._1(); - Long sentenceCumSumCount = pair._2(); - batch.add(Pair.of(vocabWordsList, sentenceCumSumCount)); - } - - for (int i = 0; i < iterations; i++) { - //System.out.println("Training sentence: " + vocabWordsList); - for (Pair, Long> pair: batch) { - List vocabWordsList = pair.getKey(); - Long sentenceCumSumCount = pair.getValue(); - double currentSentenceAlpha = Math.max(minAlpha, - alpha - (alpha - minAlpha) * (sentenceCumSumCount / (double) totalWordCount)); - trainSentence(vocabWordsList, currentSentenceAlpha); - } - } - } - return indexSyn0VecMap.entrySet(); - } - - - public void trainSentence(List vocabWordsList, double currentSentenceAlpha) { - - if (vocabWordsList != null && !vocabWordsList.isEmpty()) { - for (int ithWordInSentence = 0; ithWordInSentence < vocabWordsList.size(); ithWordInSentence++) { - // Random value ranging from 0 to window size - nextRandom.set(Math.abs(nextRandom.get() * 25214903917L + 11)); - int b = (int) (long) this.nextRandom.get() % window; - VocabWord currentWord = vocabWordsList.get(ithWordInSentence); - if (currentWord != null) { - skipGram(ithWordInSentence, vocabWordsList, b, currentSentenceAlpha); - } - } - } - } - - public void skipGram(int ithWordInSentence, List vocabWordsList, int b, double currentSentenceAlpha) { - - VocabWord currentWord = vocabWordsList.get(ithWordInSentence); - if (currentWord != null && !vocabWordsList.isEmpty()) { - int end = window * 2 + 1 - b; - for (int a = b; a < end; a++) { - if (a != window) { - int c = ithWordInSentence - window + a; - if (c >= 0 && c < vocabWordsList.size()) { - VocabWord lastWord = vocabWordsList.get(c); - iterateSample(currentWord, lastWord, currentSentenceAlpha); - } - } - } - } - } - - public void iterateSample(VocabWord w1, VocabWord w2, double currentSentenceAlpha) { - - - if (w1 == null || w2 == null || w2.getIndex() < 0 || w2.getIndex() == w1.getIndex()) - return; - final int currentWordIndex = w2.getIndex(); - - // error for current word and context - INDArray neu1e = Nd4j.create(vectorLength); - - // First iteration Syn0 is random numbers - INDArray l1 = null; - if (indexSyn0VecMap.containsKey(vocab.elementAtIndex(currentWordIndex))) { - l1 = indexSyn0VecMap.get(vocab.elementAtIndex(currentWordIndex)); - } else { - l1 = getRandomSyn0Vec(vectorLength, (long) currentWordIndex); - } - - // - for (int i = 0; i < w1.getCodeLength(); i++) { - int code = w1.getCodes().get(i); - int point = w1.getPoints().get(i); - if(point < 0) - throw new IllegalStateException("Illegal point " + point); - // Point to - INDArray syn1; - if (pointSyn1VecMap.containsKey(point)) { - syn1 = pointSyn1VecMap.get(point); - } else { - syn1 = Nd4j.zeros(1, vectorLength); // 1 row of vector length of zeros - pointSyn1VecMap.put(point, syn1); - } - - // Dot product of Syn0 and Syn1 vecs - double dot = Nd4j.getBlasWrapper().level1().dot(vectorLength, 1.0, l1, syn1); - - if (dot < -maxExp || dot >= maxExp) - continue; - - int idx = (int) ((dot + maxExp) * ((double) expTable.length / maxExp / 2.0)); - - if (idx > expTable.length) continue; - - //score - double f = expTable[idx]; - //gradient - double g = (1 - code - f) * (useAdaGrad ? w1.getGradient(i, currentSentenceAlpha, currentSentenceAlpha) : currentSentenceAlpha); - - - Nd4j.getBlasWrapper().level1().axpy(vectorLength, g, syn1, neu1e); - Nd4j.getBlasWrapper().level1().axpy(vectorLength, g, l1, syn1); - } - - int target = w1.getIndex(); - int label; - //negative sampling - if(negative > 0) - for (int d = 0; d < negative + 1; d++) { - if (d == 0) - label = 1; - else { - nextRandom.set(Math.abs(nextRandom.get() * 25214903917L + 11)); - int idx = Math.abs((int) (nextRandom.get() >> 16) % negativeHolder.getTable().length()); - - target = negativeHolder.getTable().getInt(idx); - if (target <= 0) - target = (int) nextRandom.get() % (vocab.numWords() - 1) + 1; - - if (target == w1.getIndex()) - continue; - label = 0; - } - - if(target >= negativeHolder.getSyn1Neg().rows() || target < 0) - continue; - - double f = Nd4j.getBlasWrapper().dot(l1,negativeHolder.getSyn1Neg().slice(target)); - double g; - if (f > maxExp) - g = useAdaGrad ? w1.getGradient(target, (label - 1), alpha) : (label - 1) * alpha; - else if (f < -maxExp) - g = label * (useAdaGrad ? w1.getGradient(target, alpha, alpha) : alpha); - else { - int idx = (int) ((f + maxExp) * (expTable.length / maxExp / 2)); - if (idx >= expTable.length) - continue; - - g = useAdaGrad ? w1.getGradient(target, label - expTable[idx], alpha) : (label - expTable[idx]) * alpha; - } - - Nd4j.getBlasWrapper().level1().axpy(vectorLength, g, negativeHolder.getSyn1Neg().slice(target),neu1e); - - Nd4j.getBlasWrapper().level1().axpy(vectorLength, g, l1,negativeHolder.getSyn1Neg().slice(target)); - } - - - // Updated the Syn0 vector based on gradient. Syn0 is not random anymore. - Nd4j.getBlasWrapper().level1().axpy(vectorLength, 1.0f, neu1e, l1); - - VocabWord word = vocab.elementAtIndex(currentWordIndex); - indexSyn0VecMap.put(word, l1); - } - - private INDArray getRandomSyn0Vec(int vectorLength, long lseed) { - /* - we use wordIndex as part of seed here, to guarantee that during word syn0 initialization on dwo distinct nodes, initial weights will be the same for the same word - */ - return Nd4j.rand(lseed * seed, new int[]{1 ,vectorLength}).subi(0.5).divi(vectorLength); + super(new FirstIterationFunctionAdapter(word2vecVarMapBroadcast, expTableBroadcast, vocabCacheBroadcast)); } } + diff --git a/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/FirstIterationFunctionAdapter.java b/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/FirstIterationFunctionAdapter.java new file mode 100644 index 000000000000..a8d387cc55a9 --- /dev/null +++ b/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/FirstIterationFunctionAdapter.java @@ -0,0 +1,245 @@ +package org.deeplearning4j.spark.models.embeddings.word2vec; + +import org.apache.commons.lang3.tuple.Pair; +import org.apache.spark.broadcast.Broadcast; +import org.datavec.spark.functions.FlatMapFunctionAdapter; +import org.deeplearning4j.models.word2vec.VocabWord; +import org.deeplearning4j.models.word2vec.wordstore.VocabCache; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import scala.Tuple2; + +import java.util.*; +import java.util.concurrent.atomic.AtomicLong; + +/** + * @author jeffreytang + * @author raver119@gmail.com + */ +public class FirstIterationFunctionAdapter + implements FlatMapFunctionAdapter< Iterator, Long>>, Map.Entry > { + + private int ithIteration = 1; + private int vectorLength; + private boolean useAdaGrad; + private int batchSize = 0; + private double negative; + private int window; + private double alpha; + private double minAlpha; + private long totalWordCount; + private long seed; + private int maxExp; + private double[] expTable; + private int iterations; + private Map indexSyn0VecMap; + private Map pointSyn1VecMap; + private AtomicLong nextRandom = new AtomicLong(5); + + private volatile VocabCache vocab; + private volatile NegativeHolder negativeHolder; + private AtomicLong cid = new AtomicLong(0); + private AtomicLong aff = new AtomicLong(0); + + + + + + public FirstIterationFunctionAdapter(Broadcast> word2vecVarMapBroadcast, + Broadcast expTableBroadcast, Broadcast> vocabCacheBroadcast) { + + Map word2vecVarMap = word2vecVarMapBroadcast.getValue(); + this.expTable = expTableBroadcast.getValue(); + this.vectorLength = (int) word2vecVarMap.get("vectorLength"); + this.useAdaGrad = (boolean) word2vecVarMap.get("useAdaGrad"); + this.negative = (double) word2vecVarMap.get("negative"); + this.window = (int) word2vecVarMap.get("window"); + this.alpha = (double) word2vecVarMap.get("alpha"); + this.minAlpha = (double) word2vecVarMap.get("minAlpha"); + this.totalWordCount = (long) word2vecVarMap.get("totalWordCount"); + this.seed = (long) word2vecVarMap.get("seed"); + this.maxExp = (int) word2vecVarMap.get("maxExp"); + this.iterations = (int) word2vecVarMap.get("iterations"); + this.batchSize = (int) word2vecVarMap.get("batchSize"); + this.indexSyn0VecMap = new HashMap<>(); + this.pointSyn1VecMap = new HashMap<>(); + this.vocab = vocabCacheBroadcast.getValue(); + + if (this.vocab == null) throw new RuntimeException("VocabCache is null"); + + if (negative > 0) { + negativeHolder = NegativeHolder.getInstance(); + negativeHolder.initHolder(vocab, expTable, this.vectorLength); + } + } + + + + @Override + public Iterable> call(Iterator, Long>> pairIter) { + while (pairIter.hasNext()) { + List, Long>> batch = new ArrayList<>(); + while (pairIter.hasNext() && batch.size() < batchSize) { + Tuple2, Long> pair = pairIter.next(); + List vocabWordsList = pair._1(); + Long sentenceCumSumCount = pair._2(); + batch.add(Pair.of(vocabWordsList, sentenceCumSumCount)); + } + + for (int i = 0; i < iterations; i++) { + //System.out.println("Training sentence: " + vocabWordsList); + for (Pair, Long> pair: batch) { + List vocabWordsList = pair.getKey(); + Long sentenceCumSumCount = pair.getValue(); + double currentSentenceAlpha = Math.max(minAlpha, + alpha - (alpha - minAlpha) * (sentenceCumSumCount / (double) totalWordCount)); + trainSentence(vocabWordsList, currentSentenceAlpha); + } + } + } + return indexSyn0VecMap.entrySet(); + } + + + public void trainSentence(List vocabWordsList, double currentSentenceAlpha) { + + if (vocabWordsList != null && !vocabWordsList.isEmpty()) { + for (int ithWordInSentence = 0; ithWordInSentence < vocabWordsList.size(); ithWordInSentence++) { + // Random value ranging from 0 to window size + nextRandom.set(Math.abs(nextRandom.get() * 25214903917L + 11)); + int b = (int) (long) this.nextRandom.get() % window; + VocabWord currentWord = vocabWordsList.get(ithWordInSentence); + if (currentWord != null) { + skipGram(ithWordInSentence, vocabWordsList, b, currentSentenceAlpha); + } + } + } + } + + public void skipGram(int ithWordInSentence, List vocabWordsList, int b, double currentSentenceAlpha) { + + VocabWord currentWord = vocabWordsList.get(ithWordInSentence); + if (currentWord != null && !vocabWordsList.isEmpty()) { + int end = window * 2 + 1 - b; + for (int a = b; a < end; a++) { + if (a != window) { + int c = ithWordInSentence - window + a; + if (c >= 0 && c < vocabWordsList.size()) { + VocabWord lastWord = vocabWordsList.get(c); + iterateSample(currentWord, lastWord, currentSentenceAlpha); + } + } + } + } + } + + public void iterateSample(VocabWord w1, VocabWord w2, double currentSentenceAlpha) { + + + if (w1 == null || w2 == null || w2.getIndex() < 0 || w2.getIndex() == w1.getIndex()) + return; + final int currentWordIndex = w2.getIndex(); + + // error for current word and context + INDArray neu1e = Nd4j.create(vectorLength); + + // First iteration Syn0 is random numbers + INDArray l1 = null; + if (indexSyn0VecMap.containsKey(vocab.elementAtIndex(currentWordIndex))) { + l1 = indexSyn0VecMap.get(vocab.elementAtIndex(currentWordIndex)); + } else { + l1 = getRandomSyn0Vec(vectorLength, (long) currentWordIndex); + } + + // + for (int i = 0; i < w1.getCodeLength(); i++) { + int code = w1.getCodes().get(i); + int point = w1.getPoints().get(i); + if(point < 0) + throw new IllegalStateException("Illegal point " + point); + // Point to + INDArray syn1; + if (pointSyn1VecMap.containsKey(point)) { + syn1 = pointSyn1VecMap.get(point); + } else { + syn1 = Nd4j.zeros(1, vectorLength); // 1 row of vector length of zeros + pointSyn1VecMap.put(point, syn1); + } + + // Dot product of Syn0 and Syn1 vecs + double dot = Nd4j.getBlasWrapper().level1().dot(vectorLength, 1.0, l1, syn1); + + if (dot < -maxExp || dot >= maxExp) + continue; + + int idx = (int) ((dot + maxExp) * ((double) expTable.length / maxExp / 2.0)); + + if (idx > expTable.length) continue; + + //score + double f = expTable[idx]; + //gradient + double g = (1 - code - f) * (useAdaGrad ? w1.getGradient(i, currentSentenceAlpha, currentSentenceAlpha) : currentSentenceAlpha); + + + Nd4j.getBlasWrapper().level1().axpy(vectorLength, g, syn1, neu1e); + Nd4j.getBlasWrapper().level1().axpy(vectorLength, g, l1, syn1); + } + + int target = w1.getIndex(); + int label; + //negative sampling + if(negative > 0) + for (int d = 0; d < negative + 1; d++) { + if (d == 0) + label = 1; + else { + nextRandom.set(Math.abs(nextRandom.get() * 25214903917L + 11)); + int idx = Math.abs((int) (nextRandom.get() >> 16) % negativeHolder.getTable().length()); + + target = negativeHolder.getTable().getInt(idx); + if (target <= 0) + target = (int) nextRandom.get() % (vocab.numWords() - 1) + 1; + + if (target == w1.getIndex()) + continue; + label = 0; + } + + if(target >= negativeHolder.getSyn1Neg().rows() || target < 0) + continue; + + double f = Nd4j.getBlasWrapper().dot(l1,negativeHolder.getSyn1Neg().slice(target)); + double g; + if (f > maxExp) + g = useAdaGrad ? w1.getGradient(target, (label - 1), alpha) : (label - 1) * alpha; + else if (f < -maxExp) + g = label * (useAdaGrad ? w1.getGradient(target, alpha, alpha) : alpha); + else { + int idx = (int) ((f + maxExp) * (expTable.length / maxExp / 2)); + if (idx >= expTable.length) + continue; + + g = useAdaGrad ? w1.getGradient(target, label - expTable[idx], alpha) : (label - expTable[idx]) * alpha; + } + + Nd4j.getBlasWrapper().level1().axpy(vectorLength, g, negativeHolder.getSyn1Neg().slice(target),neu1e); + + Nd4j.getBlasWrapper().level1().axpy(vectorLength, g, l1,negativeHolder.getSyn1Neg().slice(target)); + } + + + // Updated the Syn0 vector based on gradient. Syn0 is not random anymore. + Nd4j.getBlasWrapper().level1().axpy(vectorLength, 1.0f, neu1e, l1); + + VocabWord word = vocab.elementAtIndex(currentWordIndex); + indexSyn0VecMap.put(word, l1); + } + + private INDArray getRandomSyn0Vec(int vectorLength, long lseed) { + /* + we use wordIndex as part of seed here, to guarantee that during word syn0 initialization on dwo distinct nodes, initial weights will be the same for the same word + */ + return Nd4j.rand(lseed * seed, new int[]{1 ,vectorLength}).subi(0.5).divi(vectorLength); + } +} diff --git a/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/SecondIterationFunction.java b/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/SecondIterationFunction.java index e2f80f755ec1..8c31094e974a 100644 --- a/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/SecondIterationFunction.java +++ b/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/SecondIterationFunction.java @@ -1,8 +1,9 @@ package org.deeplearning4j.spark.models.embeddings.word2vec; import org.apache.commons.lang3.tuple.Pair; -import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.broadcast.Broadcast; +import org.datavec.spark.functions.FlatMapFunctionAdapter; +import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee; import org.deeplearning4j.models.word2vec.VocabWord; import org.deeplearning4j.models.word2vec.wordstore.VocabCache; import org.nd4j.linalg.api.ndarray.INDArray; @@ -21,7 +22,20 @@ * @author raver119@gmail.com */ public class SecondIterationFunction - implements FlatMapFunction< Iterator, Long>>, Entry > { + extends BaseFlatMapFunctionAdaptee< Iterator, Long>>, Entry > { + + public SecondIterationFunction(Broadcast> word2vecVarMapBroadcast, + Broadcast expTableBroadcast, Broadcast> vocabCacheBroadcast) { + super(new SecondIterationFunctionAdapter(word2vecVarMapBroadcast, expTableBroadcast, vocabCacheBroadcast)); + } +} + +/** + * @author jeffreytang + * @author raver119@gmail.com + */ +class SecondIterationFunctionAdapter + implements FlatMapFunctionAdapter< Iterator, Long>>, Entry > { private int ithIteration = 1; private int vectorLength; @@ -49,7 +63,7 @@ public class SecondIterationFunction - public SecondIterationFunction(Broadcast> word2vecVarMapBroadcast, + public SecondIterationFunctionAdapter(Broadcast> word2vecVarMapBroadcast, Broadcast expTableBroadcast, Broadcast> vocabCacheBroadcast) { Map word2vecVarMap = word2vecVarMapBroadcast.getValue(); diff --git a/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/text/TextPipelineTest.java b/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/text/TextPipelineTest.java index 5ba4b7b90a03..1f7d8046cd2a 100644 --- a/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/text/TextPipelineTest.java +++ b/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/text/TextPipelineTest.java @@ -29,6 +29,7 @@ import org.deeplearning4j.models.word2vec.VocabWord; import org.deeplearning4j.models.word2vec.wordstore.VocabCache; import org.deeplearning4j.spark.models.embeddings.word2vec.FirstIterationFunction; +import org.deeplearning4j.spark.models.embeddings.word2vec.FirstIterationFunctionAdapter; import org.deeplearning4j.spark.models.embeddings.word2vec.MapToPairFunction; import org.deeplearning4j.spark.models.embeddings.word2vec.Word2Vec; import org.deeplearning4j.spark.text.functions.CountCumSum; @@ -482,8 +483,8 @@ public void testFirstIteration() throws Exception { Iterator, Long>> iterator = vocabWordListSentenceCumSumRDD.collect().iterator(); - FirstIterationFunction firstIterationFunction = - new FirstIterationFunction(word2vecVarMapBroadcast, expTableBroadcast, pipeline.getBroadCastVocabCache()); + FirstIterationFunctionAdapter firstIterationFunction = + new FirstIterationFunctionAdapter(word2vecVarMapBroadcast, expTableBroadcast, pipeline.getBroadCastVocabCache()); Iterable> ret = firstIterationFunction.call(iterator); assertTrue(ret.iterator().hasNext()); diff --git a/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/pom.xml b/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/pom.xml index 2e9ec4d7ec67..26ea1a0257bb 100644 --- a/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/pom.xml +++ b/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/pom.xml @@ -8,6 +8,7 @@ 4.0.0 dl4j-spark-parameterserver_2.10 + 0.7.3_spark_${spark.major.version}-SNAPSHOT jar dl4j-spark-parameterserver @@ -20,7 +21,7 @@ org.nd4j nd4j-aeron - ${project.version} + ${parent.version} org.deeplearning4j diff --git a/deeplearning4j-scaleout/spark/dl4j-spark/pom.xml b/deeplearning4j-scaleout/spark/dl4j-spark/pom.xml index 0326b8cab6da..a5684e70f0a3 100644 --- a/deeplearning4j-scaleout/spark/dl4j-spark/pom.xml +++ b/deeplearning4j-scaleout/spark/dl4j-spark/pom.xml @@ -25,6 +25,7 @@ 4.0.0 dl4j-spark_2.10 + 0.7.3_spark_${spark.major.version}-SNAPSHOT jar dl4j-spark @@ -33,29 +34,8 @@ UTF-8 UTF-8 - 1.0.4 - 1.6.2 - - - cdh5 - - - org.apache.hadoop - https://repository.cloudera.com/artifactory/cloudera-repos/ - - - - 2.0.0-cdh4.6.0 - 1.2.0-cdh5.3.0 - - - - - - - scala-tools.org @@ -64,7 +44,6 @@ - @@ -72,37 +51,19 @@ org.deeplearning4j deeplearning4j-core - ${project.version} - - - - org.apache.spark - spark-mllib_2.10 - ${spark.version} - - - - org.apache.spark - spark-core_2.10 - ${spark.version} - - - javax.servlet - servlet-api - - + ${parent.version} org.datavec datavec-spark_2.10 - ${datavec.version} + ${datavec.spark.version} org.deeplearning4j deeplearning4j-ui-components - ${project.version} + ${parent.version} @@ -121,7 +82,7 @@ org.deeplearning4j deeplearning4j-play_2.10 - ${project.version} + ${parent.version} test diff --git a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerFlatMap.java b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerFlatMap.java index 55ccf0e977f3..391fd90f3506 100644 --- a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerFlatMap.java +++ b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerFlatMap.java @@ -1,6 +1,7 @@ package org.deeplearning4j.spark.api.worker; -import org.apache.spark.api.java.function.FlatMapFunction; +import org.datavec.spark.functions.FlatMapFunctionAdapter; +import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee; import org.deeplearning4j.berkeley.Pair; import org.deeplearning4j.datasets.iterator.AsyncDataSetIterator; import org.deeplearning4j.datasets.iterator.IteratorDataSetIterator; @@ -25,11 +26,24 @@ * * @author Alex Black */ -public class ExecuteWorkerFlatMap implements FlatMapFunction, R> { +public class ExecuteWorkerFlatMap extends BaseFlatMapFunctionAdaptee, R> { + + public ExecuteWorkerFlatMap(TrainingWorker worker) { + super(new ExecuteWorkerFlatMapAdapter(worker)); + } +} + +/** + * A FlatMapFunction for executing training on DataSets. + * Used in both SparkDl4jMultiLayer and SparkComputationGraph implementations + * + * @author Alex Black + */ +class ExecuteWorkerFlatMapAdapter implements FlatMapFunctionAdapter, R> { private final TrainingWorker worker; - public ExecuteWorkerFlatMap(TrainingWorker worker){ + public ExecuteWorkerFlatMapAdapter(TrainingWorker worker){ this.worker = worker; } diff --git a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerMultiDataSetFlatMap.java b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerMultiDataSetFlatMap.java index 60ad484b30b0..263e30e71c01 100644 --- a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerMultiDataSetFlatMap.java +++ b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerMultiDataSetFlatMap.java @@ -1,6 +1,7 @@ package org.deeplearning4j.spark.api.worker; -import org.apache.spark.api.java.function.FlatMapFunction; +import org.datavec.spark.functions.FlatMapFunctionAdapter; +import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee; import org.deeplearning4j.berkeley.Pair; import org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator; import org.deeplearning4j.datasets.iterator.IteratorMultiDataSetIterator; @@ -23,11 +24,23 @@ * * @author Alex Black */ -public class ExecuteWorkerMultiDataSetFlatMap implements FlatMapFunction, R> { +public class ExecuteWorkerMultiDataSetFlatMap extends BaseFlatMapFunctionAdaptee, R> { + + public ExecuteWorkerMultiDataSetFlatMap(TrainingWorker worker) { + super(new ExecuteWorkerMultiDataSetFlatMapAdapter<>(worker)); + } +} + +/** + * A FlatMapFunction for executing training on MultiDataSets. Used only in SparkComputationGraph implementation. + * + * @author Alex Black + */ +class ExecuteWorkerMultiDataSetFlatMapAdapter implements FlatMapFunctionAdapter, R> { private final TrainingWorker worker; - public ExecuteWorkerMultiDataSetFlatMap(TrainingWorker worker){ + public ExecuteWorkerMultiDataSetFlatMapAdapter(TrainingWorker worker){ this.worker = worker; } diff --git a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerPDSFlatMap.java b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerPDSFlatMap.java index bae22f8a1840..2701abfd9e74 100644 --- a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerPDSFlatMap.java +++ b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerPDSFlatMap.java @@ -1,7 +1,8 @@ package org.deeplearning4j.spark.api.worker; -import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.input.PortableDataStream; +import org.datavec.spark.functions.FlatMapFunctionAdapter; +import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee; import org.deeplearning4j.spark.api.TrainingResult; import org.deeplearning4j.spark.api.TrainingWorker; import org.deeplearning4j.spark.iterator.PortableDataStreamDataSetIterator; @@ -15,11 +16,24 @@ * * @author Alex Black */ -public class ExecuteWorkerPDSFlatMap implements FlatMapFunction, R> { - private final FlatMapFunction, R> workerFlatMap; +public class ExecuteWorkerPDSFlatMap extends BaseFlatMapFunctionAdaptee, R> { - public ExecuteWorkerPDSFlatMap(TrainingWorker worker){ - this.workerFlatMap = new ExecuteWorkerFlatMap<>(worker); + public ExecuteWorkerPDSFlatMap(TrainingWorker worker) { + super(new ExecuteWorkerPDSFlatMapAdapter<>(worker)); + } +} + +/** + * A FlatMapFunction for executing training on serialized DataSet objects, that can be loaded using a PortableDataStream + * Used in both SparkDl4jMultiLayer and SparkComputationGraph implementations + * + * @author Alex Black + */ +class ExecuteWorkerPDSFlatMapAdapter implements FlatMapFunctionAdapter, R> { + private final FlatMapFunctionAdapter, R> workerFlatMap; + + public ExecuteWorkerPDSFlatMapAdapter(TrainingWorker worker){ + this.workerFlatMap = new ExecuteWorkerFlatMapAdapter<>(worker); } @Override diff --git a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerPDSMDSFlatMap.java b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerPDSMDSFlatMap.java index 27c4ae68b1bc..db7c146597db 100644 --- a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerPDSMDSFlatMap.java +++ b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerPDSMDSFlatMap.java @@ -1,7 +1,8 @@ package org.deeplearning4j.spark.api.worker; -import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.input.PortableDataStream; +import org.datavec.spark.functions.FlatMapFunctionAdapter; +import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee; import org.deeplearning4j.spark.api.TrainingResult; import org.deeplearning4j.spark.api.TrainingWorker; import org.deeplearning4j.spark.iterator.PortableDataStreamMultiDataSetIterator; @@ -15,11 +16,24 @@ * * @author Alex Black */ -public class ExecuteWorkerPDSMDSFlatMap implements FlatMapFunction, R> { - private final FlatMapFunction, R> workerFlatMap; +public class ExecuteWorkerPDSMDSFlatMap extends BaseFlatMapFunctionAdaptee, R> { - public ExecuteWorkerPDSMDSFlatMap(TrainingWorker worker){ - this.workerFlatMap = new ExecuteWorkerMultiDataSetFlatMap<>(worker); + public ExecuteWorkerPDSMDSFlatMap(TrainingWorker worker) { + super(new ExecuteWorkerPDSMDSFlatMapAdapter<>(worker)); + } +} + +/** + * A FlatMapFunction for executing training on serialized MultiDataSet objects, that can be loaded using a PortableDataStream + * Used for SparkComputationGraph implementations only + * + * @author Alex Black + */ +class ExecuteWorkerPDSMDSFlatMapAdapter implements FlatMapFunctionAdapter, R> { + private final FlatMapFunctionAdapter, R> workerFlatMap; + + public ExecuteWorkerPDSMDSFlatMapAdapter(TrainingWorker worker){ + this.workerFlatMap = new ExecuteWorkerMultiDataSetFlatMapAdapter<>(worker); } @Override diff --git a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerPathFlatMap.java b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerPathFlatMap.java index 66ad94074797..141ef45425d4 100644 --- a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerPathFlatMap.java +++ b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerPathFlatMap.java @@ -1,6 +1,7 @@ package org.deeplearning4j.spark.api.worker; -import org.apache.spark.api.java.function.FlatMapFunction; +import org.datavec.spark.functions.FlatMapFunctionAdapter; +import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee; import org.deeplearning4j.spark.api.TrainingResult; import org.deeplearning4j.spark.api.TrainingWorker; import org.deeplearning4j.spark.api.WorkerConfiguration; @@ -18,12 +19,26 @@ * * @author Alex Black */ -public class ExecuteWorkerPathFlatMap implements FlatMapFunction, R> { - private final FlatMapFunction, R> workerFlatMap; +public class ExecuteWorkerPathFlatMap extends BaseFlatMapFunctionAdaptee, R> { + + public ExecuteWorkerPathFlatMap(TrainingWorker worker) { + super(new ExecuteWorkerPathFlatMapAdapter<>(worker)); + } +} + +/** + * A FlatMapFunction for executing training on serialized DataSet objects, that can be loaded from a path (local or HDFS) + * that is specified as a String + * Used in both SparkDl4jMultiLayer and SparkComputationGraph implementations + * + * @author Alex Black + */ +class ExecuteWorkerPathFlatMapAdapter implements FlatMapFunctionAdapter, R> { + private final FlatMapFunctionAdapter, R> workerFlatMap; private final int maxDataSetObjects; - public ExecuteWorkerPathFlatMap(TrainingWorker worker){ - this.workerFlatMap = new ExecuteWorkerFlatMap<>(worker); + public ExecuteWorkerPathFlatMapAdapter(TrainingWorker worker){ + this.workerFlatMap = new ExecuteWorkerFlatMapAdapter<>(worker); //How many dataset objects of size 'dataSetObjectNumExamples' should we load? //Only pass on the required number, not all of them (to avoid async preloading data that won't be used) diff --git a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerPathMDSFlatMap.java b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerPathMDSFlatMap.java index f61b445aa794..3ecb74e1c296 100644 --- a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerPathMDSFlatMap.java +++ b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerPathMDSFlatMap.java @@ -1,6 +1,7 @@ package org.deeplearning4j.spark.api.worker; -import org.apache.spark.api.java.function.FlatMapFunction; +import org.datavec.spark.functions.FlatMapFunctionAdapter; +import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee; import org.deeplearning4j.spark.api.TrainingResult; import org.deeplearning4j.spark.api.TrainingWorker; import org.deeplearning4j.spark.api.WorkerConfiguration; @@ -18,12 +19,26 @@ * * @author Alex Black */ -public class ExecuteWorkerPathMDSFlatMap implements FlatMapFunction, R> { - private final FlatMapFunction, R> workerFlatMap; +public class ExecuteWorkerPathMDSFlatMap extends BaseFlatMapFunctionAdaptee, R> { + + public ExecuteWorkerPathMDSFlatMap(TrainingWorker worker) { + super(new ExecuteWorkerPathMDSFlatMapAdapter<>(worker)); + } +} + +/** + * A FlatMapFunction for executing training on serialized DataSet objects, that can be loaded from a path (local or HDFS) + * that is specified as a String + * Used in both SparkDl4jMultiLayer and SparkComputationGraph implementations + * + * @author Alex Black + */ +class ExecuteWorkerPathMDSFlatMapAdapter implements FlatMapFunctionAdapter, R> { + private final FlatMapFunctionAdapter, R> workerFlatMap; private final int maxDataSetObjects; - public ExecuteWorkerPathMDSFlatMap(TrainingWorker worker){ - this.workerFlatMap = new ExecuteWorkerMultiDataSetFlatMap<>(worker); + public ExecuteWorkerPathMDSFlatMapAdapter(TrainingWorker worker){ + this.workerFlatMap = new ExecuteWorkerMultiDataSetFlatMapAdapter<>(worker); //How many dataset objects of size 'dataSetObjectNumExamples' should we load? //Only pass on the required number, not all of them (to avoid async preloading data that won't be used) diff --git a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/BatchDataSetsFunction.java b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/BatchDataSetsFunction.java index 90f37fc0c730..41312ca592a5 100644 --- a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/BatchDataSetsFunction.java +++ b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/BatchDataSetsFunction.java @@ -16,7 +16,8 @@ package org.deeplearning4j.spark.data; -import org.apache.spark.api.java.function.FlatMapFunction; +import org.datavec.spark.functions.FlatMapFunctionAdapter; +import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee; import org.nd4j.linalg.dataset.DataSet; import java.util.ArrayList; @@ -37,10 +38,31 @@ * * @author Alex Black */ -public class BatchDataSetsFunction implements FlatMapFunction,DataSet> { - private final int minibatchSize; +public class BatchDataSetsFunction extends BaseFlatMapFunctionAdaptee,DataSet> { public BatchDataSetsFunction(int minibatchSize) { + super(new BatchDataSetsFunctionAdapter(minibatchSize)); + } +} + +/** + * Function used to batch DataSet objects together. Typically used to combine singe-example DataSet objects out of + * something like {@link org.deeplearning4j.spark.datavec.DataVecDataSetFunction} together into minibatches.
+ * + * Usage: + *
+ * {@code
+ *      RDD mySingleExampleDataSets = ...;
+ *      RDD batchData = mySingleExampleDataSets.mapPartitions(new BatchDataSetsFunction(batchSize));
+ * }
+ * 
+ * + * @author Alex Black + */ +class BatchDataSetsFunctionAdapter implements FlatMapFunctionAdapter,DataSet> { + private final int minibatchSize; + + public BatchDataSetsFunctionAdapter(int minibatchSize) { this.minibatchSize = minibatchSize; } diff --git a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/shuffle/SplitDataSetExamplesPairFlatMapFunction.java b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/shuffle/SplitDataSetExamplesPairFlatMapFunction.java index ad743d992180..1f52bc05e18a 100644 --- a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/shuffle/SplitDataSetExamplesPairFlatMapFunction.java +++ b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/shuffle/SplitDataSetExamplesPairFlatMapFunction.java @@ -1,8 +1,8 @@ package org.deeplearning4j.spark.data.shuffle; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.function.FlatMapFunction; -import org.apache.spark.api.java.function.PairFlatMapFunction; +import org.deeplearning4j.spark.util.BasePairFlatMapFunctionAdaptee; +import org.deeplearning4j.spark.util.PairFlatMapFunctionAdapter; import org.nd4j.linalg.dataset.DataSet; import scala.Tuple2; @@ -18,12 +18,27 @@ * * @author Alex Black */ -public class SplitDataSetExamplesPairFlatMapFunction implements PairFlatMapFunction { +public class SplitDataSetExamplesPairFlatMapFunction extends BasePairFlatMapFunctionAdaptee { + + public SplitDataSetExamplesPairFlatMapFunction(int maxKeyIndex) { + super(new SplitDataSetExamplesPairFlatMapFunctionAdapter(maxKeyIndex)); + } +} + +/** + * A PairFlatMapFunction that splits each example in a {@link DataSet} object into its own {@link DataSet}. + * Also adds a random key (integer value) in the range 0 to maxKeyIndex-1.
+ * + * Used in {@link org.deeplearning4j.spark.util.SparkUtils#shuffleExamples(JavaRDD, int, int)} + * + * @author Alex Black + */ +class SplitDataSetExamplesPairFlatMapFunctionAdapter implements PairFlatMapFunctionAdapter { private transient Random r; private int maxKeyIndex; - public SplitDataSetExamplesPairFlatMapFunction(int maxKeyIndex){ + public SplitDataSetExamplesPairFlatMapFunctionAdapter(int maxKeyIndex){ this.maxKeyIndex = maxKeyIndex; } diff --git a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/RDDMiniBatches.java b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/RDDMiniBatches.java index 494c090ca741..e3eccd87072b 100644 --- a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/RDDMiniBatches.java +++ b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/RDDMiniBatches.java @@ -17,7 +17,8 @@ package org.deeplearning4j.spark.datavec; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.function.FlatMapFunction; +import org.datavec.spark.functions.FlatMapFunctionAdapter; +import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee; import org.nd4j.linalg.dataset.DataSet; import java.io.Serializable; @@ -43,11 +44,17 @@ public JavaRDD miniBatchesJava() { return toSplitJava.mapPartitions(new MiniBatchFunction(miniBatches)); } + public static class MiniBatchFunction extends BaseFlatMapFunctionAdaptee, DataSet> { - public static class MiniBatchFunction implements FlatMapFunction, DataSet> { + public MiniBatchFunction(int batchSize) { + super(new MiniBatchFunctionAdapter(batchSize)); + } + } + + static class MiniBatchFunctionAdapter implements FlatMapFunctionAdapter, DataSet> { private int batchSize = 10; - public MiniBatchFunction(int batchSize) { + public MiniBatchFunctionAdapter(int batchSize) { this.batchSize = batchSize; } diff --git a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/repartition/MapTupleToPairFlatMap.java b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/repartition/MapTupleToPairFlatMap.java index 1669c14c4009..98294173a881 100644 --- a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/repartition/MapTupleToPairFlatMap.java +++ b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/repartition/MapTupleToPairFlatMap.java @@ -1,6 +1,7 @@ package org.deeplearning4j.spark.impl.common.repartition; -import org.apache.spark.api.java.function.PairFlatMapFunction; +import org.deeplearning4j.spark.util.BasePairFlatMapFunctionAdaptee; +import org.deeplearning4j.spark.util.PairFlatMapFunctionAdapter; import scala.Tuple2; import java.util.ArrayList; @@ -13,7 +14,14 @@ * * @author Alex Black */ -public class MapTupleToPairFlatMap implements PairFlatMapFunction>,T,U> { +public class MapTupleToPairFlatMap extends BasePairFlatMapFunctionAdaptee>,T,U> { + + public MapTupleToPairFlatMap() { + super(new MapTupleToPairFlatMapAdapter()); + } +} + +class MapTupleToPairFlatMapAdapter implements PairFlatMapFunctionAdapter>,T,U> { @Override public Iterable> call(Iterator> tuple2Iterator) throws Exception { diff --git a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/score/BaseVaeReconstructionProbWithKeyFunction.java b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/score/BaseVaeReconstructionProbWithKeyFunctionAdapter.java similarity index 81% rename from deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/score/BaseVaeReconstructionProbWithKeyFunction.java rename to deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/score/BaseVaeReconstructionProbWithKeyFunctionAdapter.java index 47bedb9b58c7..9592d56de8d5 100644 --- a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/score/BaseVaeReconstructionProbWithKeyFunction.java +++ b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/score/BaseVaeReconstructionProbWithKeyFunctionAdapter.java @@ -11,7 +11,7 @@ * @param Type of key, associated with each example. Used to keep track of which score belongs to which example * @author Alex Black */ -public abstract class BaseVaeReconstructionProbWithKeyFunction extends BaseVaeScoreWithKeyFunction { +public abstract class BaseVaeReconstructionProbWithKeyFunctionAdapter extends BaseVaeScoreWithKeyFunctionAdapter { private final boolean useLogProbability; private final int numSamples; @@ -23,8 +23,8 @@ public abstract class BaseVaeReconstructionProbWithKeyFunction extends BaseVa * @param batchSize Batch size to use when scoring * @param numSamples Number of samples to use when calling {@link VariationalAutoencoder#reconstructionLogProbability(INDArray, int)} */ - public BaseVaeReconstructionProbWithKeyFunction(Broadcast params, Broadcast jsonConfig, boolean useLogProbability, - int batchSize, int numSamples){ + public BaseVaeReconstructionProbWithKeyFunctionAdapter(Broadcast params, Broadcast jsonConfig, boolean useLogProbability, + int batchSize, int numSamples){ super(params, jsonConfig, batchSize); this.useLogProbability = useLogProbability; this.numSamples = numSamples; diff --git a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/score/BaseVaeScoreWithKeyFunction.java b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/score/BaseVaeScoreWithKeyFunctionAdapter.java similarity index 90% rename from deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/score/BaseVaeScoreWithKeyFunction.java rename to deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/score/BaseVaeScoreWithKeyFunctionAdapter.java index cf38f168c731..e8ff45d49b53 100644 --- a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/score/BaseVaeScoreWithKeyFunction.java +++ b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/score/BaseVaeScoreWithKeyFunctionAdapter.java @@ -22,6 +22,8 @@ import org.apache.spark.api.java.function.PairFlatMapFunction; import org.apache.spark.broadcast.Broadcast; import org.deeplearning4j.nn.layers.variational.VariationalAutoencoder; +import org.deeplearning4j.spark.util.BasePairFlatMapFunctionAdaptee; +import org.deeplearning4j.spark.util.PairFlatMapFunctionAdapter; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.executioner.GridExecutioner; import org.nd4j.linalg.factory.Nd4j; @@ -42,7 +44,7 @@ * @author Alex Black */ @Slf4j -public abstract class BaseVaeScoreWithKeyFunction implements PairFlatMapFunction>, K, Double> { +public abstract class BaseVaeScoreWithKeyFunctionAdapter implements PairFlatMapFunctionAdapter>, K, Double> { protected final Broadcast params; protected final Broadcast jsonConfig; @@ -54,7 +56,7 @@ public abstract class BaseVaeScoreWithKeyFunction implements PairFlatMapFunct * @param jsonConfig MultiLayerConfiguration, as json * @param batchSize Batch size to use when scoring */ - public BaseVaeScoreWithKeyFunction(Broadcast params, Broadcast jsonConfig, int batchSize) { + public BaseVaeScoreWithKeyFunctionAdapter(Broadcast params, Broadcast jsonConfig, int batchSize) { this.params = params; this.jsonConfig = jsonConfig; this.batchSize = batchSize; diff --git a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/CGVaeReconstructionErrorWithKeyFunction.java b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/CGVaeReconstructionErrorWithKeyFunction.java index f51253a13678..22e14b6deec6 100644 --- a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/CGVaeReconstructionErrorWithKeyFunction.java +++ b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/CGVaeReconstructionErrorWithKeyFunction.java @@ -5,8 +5,7 @@ import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.layers.variational.VariationalAutoencoder; -import org.deeplearning4j.spark.impl.common.score.BaseVaeReconstructionProbWithKeyFunction; -import org.deeplearning4j.spark.impl.common.score.BaseVaeScoreWithKeyFunction; +import org.deeplearning4j.spark.impl.common.score.BaseVaeScoreWithKeyFunctionAdapter; import org.nd4j.linalg.api.ndarray.INDArray; /** @@ -18,7 +17,7 @@ * @author Alex Black * @see CGVaeReconstructionProbWithKeyFunction */ -public class CGVaeReconstructionErrorWithKeyFunction extends BaseVaeScoreWithKeyFunction { +public class CGVaeReconstructionErrorWithKeyFunction extends BaseVaeScoreWithKeyFunctionAdapter { /** diff --git a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/CGVaeReconstructionProbWithKeyFunction.java b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/CGVaeReconstructionProbWithKeyFunction.java index c0214aa2631b..f293a47401ed 100644 --- a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/CGVaeReconstructionProbWithKeyFunction.java +++ b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/CGVaeReconstructionProbWithKeyFunction.java @@ -5,8 +5,7 @@ import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.layers.variational.VariationalAutoencoder; -import org.deeplearning4j.spark.impl.common.score.BaseVaeReconstructionProbWithKeyFunction; -import org.deeplearning4j.spark.impl.common.score.BaseVaeScoreWithKeyFunction; +import org.deeplearning4j.spark.impl.common.score.BaseVaeReconstructionProbWithKeyFunctionAdapter; import org.nd4j.linalg.api.ndarray.INDArray; /** @@ -16,7 +15,7 @@ * * @author Alex Black */ -public class CGVaeReconstructionProbWithKeyFunction extends BaseVaeReconstructionProbWithKeyFunction { +public class CGVaeReconstructionProbWithKeyFunction extends BaseVaeReconstructionProbWithKeyFunctionAdapter { /** diff --git a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/GraphFeedForwardWithKeyFunction.java b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/GraphFeedForwardWithKeyFunction.java index 562134d50820..22d4d6920f0a 100644 --- a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/GraphFeedForwardWithKeyFunction.java +++ b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/GraphFeedForwardWithKeyFunction.java @@ -18,10 +18,11 @@ package org.deeplearning4j.spark.impl.graph.scoring; -import org.apache.spark.api.java.function.PairFlatMapFunction; import org.apache.spark.broadcast.Broadcast; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.graph.ComputationGraph; +import org.deeplearning4j.spark.util.BasePairFlatMapFunctionAdaptee; +import org.deeplearning4j.spark.util.PairFlatMapFunctionAdapter; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.executioner.GridExecutioner; import org.nd4j.linalg.factory.Nd4j; @@ -42,7 +43,21 @@ * @param Type of key, associated with each example. Used to keep track of which output belongs to which input example * @author Alex Black */ -public class GraphFeedForwardWithKeyFunction implements PairFlatMapFunction>, K, INDArray[]> { +public class GraphFeedForwardWithKeyFunction extends BasePairFlatMapFunctionAdaptee>, K, INDArray[]> { + + public GraphFeedForwardWithKeyFunction(Broadcast params, Broadcast jsonConfig, int batchSize) { + super(new GraphFeedForwardWithKeyFunctionAdapter(params, jsonConfig, batchSize)); + } +} + +/** + * Function to feed-forward examples, and get the network output (for example, class probabilities). + * A key value is used to keey track of which output corresponds to which input. + * + * @param Type of key, associated with each example. Used to keep track of which output belongs to which input example + * @author Alex Black + */ +class GraphFeedForwardWithKeyFunctionAdapter implements PairFlatMapFunctionAdapter>, K, INDArray[]> { protected static Logger log = LoggerFactory.getLogger(GraphFeedForwardWithKeyFunction.class); @@ -55,7 +70,7 @@ public class GraphFeedForwardWithKeyFunction implements PairFlatMapFunction 1 for efficiency) */ - public GraphFeedForwardWithKeyFunction(Broadcast params, Broadcast jsonConfig, int batchSize) { + public GraphFeedForwardWithKeyFunctionAdapter(Broadcast params, Broadcast jsonConfig, int batchSize) { this.params = params; this.jsonConfig = jsonConfig; this.batchSize = batchSize; diff --git a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreExamplesFunction.java b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreExamplesFunction.java index f077b006e7cb..bcb7bbbabf83 100644 --- a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreExamplesFunction.java +++ b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreExamplesFunction.java @@ -18,10 +18,11 @@ package org.deeplearning4j.spark.impl.graph.scoring; -import org.apache.spark.api.java.function.DoubleFlatMapFunction; import org.apache.spark.broadcast.Broadcast; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.graph.ComputationGraph; +import org.deeplearning4j.spark.util.BaseDoubleFlatMapFunctionAdaptee; +import org.deeplearning4j.spark.util.DoubleFlatMapFunctionAdapter; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.executioner.GridExecutioner; import org.nd4j.linalg.dataset.api.MultiDataSet; @@ -34,6 +35,22 @@ import java.util.Iterator; import java.util.List; + +/**Function to score examples individually. Note that scoring is batched for computational efficiency.
+ * This is essentially a Spark implementation of the {@link ComputationGraph#scoreExamples(MultiDataSet, boolean)} method
+ * Note: This method returns a score for each example, but the association between examples and scores is lost. In + * cases where we need to know the score for particular examples, use {@link ScoreExamplesWithKeyFunction} + * @author Alex Black + * @see ScoreExamplesWithKeyFunction + */ +public class ScoreExamplesFunction extends BaseDoubleFlatMapFunctionAdaptee> { + + public ScoreExamplesFunction(Broadcast params, Broadcast jsonConfig, boolean addRegularizationTerms, + int batchSize) { + super(new ScoreExamplesFunctionAdapter(params, jsonConfig, addRegularizationTerms, batchSize)); + } +} + /**Function to score examples individually. Note that scoring is batched for computational efficiency.
* This is essentially a Spark implementation of the {@link ComputationGraph#scoreExamples(MultiDataSet, boolean)} method
* Note: This method returns a score for each example, but the association between examples and scores is lost. In @@ -41,7 +58,7 @@ * @author Alex Black * @see ScoreExamplesWithKeyFunction */ -public class ScoreExamplesFunction implements DoubleFlatMapFunction> { +class ScoreExamplesFunctionAdapter implements DoubleFlatMapFunctionAdapter> { protected static final Logger log = LoggerFactory.getLogger(ScoreExamplesFunction.class); private final Broadcast params; @@ -49,7 +66,7 @@ public class ScoreExamplesFunction implements DoubleFlatMapFunction params, Broadcast jsonConfig, boolean addRegularizationTerms, + public ScoreExamplesFunctionAdapter(Broadcast params, Broadcast jsonConfig, boolean addRegularizationTerms, int batchSize){ this.params = params; this.jsonConfig = jsonConfig; diff --git a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreExamplesWithKeyFunction.java b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreExamplesWithKeyFunction.java index f898540e6579..6ad501a8e251 100644 --- a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreExamplesWithKeyFunction.java +++ b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreExamplesWithKeyFunction.java @@ -18,10 +18,11 @@ package org.deeplearning4j.spark.impl.graph.scoring; -import org.apache.spark.api.java.function.PairFlatMapFunction; import org.apache.spark.broadcast.Broadcast; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.graph.ComputationGraph; +import org.deeplearning4j.spark.util.BasePairFlatMapFunctionAdaptee; +import org.deeplearning4j.spark.util.PairFlatMapFunctionAdapter; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.executioner.GridExecutioner; import org.nd4j.linalg.dataset.api.MultiDataSet; @@ -44,7 +45,24 @@ * @param Type of key, associated with each example. Used to keep track of which score belongs to which example * @see ScoreExamplesFunction */ -public class ScoreExamplesWithKeyFunction implements PairFlatMapFunction>,K,Double> { +public class ScoreExamplesWithKeyFunction extends BasePairFlatMapFunctionAdaptee>,K,Double> { + + public ScoreExamplesWithKeyFunction(Broadcast params, Broadcast jsonConfig, boolean addRegularizationTerms, + int batchSize) { + super(new ScoreExamplesWithKeyFunctionAdapter(params, jsonConfig, addRegularizationTerms, batchSize)); + } +} + +/**Function to score examples individually, where each example is associated with a particular key
+ * Note that scoring is batched for computational efficiency.
+ * This is the Spark implementation of the {@link ComputationGraph#scoreExamples(MultiDataSet, boolean)} method
+ * Note: The MultiDataSet objects passed in must have exactly one example in them (otherwise: can't have a 1:1 association + * between keys and data sets to score) + * @author Alex Black + * @param Type of key, associated with each example. Used to keep track of which score belongs to which example + * @see ScoreExamplesFunction + */ +class ScoreExamplesWithKeyFunctionAdapter implements PairFlatMapFunctionAdapter>,K,Double> { protected static Logger log = LoggerFactory.getLogger(ScoreExamplesWithKeyFunction.class); @@ -59,7 +77,7 @@ public class ScoreExamplesWithKeyFunction implements PairFlatMapFunction params, Broadcast jsonConfig, boolean addRegularizationTerms, + public ScoreExamplesWithKeyFunctionAdapter(Broadcast params, Broadcast jsonConfig, boolean addRegularizationTerms, int batchSize){ this.params = params; this.jsonConfig = jsonConfig; diff --git a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreFlatMapFunctionCGDataSet.java b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreFlatMapFunctionCGDataSet.java index 8265adeb0afe..592369f89414 100644 --- a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreFlatMapFunctionCGDataSet.java +++ b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreFlatMapFunctionCGDataSet.java @@ -18,8 +18,9 @@ package org.deeplearning4j.spark.impl.graph.scoring; -import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.broadcast.Broadcast; +import org.datavec.spark.functions.FlatMapFunctionAdapter; +import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee; import org.deeplearning4j.datasets.iterator.IteratorDataSetIterator; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.graph.ComputationGraph; @@ -38,7 +39,15 @@ import java.util.List; /** Function used to score a DataSet using a ComputationGraph */ -public class ScoreFlatMapFunctionCGDataSet implements FlatMapFunction, Tuple2> { +public class ScoreFlatMapFunctionCGDataSet extends BaseFlatMapFunctionAdaptee, Tuple2> { + + public ScoreFlatMapFunctionCGDataSet(String json, Broadcast params, int minibatchSize) { + super(new ScoreFlatMapFunctionCGDataSetAdapter(json, params, minibatchSize)); + } +} + +/** Function used to score a DataSet using a ComputationGraph */ +class ScoreFlatMapFunctionCGDataSetAdapter implements FlatMapFunctionAdapter, Tuple2> { private static final Logger log = LoggerFactory.getLogger(ScoreFlatMapFunctionCGDataSet.class); private String json; @@ -46,7 +55,7 @@ public class ScoreFlatMapFunctionCGDataSet implements FlatMapFunction params, int minibatchSize){ + public ScoreFlatMapFunctionCGDataSetAdapter(String json, Broadcast params, int minibatchSize){ this.json = json; this.params = params; this.minibatchSize = minibatchSize; diff --git a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreFlatMapFunctionCGMultiDataSet.java b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreFlatMapFunctionCGMultiDataSet.java index 55a9367c3196..f197045160df 100644 --- a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreFlatMapFunctionCGMultiDataSet.java +++ b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreFlatMapFunctionCGMultiDataSet.java @@ -18,8 +18,9 @@ package org.deeplearning4j.spark.impl.graph.scoring; -import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.broadcast.Broadcast; +import org.datavec.spark.functions.FlatMapFunctionAdapter; +import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee; import org.deeplearning4j.datasets.iterator.IteratorMultiDataSetIterator; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.graph.ComputationGraph; @@ -38,7 +39,15 @@ import java.util.List; /** Function used to score a MultiDataSet using a given ComputationGraph */ -public class ScoreFlatMapFunctionCGMultiDataSet implements FlatMapFunction,Tuple2> { +public class ScoreFlatMapFunctionCGMultiDataSet extends BaseFlatMapFunctionAdaptee,Tuple2> { + + public ScoreFlatMapFunctionCGMultiDataSet(String json, Broadcast params, int minibatchSize) { + super(new ScoreFlatMapFunctionCGMultiDataSetAdapter(json, params, minibatchSize)); + } +} + +/** Function used to score a MultiDataSet using a given ComputationGraph */ +class ScoreFlatMapFunctionCGMultiDataSetAdapter implements FlatMapFunctionAdapter,Tuple2> { private static final Logger log = LoggerFactory.getLogger(ScoreFlatMapFunctionCGMultiDataSet.class); private String json; @@ -46,7 +55,7 @@ public class ScoreFlatMapFunctionCGMultiDataSet implements FlatMapFunction params, int minibatchSize){ + public ScoreFlatMapFunctionCGMultiDataSetAdapter(String json, Broadcast params, int minibatchSize){ this.json = json; this.params = params; this.minibatchSize = minibatchSize; diff --git a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/layer/IterativeReduceFlatMap.java b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/layer/IterativeReduceFlatMap.java index 0f17b0905b57..cc0a9bfcb5bd 100644 --- a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/layer/IterativeReduceFlatMap.java +++ b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/layer/IterativeReduceFlatMap.java @@ -18,8 +18,9 @@ package org.deeplearning4j.spark.impl.layer; -import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.broadcast.Broadcast; +import org.datavec.spark.functions.FlatMapFunctionAdapter; +import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.layers.OutputLayer; @@ -40,7 +41,20 @@ * * @author Adam Gibson */ -public class IterativeReduceFlatMap implements FlatMapFunction,INDArray> { +public class IterativeReduceFlatMap extends BaseFlatMapFunctionAdaptee,INDArray> { + + public IterativeReduceFlatMap(String json, Broadcast params) { + super(new IterativeReduceFlatMapAdapter(json, params)); + } +} + +/** + * Iterative reduce with + * flat map using map partitions + * + * @author Adam Gibson + */ +class IterativeReduceFlatMapAdapter implements FlatMapFunctionAdapter,INDArray> { private String json; private Broadcast params; @@ -51,7 +65,7 @@ public class IterativeReduceFlatMap implements FlatMapFunction * @param json json configuration for the network * @param params the parameters to use for the network */ - public IterativeReduceFlatMap(String json, Broadcast params) { + public IterativeReduceFlatMapAdapter(String json, Broadcast params) { this.json = json; this.params = params; } diff --git a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/evaluation/IEvaluateFlatMapFunction.java b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/evaluation/IEvaluateFlatMapFunction.java index 6da0ca8e7340..7679d20f4bff 100644 --- a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/evaluation/IEvaluateFlatMapFunction.java +++ b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/evaluation/IEvaluateFlatMapFunction.java @@ -18,8 +18,9 @@ package org.deeplearning4j.spark.impl.multilayer.evaluation; -import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.broadcast.Broadcast; +import org.datavec.spark.functions.FlatMapFunctionAdapter; +import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee; import org.deeplearning4j.eval.IEvaluation; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; @@ -40,7 +41,21 @@ * * @author Alex Black */ -public class IEvaluateFlatMapFunction implements FlatMapFunction, T> { +public class IEvaluateFlatMapFunction extends BaseFlatMapFunctionAdaptee, T> { + + public IEvaluateFlatMapFunction(Broadcast json, Broadcast params, int evalBatchSize, T evaluation) { + super(new IEvaluateFlatMapFunctionAdapter<>(json, params, evalBatchSize, evaluation)); + } +} + +/** + * Function to evaluate data (using an IEvaluation instance), in a distributed manner + * Flat map function used to batch examples for computational efficiency + reduce number of IEvaluation objects returned + * for network efficiency. + * + * @author Alex Black + */ +class IEvaluateFlatMapFunctionAdapter implements FlatMapFunctionAdapter, T> { protected static Logger log = LoggerFactory.getLogger(IEvaluateFlatMapFunction.class); @@ -56,7 +71,7 @@ public class IEvaluateFlatMapFunction implements FlatMapF * this. Used to avoid doing too many at once (and hence memory issues) * @param evaluation Initial evaulation instance (i.e., empty Evaluation or RegressionEvaluation instance) */ - public IEvaluateFlatMapFunction(Broadcast json, Broadcast params, int evalBatchSize, T evaluation){ + public IEvaluateFlatMapFunctionAdapter(Broadcast json, Broadcast params, int evalBatchSize, T evaluation){ this.json = json; this.params = params; this.evalBatchSize = evalBatchSize; diff --git a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/FeedForwardWithKeyFunction.java b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/FeedForwardWithKeyFunction.java index 600080917339..07513f5fdcf6 100644 --- a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/FeedForwardWithKeyFunction.java +++ b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/FeedForwardWithKeyFunction.java @@ -18,10 +18,11 @@ package org.deeplearning4j.spark.impl.multilayer.scoring; -import org.apache.spark.api.java.function.PairFlatMapFunction; import org.apache.spark.broadcast.Broadcast; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; +import org.deeplearning4j.spark.util.BasePairFlatMapFunctionAdaptee; +import org.deeplearning4j.spark.util.PairFlatMapFunctionAdapter; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.executioner.GridExecutioner; import org.nd4j.linalg.factory.Nd4j; @@ -39,7 +40,21 @@ * @param Type of key, associated with each example. Used to keep track of which output belongs to which input example * @author Alex Black */ -public class FeedForwardWithKeyFunction implements PairFlatMapFunction>, K, INDArray> { +public class FeedForwardWithKeyFunction extends BasePairFlatMapFunctionAdaptee>, K, INDArray> { + + public FeedForwardWithKeyFunction(Broadcast params, Broadcast jsonConfig, int batchSize) { + super(new FeedForwardWithKeyFunctionAdapter(params, jsonConfig, batchSize)); + } +} + +/** + * Function to feed-forward examples, and get the network output (for example, class probabilities). + * A key value is used to keey track of which output corresponds to which input. + * + * @param Type of key, associated with each example. Used to keep track of which output belongs to which input example + * @author Alex Black + */ +class FeedForwardWithKeyFunctionAdapter implements PairFlatMapFunctionAdapter>, K, INDArray> { protected static Logger log = LoggerFactory.getLogger(FeedForwardWithKeyFunction.class); @@ -52,7 +67,7 @@ public class FeedForwardWithKeyFunction implements PairFlatMapFunction 1 for efficiency) */ - public FeedForwardWithKeyFunction(Broadcast params, Broadcast jsonConfig, int batchSize) { + public FeedForwardWithKeyFunctionAdapter(Broadcast params, Broadcast jsonConfig, int batchSize) { this.params = params; this.jsonConfig = jsonConfig; this.batchSize = batchSize; diff --git a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/ScoreExamplesFunction.java b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/ScoreExamplesFunction.java index d130836be9f7..9dbc48e65354 100644 --- a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/ScoreExamplesFunction.java +++ b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/ScoreExamplesFunction.java @@ -18,10 +18,11 @@ package org.deeplearning4j.spark.impl.multilayer.scoring; -import org.apache.spark.api.java.function.DoubleFlatMapFunction; import org.apache.spark.broadcast.Broadcast; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; +import org.deeplearning4j.spark.util.BaseDoubleFlatMapFunctionAdaptee; +import org.deeplearning4j.spark.util.DoubleFlatMapFunctionAdapter; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.executioner.GridExecutioner; import org.nd4j.linalg.dataset.DataSet; @@ -41,7 +42,22 @@ * @author Alex Black * @see ScoreExamplesWithKeyFunction */ -public class ScoreExamplesFunction implements DoubleFlatMapFunction> { +public class ScoreExamplesFunction extends BaseDoubleFlatMapFunctionAdaptee> { + + public ScoreExamplesFunction(Broadcast params, Broadcast jsonConfig, boolean addRegularizationTerms, + int batchSize) { + super(new ScoreExamplesFunctionAdapter(params, jsonConfig, addRegularizationTerms, batchSize)); + } +} + +/**Function to score examples individually. Note that scoring is batched for computational efficiency.
+ * This is essentially a Spark implementation of the {@link MultiLayerNetwork#scoreExamples(DataSet, boolean)} method
+ * Note: This method returns a score for each example, but the association between examples and scores is lost. In + * cases where we need to know the score for particular examples, use {@link ScoreExamplesWithKeyFunction} + * @author Alex Black + * @see ScoreExamplesWithKeyFunction + */ +class ScoreExamplesFunctionAdapter implements DoubleFlatMapFunctionAdapter> { protected static Logger log = LoggerFactory.getLogger(ScoreExamplesFunction.class); @@ -50,7 +66,7 @@ public class ScoreExamplesFunction implements DoubleFlatMapFunction params, Broadcast jsonConfig, boolean addRegularizationTerms, + public ScoreExamplesFunctionAdapter(Broadcast params, Broadcast jsonConfig, boolean addRegularizationTerms, int batchSize){ this.params = params; this.jsonConfig = jsonConfig; diff --git a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/ScoreExamplesWithKeyFunction.java b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/ScoreExamplesWithKeyFunction.java index 05d637a1497c..dc8abea1ff76 100644 --- a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/ScoreExamplesWithKeyFunction.java +++ b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/ScoreExamplesWithKeyFunction.java @@ -22,6 +22,8 @@ import org.apache.spark.broadcast.Broadcast; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; +import org.deeplearning4j.spark.util.BasePairFlatMapFunctionAdaptee; +import org.deeplearning4j.spark.util.PairFlatMapFunctionAdapter; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.executioner.GridExecutioner; import org.nd4j.linalg.dataset.DataSet; @@ -46,7 +48,26 @@ * @author Alex Black * @see ScoreExamplesFunction */ -public class ScoreExamplesWithKeyFunction implements PairFlatMapFunction>, K, Double> { +public class ScoreExamplesWithKeyFunction extends BasePairFlatMapFunctionAdaptee>, K, Double> { + + public ScoreExamplesWithKeyFunction(Broadcast params, Broadcast jsonConfig, boolean addRegularizationTerms, + int batchSize) { + super(new ScoreExamplesWithKeyFunctionAdapter(params, jsonConfig, addRegularizationTerms, batchSize)); + } +} + +/** + * Function to score examples individually, where each example is associated with a particular key
+ * Note that scoring is batched for computational efficiency.
+ * This is the Spark implementation of t he {@link MultiLayerNetwork#scoreExamples(DataSet, boolean)} method
+ * Note: The DataSet objects passed in must have exactly one example in them (otherwise: can't have a 1:1 association + * between keys and data sets to score) + * + * @param Type of key, associated with each example. Used to keep track of which score belongs to which example + * @author Alex Black + * @see ScoreExamplesFunction + */ +class ScoreExamplesWithKeyFunctionAdapter implements PairFlatMapFunctionAdapter>, K, Double> { protected static Logger log = LoggerFactory.getLogger(ScoreExamplesWithKeyFunction.class); @@ -61,7 +82,7 @@ public class ScoreExamplesWithKeyFunction implements PairFlatMapFunction params, Broadcast jsonConfig, boolean addRegularizationTerms, + public ScoreExamplesWithKeyFunctionAdapter(Broadcast params, Broadcast jsonConfig, boolean addRegularizationTerms, int batchSize) { this.params = params; this.jsonConfig = jsonConfig; diff --git a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/ScoreFlatMapFunction.java b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/ScoreFlatMapFunction.java index fb92a079ce74..68b72ee13640 100644 --- a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/ScoreFlatMapFunction.java +++ b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/ScoreFlatMapFunction.java @@ -1,7 +1,8 @@ package org.deeplearning4j.spark.impl.multilayer.scoring; -import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.broadcast.Broadcast; +import org.datavec.spark.functions.FlatMapFunctionAdapter; +import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee; import org.deeplearning4j.datasets.iterator.IteratorDataSetIterator; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; @@ -19,21 +20,30 @@ import java.util.Iterator; import java.util.List; -public class ScoreFlatMapFunction implements FlatMapFunction,Tuple2> { +public class ScoreFlatMapFunction extends BaseFlatMapFunctionAdaptee,Tuple2> { + + public ScoreFlatMapFunction(String json, Broadcast params, int minibatchSize){ + super(new ScoreFlatMapFunctionAdapter(json, params, minibatchSize)); + } + +} + +class ScoreFlatMapFunctionAdapter implements FlatMapFunctionAdapter,Tuple2> { + private static final Logger log = LoggerFactory.getLogger(ScoreFlatMapFunction.class); private String json; private Broadcast params; private int minibatchSize; - public ScoreFlatMapFunction(String json, Broadcast params, int minibatchSize){ + public ScoreFlatMapFunctionAdapter(String json, Broadcast params, int minibatchSize){ this.json = json; this.params = params; this.minibatchSize = minibatchSize; } @Override - public Iterable> call(Iterator dataSetIterator) throws Exception { + public Iterable> call(Iterator dataSetIterator) throws Exception { if(!dataSetIterator.hasNext()) { return Collections.singletonList(new Tuple2<>(0,0.0)); } diff --git a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/VaeReconstructionErrorWithKeyFunction.java b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/VaeReconstructionErrorWithKeyFunction.java index 7fd6983552c8..b9397b0ceb6a 100644 --- a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/VaeReconstructionErrorWithKeyFunction.java +++ b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/VaeReconstructionErrorWithKeyFunction.java @@ -5,9 +5,14 @@ import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.layers.variational.VariationalAutoencoder; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.deeplearning4j.spark.impl.common.score.BaseVaeReconstructionProbWithKeyFunction; -import org.deeplearning4j.spark.impl.common.score.BaseVaeScoreWithKeyFunction; +import org.deeplearning4j.spark.impl.common.score.BaseVaeScoreWithKeyFunctionAdapter; +import org.deeplearning4j.spark.util.BasePairFlatMapFunctionAdaptee; +import org.deeplearning4j.spark.util.PairFlatMapFunctionAdapter; import org.nd4j.linalg.api.ndarray.INDArray; +import scala.Tuple2; + +import java.util.Iterator; + /** * Function to calculate the reconstruction error for a variational autoencoder, that is the first layer in a @@ -18,14 +23,30 @@ * @author Alex Black * @see VaeReconstructionProbWithKeyFunction */ -public class VaeReconstructionErrorWithKeyFunction extends BaseVaeScoreWithKeyFunction { +public class VaeReconstructionErrorWithKeyFunction extends BasePairFlatMapFunctionAdaptee>, K, Double> { + + public VaeReconstructionErrorWithKeyFunction(Broadcast params, Broadcast jsonConfig, int batchSize) { + super(new VaeReconstructionErrorWithKeyFunctionAdapter(params, jsonConfig, batchSize)); + } +} + +/** + * Function to calculate the reconstruction error for a variational autoencoder, that is the first layer in a + * MultiLayerNetwork.
+ * Note that the VAE must be using a loss function, not a {@link org.deeplearning4j.nn.conf.layers.variational.ReconstructionDistribution}
+ * Also note that scoring is batched for computational efficiency.
+ * + * @author Alex Black + * @see VaeReconstructionProbWithKeyFunction + */ +class VaeReconstructionErrorWithKeyFunctionAdapter extends BaseVaeScoreWithKeyFunctionAdapter { /** * @param params MultiLayerNetwork parameters * @param jsonConfig MultiLayerConfiguration, as json * @param batchSize Batch size to use when scoring */ - public VaeReconstructionErrorWithKeyFunction(Broadcast params, Broadcast jsonConfig, int batchSize) { + public VaeReconstructionErrorWithKeyFunctionAdapter(Broadcast params, Broadcast jsonConfig, int batchSize) { super(params, jsonConfig, batchSize); } diff --git a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/VaeReconstructionProbWithKeyFunction.java b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/VaeReconstructionProbWithKeyFunction.java index a244ca8c16ce..6607a8fcaa1d 100644 --- a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/VaeReconstructionProbWithKeyFunction.java +++ b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/VaeReconstructionProbWithKeyFunction.java @@ -5,9 +5,14 @@ import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.layers.variational.VariationalAutoencoder; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.deeplearning4j.spark.impl.common.score.BaseVaeReconstructionProbWithKeyFunction; -import org.deeplearning4j.spark.impl.common.score.BaseVaeScoreWithKeyFunction; +import org.deeplearning4j.spark.impl.common.score.BaseVaeReconstructionProbWithKeyFunctionAdapter; +import org.deeplearning4j.spark.util.BasePairFlatMapFunctionAdaptee; +import org.deeplearning4j.spark.util.PairFlatMapFunctionAdapter; import org.nd4j.linalg.api.ndarray.INDArray; +import scala.Tuple2; + +import java.util.Iterator; + /** * Function to calculate the reconstruction probability for a variational autoencoder, that is the first layer in a @@ -16,7 +21,21 @@ * * @author Alex Black */ -public class VaeReconstructionProbWithKeyFunction extends BaseVaeReconstructionProbWithKeyFunction { +public class VaeReconstructionProbWithKeyFunction extends BasePairFlatMapFunctionAdaptee>, K, Double> { + + public VaeReconstructionProbWithKeyFunction(Broadcast params, Broadcast jsonConfig, boolean useLogProbability, int batchSize, int numSamples) { + super(new VaeReconstructionProbWithKeyFunctionAdapter(params, jsonConfig, useLogProbability, batchSize, numSamples)); + } +} + +/** + * Function to calculate the reconstruction probability for a variational autoencoder, that is the first layer in a + * MultiLayerNetwork.
+ * Note that scoring is batched for computational efficiency.
+ * + * @author Alex Black + */ +class VaeReconstructionProbWithKeyFunctionAdapter extends BaseVaeReconstructionProbWithKeyFunctionAdapter { /** @@ -26,7 +45,7 @@ public class VaeReconstructionProbWithKeyFunction extends BaseVaeReconstructi * @param batchSize Batch size to use when scoring * @param numSamples Number of samples to use when calling {@link VariationalAutoencoder#reconstructionLogProbability(INDArray, int)} */ - public VaeReconstructionProbWithKeyFunction(Broadcast params, Broadcast jsonConfig, boolean useLogProbability, int batchSize, int numSamples) { + public VaeReconstructionProbWithKeyFunctionAdapter(Broadcast params, Broadcast jsonConfig, boolean useLogProbability, int batchSize, int numSamples) { super(params, jsonConfig, useLogProbability, batchSize, numSamples); } diff --git a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/iterator/PortableDataStreamMultiDataSetIterator.java b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/iterator/PortableDataStreamMultiDataSetIterator.java index 2639f006a87f..cd7dc18f44bc 100644 --- a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/iterator/PortableDataStreamMultiDataSetIterator.java +++ b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/iterator/PortableDataStreamMultiDataSetIterator.java @@ -72,8 +72,6 @@ public MultiDataSet next() { ds.load(is); } catch(IOException e){ throw new RuntimeException(e); - } finally { - pds.close(); } if(preprocessor != null) preprocessor.preProcess(ds); diff --git a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/util/DoubleFlatMapFunctionAdapter.java b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/util/DoubleFlatMapFunctionAdapter.java new file mode 100644 index 000000000000..1aa51e8f1fcd --- /dev/null +++ b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/util/DoubleFlatMapFunctionAdapter.java @@ -0,0 +1,13 @@ +package org.deeplearning4j.spark.util; + +import java.io.Serializable; + +/** + * + * A function that returns zero or more records of type Double from each input record. + * + * Adapter for Spark interface in order to freeze interface changes between spark versions + */ +public interface DoubleFlatMapFunctionAdapter extends Serializable { + Iterable call(T t) throws Exception; +} diff --git a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/util/MLLibUtil.java b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/util/MLLibUtil.java index 043ae409b748..96d7a2f8bcbb 100644 --- a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/util/MLLibUtil.java +++ b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/util/MLLibUtil.java @@ -34,6 +34,8 @@ import org.datavec.api.records.reader.RecordReader; import org.datavec.api.split.InputStreamInputSplit; import org.datavec.api.writable.Writable; +import org.datavec.spark.functions.FlatMapFunctionAdapter; +import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.factory.Nd4j; @@ -223,12 +225,12 @@ public DataSet call(DataSet v1, DataSet v2) throws Exception { }, (int) (mappedData.count() / batchSize)); - JavaRDD data2 = aggregated.flatMap(new FlatMapFunction, DataSet>() { + JavaRDD data2 = aggregated.flatMap(new BaseFlatMapFunctionAdaptee, DataSet>(new FlatMapFunctionAdapter, DataSet>() { @Override public Iterable call(Tuple2 longDataSetTuple2) throws Exception { return longDataSetTuple2._2(); } - }); + })); return data2; } @@ -365,7 +367,7 @@ public DataSet call(LabeledPoint lp) { } /** - * Convert an rdd of data set in to labeled point. + * Convert an rdd of data set in to labeled point. * @param data the dataset to convert * @return an rdd of labeled point */ diff --git a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/util/PairFlatMapFunctionAdapter.java b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/util/PairFlatMapFunctionAdapter.java new file mode 100644 index 000000000000..30e80761ac3e --- /dev/null +++ b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/util/PairFlatMapFunctionAdapter.java @@ -0,0 +1,16 @@ +package org.deeplearning4j.spark.util; + +import scala.Tuple2; + +import java.io.Serializable; + +/** + * + * A function that returns zero or more key-value pair records from each input record. The + * key-value pairs are represented as scala.Tuple2 objects. + * + * Adapter for Spark interface in order to freeze interface changes between spark versions + */ +public interface PairFlatMapFunctionAdapter extends Serializable { + Iterable> call(T t) throws Exception; +} diff --git a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/spark-1/org/deeplearning4j/spark/util/BaseDoubleFlatMapFunctionAdaptee.java b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/spark-1/org/deeplearning4j/spark/util/BaseDoubleFlatMapFunctionAdaptee.java new file mode 100644 index 000000000000..282287d3991e --- /dev/null +++ b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/spark-1/org/deeplearning4j/spark/util/BaseDoubleFlatMapFunctionAdaptee.java @@ -0,0 +1,25 @@ +package org.deeplearning4j.spark.util; + +import org.apache.spark.api.java.function.DoubleFlatMapFunction; + +import java.util.Iterator; + +/** + * DoubleFlatMapFunction adapter to hide incompatibilities between Spark 1.x and Spark 2.x + * + * This class should be used instead of direct referral to DoubleFlatMapFunction + * + */ +public class BaseDoubleFlatMapFunctionAdaptee implements DoubleFlatMapFunction { + + protected final DoubleFlatMapFunctionAdapter adapter; + + public BaseDoubleFlatMapFunctionAdaptee(DoubleFlatMapFunctionAdapter adapter) { + this.adapter = adapter; + } + + @Override + public Iterable call(T t) throws Exception { + return adapter.call(t); + } +} diff --git a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/spark-1/org/deeplearning4j/spark/util/BasePairFlatMapFunctionAdaptee.java b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/spark-1/org/deeplearning4j/spark/util/BasePairFlatMapFunctionAdaptee.java new file mode 100644 index 000000000000..99106b1475da --- /dev/null +++ b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/spark-1/org/deeplearning4j/spark/util/BasePairFlatMapFunctionAdaptee.java @@ -0,0 +1,26 @@ +package org.deeplearning4j.spark.util; + +import org.apache.spark.api.java.function.PairFlatMapFunction; +import scala.Tuple2; + +import java.util.Iterator; + +/** + * PairFlatMapFunction adapter to hide incompatibilities between Spark 1.x and Spark 2.x + * + * This class should be used instead of direct referral to PairFlatMapFunction + * + */ +public class BasePairFlatMapFunctionAdaptee implements PairFlatMapFunction { + + protected final PairFlatMapFunctionAdapter adapter; + + public BasePairFlatMapFunctionAdaptee(PairFlatMapFunctionAdapter adapter) { + this.adapter = adapter; + } + + @Override + public Iterable> call(T t) throws Exception { + return adapter.call(t); + } +} diff --git a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/spark-2/org/deeplearning4j/spark/util/BaseDoubleFlatMapFunctionAdaptee.java b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/spark-2/org/deeplearning4j/spark/util/BaseDoubleFlatMapFunctionAdaptee.java new file mode 100644 index 000000000000..46ab08a54cd5 --- /dev/null +++ b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/spark-2/org/deeplearning4j/spark/util/BaseDoubleFlatMapFunctionAdaptee.java @@ -0,0 +1,25 @@ +package org.deeplearning4j.spark.util; + +import org.apache.spark.api.java.function.DoubleFlatMapFunction; + +import java.util.Iterator; + +/** + * DoubleFlatMapFunction adapter to hide incompatibilities between Spark 1.x and Spark 2.x + * + * This class should be used instead of direct referral to DoubleFlatMapFunction + * + */ +public class BaseDoubleFlatMapFunctionAdaptee implements DoubleFlatMapFunction { + + protected final DoubleFlatMapFunctionAdapter adapter; + + public BaseDoubleFlatMapFunctionAdaptee(DoubleFlatMapFunctionAdapter adapter) { + this.adapter = adapter; + } + + @Override + public Iterator call(T t) throws Exception { + return adapter.call(t).iterator(); + } +} diff --git a/deeplearning4j-scaleout/spark/dl4j-spark/src/main/spark-2/org/deeplearning4j/spark/util/BasePairFlatMapFunctionAdaptee.java b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/spark-2/org/deeplearning4j/spark/util/BasePairFlatMapFunctionAdaptee.java new file mode 100644 index 000000000000..5de35a39e4c2 --- /dev/null +++ b/deeplearning4j-scaleout/spark/dl4j-spark/src/main/spark-2/org/deeplearning4j/spark/util/BasePairFlatMapFunctionAdaptee.java @@ -0,0 +1,26 @@ +package org.deeplearning4j.spark.util; + +import org.apache.spark.api.java.function.PairFlatMapFunction; +import scala.Tuple2; + +import java.util.Iterator; + +/** + * PairFlatMapFunction adapter to hide incompatibilities between Spark 1.x and Spark 2.x + * + * This class should be used instead of direct referral to PairFlatMapFunction + * + */ +public class BasePairFlatMapFunctionAdaptee implements PairFlatMapFunction { + + protected final PairFlatMapFunctionAdapter adapter; + + public BasePairFlatMapFunctionAdaptee(PairFlatMapFunctionAdapter adapter) { + this.adapter = adapter; + } + + @Override + public Iterator> call(T t) throws Exception { + return adapter.call(t).iterator(); + } +} diff --git a/deeplearning4j-scaleout/spark/pom.xml b/deeplearning4j-scaleout/spark/pom.xml index 6d351dd4a7a1..b515a4fb88ea 100644 --- a/deeplearning4j-scaleout/spark/pom.xml +++ b/deeplearning4j-scaleout/spark/pom.xml @@ -42,6 +42,24 @@ + + + org.codehaus.mojo + build-helper-maven-plugin + 1.12 + + + add-source + generate-sources + add-source + + + src/main/spark-${spark.major.version} + + + + + org.scala-tools maven-scala-plugin @@ -58,8 +76,6 @@ 2.10.6 - - @@ -101,4 +117,65 @@
+ + + spark-default + + + !spark.major.version + + + + 1.6.2 + 2.2.0 + 1 + + + + spark-1 + + + spark.major.version + 1 + + + + 1.6.2 + 2.2.0 + + + + spark-2 + + + spark.major.version + 2 + + + + 2.1.0 + 2.2.0 + + + + com.typesafe.akka + akka-remote_2.11 + 2.3.11 + + + + + cdh5 + + + org.apache.hadoop + https://repository.cloudera.com/artifactory/cloudera-repos/ + + + + 2.0.0-cdh4.6.0 + 1.2.0-cdh5.3.0 + + + diff --git a/perform-release.sh b/perform-release.sh index 016703f97f15..f9bd20987991 100755 --- a/perform-release.sh +++ b/perform-release.sh @@ -25,9 +25,11 @@ mvn versions:set -DallowSnapshots=true -DgenerateBackupPoms=false -DnewVersion=$ source change-scala-versions.sh 2.10 source change-cuda-versions.sh 7.5 mvn clean deploy -Dgpg.executable=gpg2 -DperformRelease -Psonatype-oss-release -DskipTests -DstagingRepositoryId=$STAGING_REPOSITORY +mvn clean deploy -Dgpg.executable=gpg2 -DperformRelease -Psonatype-oss-release -DskipTests -DstagingRepositoryId=$STAGING_REPOSITORY -Dspark.major.version=2 source change-scala-versions.sh 2.11 source change-cuda-versions.sh 8.0 mvn clean deploy -Dgpg.executable=gpg2 -DperformRelease -Psonatype-oss-release -DskipTests -DstagingRepositoryId=$STAGING_REPOSITORY +mvn clean deploy -Dgpg.executable=gpg2 -DperformRelease -Psonatype-oss-release -DskipTests -DstagingRepositoryId=$STAGING_REPOSITORY -Dspark.major.version=2 source change-scala-versions.sh 2.10 source change-cuda-versions.sh 8.0 diff --git a/pom.xml b/pom.xml index d2e434ffe49a..89404afcb724 100644 --- a/pom.xml +++ b/pom.xml @@ -67,6 +67,8 @@ 1.3 0.7.3-SNAPSHOT 0.7.3-SNAPSHOT + 1 + 0.7.3_spark_${spark.major.version}-SNAPSHOT 3.4.1 3.3.1 2.4 @@ -87,7 +89,6 @@ 2.16.3 3.4.6 0.8.2.2 - 1.5.2 0.5.4 3.0.2 3.15.1 @@ -102,6 +103,7 @@ 1.12 1.0.0-beta5 2.19.1 + 0.7.3-SNAPSHOT @@ -114,7 +116,7 @@ org.deeplearning4j dl4j-test-resources - ${project.version} + ${dl4j-test-resources.version} test @@ -131,7 +133,7 @@ org.deeplearning4j dl4j-test-resources - ${project.version} + ${dl4j-test-resources.version} test @@ -153,7 +155,7 @@ org.deeplearning4j dl4j-test-resources - ${project.version} + ${dl4j-test-resources.version} test diff --git a/runtests.sh b/runtests.sh index de45649fe550..869dd2033a57 100755 --- a/runtests.sh +++ b/runtests.sh @@ -34,6 +34,12 @@ mvn clean test >> $BUILD_OUTPUT 2>&1 # The build finished without returning an error so dump a tail of the output dump_output +# Repeat for Spark 2 +mvn clean test -Dspark.major.version=2 >> $BUILD_OUTPUT 2>&1 + +# The build finished without returning an error so dump a tail of the output +dump_output + # nicely terminate the ping output loop kill $PING_LOOP_PID