Skip to content

Commit

Permalink
SAMZA-1627: Watermark broadcast enhancements
Browse files Browse the repository at this point in the history
Currently each upstream task needs to broadcast to every single partition of intermediate streams in order to aggregate watermarks in the consumers. A better way to do this is to have only one downstream consumer doing the aggregation, and then broadcast to all the partitions. This is safe as we can prove the broadcast watermark message is after all the upstream tasks finished producing the events whose event time are before the watermark. This reduced the full message count from O(n^2) to O(n).

Author: xinyuiscool <xiliu@linkedin.com>

Reviewers: Boris S <sborya@gmail.com>

Closes apache#456 from xinyuiscool/SAMZA-1627
  • Loading branch information
xinyuiscool committed Mar 28, 2018
1 parent aff805d commit 5431350
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.apache.samza.system.StreamMetadataCache;
import org.apache.samza.system.SystemStream;
import org.apache.samza.system.SystemStreamMetadata;
import org.apache.samza.system.SystemStreamPartition;
import org.apache.samza.task.MessageCollector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -35,7 +36,7 @@


/**
* This is a helper class to broadcast control messages to each partition of an intermediate stream
* This is a helper class to send control messages to an intermediate stream
*/
class ControlMessageSender {
private static final Logger LOG = LoggerFactory.getLogger(ControlMessageSender.class);
Expand All @@ -48,20 +49,37 @@ class ControlMessageSender {
}

void send(ControlMessage message, SystemStream systemStream, MessageCollector collector) {
Integer partitionCount = PARTITION_COUNT_CACHE.computeIfAbsent(systemStream, ss -> {
int partitionCount = getPartitionCount(systemStream);
// We pick a partition based on topic hashcode to aggregate the control messages from upstream tasks
// After aggregation the task will broadcast the results to other partitions
int aggregatePartition = systemStream.getStream().hashCode() % partitionCount;

LOG.debug(String.format("Send %s message from task %s to %s partition %s for aggregation",
MessageType.of(message).name(), message.getTaskName(), systemStream, aggregatePartition));

OutgoingMessageEnvelope envelopeOut = new OutgoingMessageEnvelope(systemStream, aggregatePartition, null, message);
collector.send(envelopeOut);
}

void broadcastToOtherPartitions(ControlMessage message, SystemStreamPartition ssp, MessageCollector collector) {
SystemStream systemStream = ssp.getSystemStream();
int partitionCount = getPartitionCount(systemStream);
int currentPartition = ssp.getPartition().getPartitionId();
for (int i = 0; i < partitionCount; i++) {
if (i != currentPartition) {
OutgoingMessageEnvelope envelopeOut = new OutgoingMessageEnvelope(systemStream, i, null, message);
collector.send(envelopeOut);
}
}
}

private int getPartitionCount(SystemStream systemStream) {
return PARTITION_COUNT_CACHE.computeIfAbsent(systemStream, ss -> {
SystemStreamMetadata metadata = metadataCache.getSystemStreamMetadata(ss, true);
if (metadata == null) {
throw new SamzaException("Unable to find metadata for stream " + systemStream);
}
return metadata.getSystemStreamPartitionMetadata().size();
});

LOG.debug(String.format("Broadcast %s message from task %s to %s with %s partition",
MessageType.of(message).name(), message.getTaskName(), systemStream, partitionCount));

for (int i = 0; i < partitionCount; i++) {
OutgoingMessageEnvelope envelopeOut = new OutgoingMessageEnvelope(systemStream, i, null, message);
collector.send(envelopeOut);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,13 @@ private static final class EndOfStreamState {

synchronized void update(String taskName) {
if (taskName != null) {
// aggregate the eos messages
tasks.add(taskName);
isEndOfStream = tasks.size() == expectedTotal;
} else {
// eos is coming from either source or aggregator task
isEndOfStream = true;
}
isEndOfStream = tasks.size() == expectedTotal;
}

boolean isEndOfStream() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ public abstract class OperatorImpl<M, RM> {
// watermark states
private WatermarkStates watermarkStates;
private TaskContext taskContext;
private ControlMessageSender controlMessageSender;

/**
* Initialize this {@link OperatorImpl} and its user-defined functions.
Expand Down Expand Up @@ -114,6 +115,7 @@ public final void init(Config config, TaskContext context) {
TaskContextImpl taskContext = (TaskContextImpl) context;
this.eosStates = (EndOfStreamStates) taskContext.fetchObject(EndOfStreamStates.class.getName());
this.watermarkStates = (WatermarkStates) taskContext.fetchObject(WatermarkStates.class.getName());
this.controlMessageSender = new ControlMessageSender(taskContext.getStreamMetadataCache());

if (taskContext.getJobModel() != null) {
ContainerModel containerModel = taskContext.getJobModel().getContainers()
Expand Down Expand Up @@ -265,6 +267,12 @@ public final void aggregateEndOfStream(EndOfStreamMessage eos, SystemStreamParti
SystemStream stream = ssp.getSystemStream();
if (eosStates.isEndOfStream(stream)) {
LOG.info("Input {} reaches the end for task {}", stream.toString(), taskName.getTaskName());
if (eos.getTaskName() != null) {
// This is the aggregation task, which already received all the eos messages from upstream
// broadcast the end-of-stream to all the peer partitions
controlMessageSender.broadcastToOtherPartitions(new EndOfStreamMessage(), ssp, collector);
}
// populate the end-of-stream through the dag
onEndOfStream(collector, coordinator);

if (eosStates.allEndOfStream()) {
Expand Down Expand Up @@ -322,6 +330,12 @@ public final void aggregateWatermark(WatermarkMessage watermarkMessage, SystemSt
long watermark = watermarkStates.getWatermark(ssp.getSystemStream());
if (watermark != WatermarkStates.WATERMARK_NOT_EXIST) {
LOG.debug("Got watermark {} from stream {}", watermark, ssp.getSystemStream());
if (watermarkMessage.getTaskName() != null) {
// This is the aggregation task, which already received all the watermark messages from upstream
// broadcast the watermark to all the peer partitions
controlMessageSender.broadcastToOtherPartitions(new WatermarkMessage(watermark), ssp, collector);
}
// populate the watermark through the dag
onWatermark(watermark, collector, coordinator);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,12 @@ synchronized void update(long timestamp, String taskName) {
}
}

/**
* Check whether we got all the watermarks.
* At a sources, the expectedTotal is 0.
* For any intermediate streams, the expectedTotal is the upstream task count.
*/
if (timestamps.size() == expectedTotal) {
if (taskName == null) {
// we get watermark either from the source or from the aggregator task
watermarkTime = Math.max(watermarkTime, timestamp);
} else if (timestamps.size() == expectedTotal) {
// For any intermediate streams, the expectedTotal is the upstream task count.
// Check whether we got all the watermarks, and set the watermark to be the min.
Optional<Long> min = timestamps.values().stream().min(Long::compare);
watermarkTime = min.orElse(timestamp);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.apache.samza.system.StreamMetadataCache;
import org.apache.samza.system.SystemStream;
import org.apache.samza.system.SystemStreamMetadata;
import org.apache.samza.system.SystemStreamPartition;
import org.apache.samza.system.WatermarkMessage;
import org.apache.samza.task.MessageCollector;
import org.junit.Test;
Expand Down Expand Up @@ -68,6 +69,35 @@ public void testSend() {
ControlMessageSender sender = new ControlMessageSender(metadataCache);
WatermarkMessage watermark = new WatermarkMessage(System.currentTimeMillis(), "task 0");
sender.send(watermark, systemStream, collector);
assertEquals(partitions.size(), 4);
assertEquals(partitions.size(), 1);
}

@Test
public void testBroadcast() {
SystemStreamMetadata metadata = mock(SystemStreamMetadata.class);
Map<Partition, SystemStreamMetadata.SystemStreamPartitionMetadata> partitionMetadata = new HashMap<>();
partitionMetadata.put(new Partition(0), mock(SystemStreamMetadata.SystemStreamPartitionMetadata.class));
partitionMetadata.put(new Partition(1), mock(SystemStreamMetadata.SystemStreamPartitionMetadata.class));
partitionMetadata.put(new Partition(2), mock(SystemStreamMetadata.SystemStreamPartitionMetadata.class));
partitionMetadata.put(new Partition(3), mock(SystemStreamMetadata.SystemStreamPartitionMetadata.class));
when(metadata.getSystemStreamPartitionMetadata()).thenReturn(partitionMetadata);
StreamMetadataCache metadataCache = mock(StreamMetadataCache.class);
when(metadataCache.getSystemStreamMetadata(anyObject(), anyBoolean())).thenReturn(metadata);

SystemStream systemStream = new SystemStream("test-system", "test-stream");
Set<Integer> partitions = new HashSet<>();
MessageCollector collector = mock(MessageCollector.class);
doAnswer(invocation -> {
OutgoingMessageEnvelope envelope = (OutgoingMessageEnvelope) invocation.getArguments()[0];
partitions.add((Integer) envelope.getPartitionKey());
assertEquals(envelope.getSystemStream(), systemStream);
return null;
}).when(collector).send(any());

ControlMessageSender sender = new ControlMessageSender(metadataCache);
WatermarkMessage watermark = new WatermarkMessage(System.currentTimeMillis(), "task 0");
SystemStreamPartition ssp = new SystemStreamPartition(systemStream, new Partition(0));
sender.broadcastToOtherPartitions(watermark, ssp, collector);
assertEquals(partitions.size(), 3);
}
}

0 comments on commit 5431350

Please sign in to comment.