Skip to content

Commit

Permalink
Spark 2.x support
Browse files Browse the repository at this point in the history
Former-commit-id: 657e26a
  • Loading branch information
Lukasz Jastrzebski committed Jan 25, 2017
1 parent 156b0e7 commit 2fc4969
Show file tree
Hide file tree
Showing 59 changed files with 1,160 additions and 414 deletions.
2 changes: 2 additions & 0 deletions buildmultiplescalaversions.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
set -eu
./change-scala-versions.sh 2.10
mvn "$@"
mvn -Dspark.major.version=2 "$@"
./change-scala-versions.sh 2.11
mvn "$@"
mvn -Dspark.major.version=2 "$@"
./change-scala-versions.sh 2.10
97 changes: 88 additions & 9 deletions deeplearning4j-scaleout/dl4j-streaming/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@

<modelVersion>4.0.0</modelVersion>

<groupId>org.deeplearning4j</groupId>
<artifactId>dl4j-streaming_2.10</artifactId>
<packaging>jar</packaging>
<version>0.7.3-SNAPSHOT</version>
<version>0.7.3_spark_${spark.major.version}-SNAPSHOT</version>

<parent>
<artifactId>deeplearning4j-scaleout</artifactId>
Expand All @@ -27,7 +26,7 @@
<dependency>
<groupId>org.datavec</groupId>
<artifactId>datavec-camel</artifactId>
<version>${project.version}</version>
<version>${parent.version}</version>
</dependency>

<dependency>
Expand Down Expand Up @@ -56,15 +55,10 @@
<artifactId>spark-streaming_2.10</artifactId>
<version>${spark.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-streaming-kafka_2.10</artifactId>
<version>${spark.version}</version>
</dependency>
<dependency>
<groupId>org.datavec</groupId>
<artifactId>datavec-spark_2.10</artifactId>
<version>${datavec.version}</version>
<version>${datavec.spark.version}</version>
</dependency>

<dependency>
Expand Down Expand Up @@ -108,5 +102,90 @@
</dependency>
</dependencies>

<build>
<plugins>
<!-- added source folder containing the code specific to the spark version -->
<plugin>
<groupId>org.codehaus.mojo</groupId>
<artifactId>build-helper-maven-plugin</artifactId>
<version>1.12</version>
<executions>
<execution>
<id>add-source</id>
<phase>generate-sources</phase>
<goals><goal>add-source</goal></goals>
<configuration>
<sources>
<source>src/main/spark-${spark.major.version}</source>
</sources>
</configuration>
</execution>
</executions>
</plugin>
</plugins>
</build>

<profiles>
<profile>
<id>spark-default</id>
<activation>
<property>
<name>!spark.major.version</name>
</property>
</activation>
<properties>
<spark.major.version>1</spark.major.version>
<spark.version>1.6.2</spark.version>
<hadoop.version>2.2.0</hadoop.version>
</properties>
<dependencies>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-streaming-kafka_2.10</artifactId>
<version>${spark.version}</version>
</dependency>
</dependencies>
</profile>
<profile>
<id>spark-1</id>
<activation>
<property>
<name>spark.major.version</name>
<value>1</value>
</property>
</activation>
<properties>
<spark.version>1.6.2</spark.version>
<hadoop.version>2.2.0</hadoop.version>
</properties>
<dependencies>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-streaming-kafka_2.10</artifactId>
<version>${spark.version}</version>
</dependency>
</dependencies>
</profile>
<profile>
<id>spark-2</id>
<activation>
<property>
<name>spark.major.version</name>
<value>2</value>
</property>
</activation>
<properties>
<spark.version>2.1.0</spark.version>
<hadoop.version>2.2.0</hadoop.version>
</properties>
<dependencies>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-streaming-kafka-0-8_2.10</artifactId>
<version>${spark.version}</version>
</dependency>
</dependencies>
</profile>
</profiles>

</project>
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import org.apache.commons.net.util.Base64;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.datavec.api.writable.Writable;
import org.datavec.spark.functions.FlatMapFunctionAdapter;
import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee;
import org.deeplearning4j.streaming.conversion.dataset.RecordToDataSet;
import org.deeplearning4j.streaming.serde.RecordDeSerializer;
import org.nd4j.linalg.dataset.DataSet;
Expand All @@ -15,11 +17,22 @@
* Flat maps a binary dataset string in to a
* dataset
*/
public class DataSetFlatmap implements FlatMapFunction<Tuple2<String, String>, DataSet> {
public class DataSetFlatmap extends BaseFlatMapFunctionAdaptee<Tuple2<String, String>, DataSet> {

public DataSetFlatmap(int numLabels, RecordToDataSet recordToDataSetFunction) {
super(new DataSetFlatmapAdapter(numLabels, recordToDataSetFunction));
}
}

/**
* Flat maps a binary dataset string in to a
* dataset
*/
class DataSetFlatmapAdapter implements FlatMapFunctionAdapter<Tuple2<String, String>, DataSet> {
private int numLabels;
private RecordToDataSet recordToDataSetFunction;

public DataSetFlatmap(int numLabels, RecordToDataSet recordToDataSetFunction) {
public DataSetFlatmapAdapter(int numLabels, RecordToDataSet recordToDataSetFunction) {
this.numLabels = numLabels;
this.recordToDataSetFunction = recordToDataSetFunction;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,6 @@ public JavaDStream<DataSet> createStream() {
@Override
public void startStreamingConsumption(long timeout) {
jssc.start();
if(timeout < 0)
jssc.awaitTermination();
else
jssc.awaitTermination(timeout);
StreamingContextUtils.awaitTermination(jssc, timeout);
}
}
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
package org.deeplearning4j.streaming.pipeline.spark.inerference;

import org.apache.commons.net.util.Base64;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.datavec.api.writable.Writable;
import org.datavec.spark.functions.FlatMapFunctionAdapter;
import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee;
import org.deeplearning4j.streaming.conversion.ndarray.RecordToNDArray;
import org.deeplearning4j.streaming.serde.RecordDeSerializer;
import org.nd4j.linalg.api.ndarray.INDArray;
Expand All @@ -16,10 +17,22 @@
* dataset
* @author Adam Gibson
*/
public class NDArrayFlatMap implements FlatMapFunction<Tuple2<String, String>, INDArray> {
private RecordToNDArray recordToDataSetFunction;
public class NDArrayFlatMap extends BaseFlatMapFunctionAdaptee<Tuple2<String, String>, INDArray> {

public NDArrayFlatMap(RecordToNDArray recordToDataSetFunction) {
super(new NDArrayFlatMapAdapter(recordToDataSetFunction));
}
}

/**
* Flat maps a binary dataset string in to a
* dataset
* @author Adam Gibson
*/
class NDArrayFlatMapAdapter implements FlatMapFunctionAdapter<Tuple2<String, String>, INDArray> {
private RecordToNDArray recordToDataSetFunction;

public NDArrayFlatMapAdapter(RecordToNDArray recordToDataSetFunction) {
this.recordToDataSetFunction = recordToDataSetFunction;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.apache.spark.streaming.kafka.KafkaUtils;
import org.deeplearning4j.streaming.conversion.ndarray.RecordToNDArray;
import org.deeplearning4j.streaming.pipeline.kafka.BaseKafkaPipeline;
import org.deeplearning4j.streaming.pipeline.spark.StreamingContextUtils;
import org.nd4j.linalg.api.ndarray.INDArray;

import java.util.Collections;
Expand Down Expand Up @@ -90,9 +91,6 @@ public JavaDStream<INDArray> createStream() {
@Override
public void startStreamingConsumption(long timeout) {
jssc.start();
if(timeout < 0)
jssc.awaitTermination();
else
jssc.awaitTermination(timeout);
StreamingContextUtils.awaitTermination(jssc, timeout);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package org.deeplearning4j.streaming.pipeline.spark;

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.streaming.api.java.JavaDStream;
import org.apache.spark.streaming.api.java.JavaStreamingContext;

/**
* In order to handle changes between Spark 1.x and 2.x
*/
public class StreamingContextUtils {

public static void awaitTermination(JavaStreamingContext jssc, long timeout) {
if(timeout < 0)
jssc.awaitTermination();
else
jssc.awaitTermination(timeout);
}

public static <K> void foreach(JavaDStream<K> stream, Function<JavaRDD<K>, Void> func) {
stream.foreach(func);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package org.deeplearning4j.streaming.pipeline.spark;

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.VoidFunction;
import org.nd4j.linalg.dataset.DataSet;

/**
* Created by agibsonccc on 6/11/16.
*/
public class PrintDataSet implements VoidFunction<JavaRDD<DataSet>> {
@Override
public void call(JavaRDD<DataSet> dataSetJavaRDD) throws Exception {
dataSetJavaRDD.foreach(new VoidFunction<DataSet>() {
@Override
public void call(DataSet dataSet) throws Exception {
System.out.println(dataSet);
}
});
}
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package org.deeplearning4j.streaming.pipeline.spark;

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.VoidFunction;
import org.apache.spark.streaming.api.java.JavaDStream;
import org.apache.spark.streaming.api.java.JavaStreamingContext;

/**
* In order to handle changes between Spark 1.x and 2.x
*/
public class StreamingContextUtils {

public static void awaitTermination(JavaStreamingContext jssc, long timeout) {
try {
if(timeout < 0)
jssc.awaitTermination();
else
jssc.awaitTerminationOrTimeout(timeout);
} catch (InterruptedException e) {
e.printStackTrace();
}
}

public static <K> void foreach(JavaDStream<K> stream, VoidFunction<JavaRDD<K>> func) {
stream.foreachRDD(func);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
import org.apache.spark.streaming.api.java.JavaPairInputDStream;
import org.apache.spark.streaming.api.java.JavaStreamingContext;
import org.apache.spark.streaming.kafka.KafkaUtils;
import org.datavec.spark.functions.FlatMapFunctionAdapter;
import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee;
import scala.Tuple2;

import java.util.*;
Expand Down Expand Up @@ -87,12 +89,12 @@ public String call(Tuple2<String, String> tuple2) {
}
});

JavaDStream<String> words = lines.flatMap(new FlatMapFunction<String, String>() {
JavaDStream<String> words = lines.flatMap(new BaseFlatMapFunctionAdaptee<>(new FlatMapFunctionAdapter<String, String>() {
@Override
public Iterable<String> call(String x) {
return Arrays.asList(SPACE.split(x));
}
});
}));
JavaPairDStream<String, Integer> wordCounts = words.mapToPair(
new PairFunction<String, String, Integer>() {
@Override
Expand All @@ -112,4 +114,4 @@ public Integer call(Integer i1, Integer i2) {
jssc.start();
jssc.awaitTermination();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import org.deeplearning4j.streaming.embedded.TestUtils;
import org.deeplearning4j.streaming.pipeline.spark.PrintDataSet;
import org.deeplearning4j.streaming.pipeline.spark.SparkStreamingPipeline;
import org.deeplearning4j.streaming.pipeline.spark.StreamingContextUtils;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
Expand Down Expand Up @@ -53,7 +54,7 @@ public void testPipeline() throws Exception {
//NOTE THAT YOU NEED TO DO SOMETHING WITH THE STREAM OTHERWISE IT ERRORS OUT.
//ALSO NOTE HERE THAT YOU NEED TO HAVE THE FUNCTION BE AN OBJECT NOT AN ANONYMOUS
//CLASS BECAUSE OF TASK SERIALIZATION
dataSetJavaDStream.foreach(new PrintDataSet());
StreamingContextUtils.foreach(dataSetJavaDStream, new PrintDataSet());

pipeline.startStreamingConsumption(1000);

Expand Down
Loading

0 comments on commit 2fc4969

Please sign in to comment.