Skip to content

[SPARK-1022] Kafka unit test that actually sends and receives data #557

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions external/kafka/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>net.sf.jopt-simple</groupId>
<artifactId>jopt-simple</artifactId>
<version>4.5</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.scalatest</groupId>
<artifactId>scalatest_${scala.binary.version}</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,31 +17,108 @@

package org.apache.spark.streaming.kafka;

import java.io.Serializable;
import java.util.HashMap;
import java.util.List;

import org.apache.spark.streaming.api.java.JavaPairReceiverInputDStream;
import org.junit.Test;
import com.google.common.collect.Maps;
import kafka.serializer.StringDecoder;
import org.apache.spark.storage.StorageLevel;
import scala.Predef;
import scala.Tuple2;
import scala.collection.JavaConverters;

import junit.framework.Assert;

import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.streaming.Duration;
import org.apache.spark.streaming.LocalJavaStreamingContext;
import org.apache.spark.streaming.api.java.JavaDStream;
import org.apache.spark.streaming.api.java.JavaPairDStream;
import org.apache.spark.streaming.api.java.JavaStreamingContext;

import org.junit.Test;
import org.junit.Ignore;
import org.junit.After;
import org.junit.Before;

public class JavaKafkaStreamSuite extends LocalJavaStreamingContext implements Serializable {
private transient KafkaStreamSuite testSuite = new KafkaStreamSuite();

@Before
@Override
public void setUp() {
testSuite.beforeFunction();
System.clearProperty("spark.driver.port");
System.setProperty("spark.streaming.clock", "org.apache.spark.streaming.util.SystemClock");
ssc = new JavaStreamingContext("local[2]", "test", new Duration(1000));
}

@After
@Override
public void tearDown() {
ssc.stop();
ssc = null;
System.clearProperty("spark.driver.port");
testSuite.afterFunction();
}

public class JavaKafkaStreamSuite extends LocalJavaStreamingContext {
@Test
@Ignore @Test
public void testKafkaStream() {
HashMap<String, Integer> topics = Maps.newHashMap();

// tests the API, does not actually test data receiving
JavaPairReceiverInputDStream<String, String> test1 =
KafkaUtils.createStream(ssc, "localhost:12345", "group", topics);
JavaPairReceiverInputDStream<String, String> test2 = KafkaUtils.createStream(ssc, "localhost:12345", "group", topics,
StorageLevel.MEMORY_AND_DISK_SER_2());

HashMap<String, String> kafkaParams = Maps.newHashMap();
kafkaParams.put("zookeeper.connect", "localhost:12345");
kafkaParams.put("group.id","consumer-group");
JavaPairReceiverInputDStream<String, String> test3 = KafkaUtils.createStream(ssc,
String.class, String.class, StringDecoder.class, StringDecoder.class,
kafkaParams, topics, StorageLevel.MEMORY_AND_DISK_SER_2());
String topic = "topic1";
HashMap<String, Integer> topics = new HashMap<String, Integer>();
topics.put(topic, 1);

HashMap<String, Integer> sent = new HashMap<String, Integer>();
sent.put("a", 5);
sent.put("b", 3);
sent.put("c", 10);

JavaPairDStream<String, String> stream = KafkaUtils.createStream(ssc,
testSuite.zkConnect(),
"group",
topics);

final HashMap<String, Long> result = new HashMap<String, Long>();

JavaDStream<String> words = stream.map(
new Function<Tuple2<String, String>, String>() {
@Override
public String call(Tuple2<String, String> tuple2) throws Exception {
return tuple2._2();
}
}
);

words.countByValue().foreachRDD(
new Function<JavaPairRDD<String, Long>, Void>() {
@Override
public Void call(JavaPairRDD<String, Long> rdd) throws Exception {
List<Tuple2<String, Long>> ret = rdd.collect();
for (Tuple2<String, Long> r : ret) {
if (result.containsKey(r._1())) {
result.put(r._1(), result.get(r._1()) + r._2());
} else {
result.put(r._1(), r._2());
}
}

return null;
}
}
);

ssc.start();

HashMap<String, Object> tmp = new HashMap<String, Object>(sent);
testSuite.produceAndSendTestMessage(topic,
JavaConverters.asScalaMapConverter(tmp).asScala().toMap(
Predef.<Tuple2<String, Object>>conforms()
));

ssc.awaitTermination(10000);

Assert.assertEquals(sent.size(), result.size());
for (String k : sent.keySet()) {
Assert.assertEquals(sent.get(k).intValue(), result.get(k).intValue());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,28 +17,172 @@

package org.apache.spark.streaming.kafka

import kafka.serializer.StringDecoder
import java.io.File
import java.net.InetSocketAddress
import java.util.{Properties, Random}

import scala.collection.mutable

import kafka.admin.CreateTopicCommand
import kafka.common.TopicAndPartition
import kafka.producer.{KeyedMessage, ProducerConfig, Producer}
import kafka.utils.ZKStringSerializer
import kafka.serializer.StringEncoder
import kafka.server.{KafkaConfig, KafkaServer}

import org.apache.spark.streaming.{StreamingContext, TestSuiteBase}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.dstream.ReceiverInputDStream
import org.apache.zookeeper.server.ZooKeeperServer
import org.apache.zookeeper.server.NIOServerCnxnFactory

import org.I0Itec.zkclient.ZkClient

class KafkaStreamSuite extends TestSuiteBase {
val zkConnect = "localhost:2181"
var zookeeper: EmbeddedZookeeper = _
var zkClient: ZkClient = _
val zkConnectionTimeout = 6000
val zkSessionTimeout = 6000

val brokerPort = 9092
val brokerProps = getBrokerConfig(brokerPort)
val brokerConf = new KafkaConfig(brokerProps)
var server: KafkaServer = _

override def beforeFunction() {
// Zookeeper server startup
zookeeper = new EmbeddedZookeeper(zkConnect)
logInfo("==================== 0 ====================")
zkClient = new ZkClient(zkConnect, zkSessionTimeout, zkConnectionTimeout, ZKStringSerializer)
logInfo("==================== 1 ====================")

// Kafka broker startup
server = new KafkaServer(brokerConf)
logInfo("==================== 2 ====================")
server.startup()
logInfo("==================== 3 ====================")
Thread.sleep(2000)
logInfo("==================== 4 ====================")
super.beforeFunction()
}

override def afterFunction() {
server.shutdown()
brokerConf.logDirs.foreach { f => KafkaStreamSuite.deleteDir(new File(f)) }

test("kafka input stream") {
zkClient.close()
zookeeper.shutdown()

super.afterFunction()
}

ignore("kafka input stream") {
val ssc = new StreamingContext(master, framework, batchDuration)
val topics = Map("my-topic" -> 1)

// tests the API, does not actually test data receiving
val test1: ReceiverInputDStream[(String, String)] =
KafkaUtils.createStream(ssc, "localhost:1234", "group", topics)
val test2: ReceiverInputDStream[(String, String)] =
KafkaUtils.createStream(ssc, "localhost:12345", "group", topics, StorageLevel.MEMORY_AND_DISK_SER_2)
val kafkaParams = Map("zookeeper.connect"->"localhost:12345","group.id"->"consumer-group")
val test3: ReceiverInputDStream[(String, String)] =
KafkaUtils.createStream[String, String, StringDecoder, StringDecoder](
ssc, kafkaParams, topics, StorageLevel.MEMORY_AND_DISK_SER_2)

// TODO: Actually test receiving data
val topic = "topic1"
val sent = Map("a" -> 5, "b" -> 3, "c" -> 10)

val stream = KafkaUtils.createStream(ssc, zkConnect, "group", Map(topic -> 1))
val result = new mutable.HashMap[String, Long]()
stream.map { case (k, v) => v }
.countByValue()
.foreachRDD { r =>
val ret = r.collect()
ret.toMap.foreach { kv =>
val count = result.getOrElseUpdate(kv._1, 0) + kv._2
result.put(kv._1, count)
}
}
ssc.start()
produceAndSendTestMessage(topic, sent)
ssc.awaitTermination(10000)

assert(sent.size === result.size)
sent.keys.foreach { k => assert(sent(k) === result(k).toInt) }

ssc.stop()
}

private def getBrokerConfig(port: Int): Properties = {
val props = new Properties()
props.put("broker.id", "0")
props.put("host.name", "localhost")
props.put("port", port.toString)
props.put("log.dir", KafkaStreamSuite.tmpDir().getAbsolutePath)
props.put("zookeeper.connect", zkConnect)
props.put("log.flush.interval.messages", "1")
props.put("replica.socket.timeout.ms", "1500")
props
}

private def getProducerConfig(brokerList: String): Properties = {
val props = new Properties()
props.put("metadata.broker.list", brokerList)
props.put("serializer.class", classOf[StringEncoder].getName)
props
}

private def createTestMessage(topic: String, sent: Map[String, Int])
: Seq[KeyedMessage[String, String]] = {
val messages = for ((s, freq) <- sent; i <- 0 until freq) yield {
new KeyedMessage[String, String](topic, s)
}
messages.toSeq
}

def produceAndSendTestMessage(topic: String, sent: Map[String, Int]) {
val brokerAddr = brokerConf.hostName + ":" + brokerConf.port
val producer = new Producer[String, String](new ProducerConfig(getProducerConfig(brokerAddr)))
CreateTopicCommand.createTopic(zkClient, topic, 1, 1, "0")
logInfo("==================== 5 ====================")
// wait until metadata is propagated
Thread.sleep(1000)
assert(server.apis.leaderCache.keySet.contains(TopicAndPartition(topic, 0)))
producer.send(createTestMessage(topic, sent): _*)
Thread.sleep(1000)

logInfo("==================== 6 ====================")
producer.close()
}
}

object KafkaStreamSuite {
val random = new Random()

def tmpDir(): File = {
val tmp = System.getProperty("java.io.tmpdir")
val f = new File(tmp, "spark-kafka-" + random.nextInt(1000))
f.mkdirs()
f
}

def deleteDir(file: File) {
if (file.isFile) {
file.delete()
} else {
for (f <- file.listFiles()) {
deleteDir(f)
}
file.delete()
}
}
}

class EmbeddedZookeeper(val zkConnect: String) {
val random = new Random()
val snapshotDir = KafkaStreamSuite.tmpDir()
val logDir = KafkaStreamSuite.tmpDir()

val zookeeper = new ZooKeeperServer(snapshotDir, logDir, 500)
val(ip, port) = {
val splits = zkConnect.split(":")
(splits(0), splits(1).toInt)
}
val factory = new NIOServerCnxnFactory()
factory.configure(new InetSocketAddress(ip, port), 16)
factory.startup(zookeeper)

def shutdown() {
factory.shutdown()
KafkaStreamSuite.deleteDir(snapshotDir)
KafkaStreamSuite.deleteDir(logDir)
}
}
3 changes: 2 additions & 1 deletion project/SparkBuild.scala
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,8 @@ object SparkBuild extends Build {
exclude("com.sun.jdmk", "jmxtools")
exclude("com.sun.jmx", "jmxri")
exclude("net.sf.jopt-simple", "jopt-simple")
excludeAll(excludeNetty, excludeSLF4J)
excludeAll(excludeNetty, excludeSLF4J),
"net.sf.jopt-simple" %"jopt-simple" % "4.5" % "test"
)
)

Expand Down