Skip to content

Commit

Permalink
use a blocking queue to pass polled messages to the processor for pro…
Browse files Browse the repository at this point in the history
…cessing
  • Loading branch information
yupeng9 committed Dec 31, 2024
1 parent 08f0712 commit 16dd9d0
Show file tree
Hide file tree
Showing 8 changed files with 266 additions and 171 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,21 @@
import org.opensearch.index.translog.TranslogManager;
import org.opensearch.index.translog.TranslogStats;
import org.opensearch.indices.ingest.DefaultStreamPoller;
import org.opensearch.indices.ingest.MessageProcessor;
import org.opensearch.indices.ingest.StreamPoller;
import org.opensearch.search.suggest.completion.CompletionStats;
import org.opensearch.threadpool.ThreadPool;

import java.io.Closeable;
import java.io.IOException;
import java.util.*;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
Expand Down Expand Up @@ -149,8 +156,9 @@ public Translog.Operation next() {
logger.info("created ingestion consumer for shard [{}]", engineConfig.getShardId());

Map<String, String> commitData = commitDataAsMap();
StreamPoller.ResetState resetState =
StreamPoller.ResetState.valueOf(ingestionSource.getPointerInitReset().toUpperCase(Locale.ROOT));
StreamPoller.ResetState resetState = StreamPoller.ResetState.valueOf(
ingestionSource.getPointerInitReset().toUpperCase(Locale.ROOT)
);
IngestionShardPointer startPointer = null;
Set<IngestionShardPointer> persistedPointers = new HashSet<>();
if (commitData.containsKey(StreamPoller.BATCH_START)) {
Expand All @@ -167,13 +175,7 @@ public Translog.Operation next() {
resetState = StreamPoller.ResetState.NONE;
}

streamPoller = new DefaultStreamPoller(
startPointer,
persistedPointers,
ingestionShardConsumer,
new MessageProcessor(this),
resetState
);
streamPoller = new DefaultStreamPoller(startPointer, persistedPointers, ingestionShardConsumer, this, resetState);
streamPoller.start();
success = true;
} catch (IOException | TranslogCorruptedException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,14 @@
import org.opensearch.index.IngestionShardConsumer;
import org.opensearch.index.IngestionShardPointer;
import org.opensearch.index.Message;
import org.opensearch.index.engine.IngestionEngine;

import java.util.List;
import java.util.Locale;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

Expand All @@ -42,7 +46,7 @@ public class DefaultStreamPoller implements StreamPoller {

private ExecutorService consumerThread;

private MessageProcessor processor;
private ExecutorService processorThread;

// start of the batch, inclusive
private IngestionShardPointer batchStartPointer;
Expand All @@ -51,6 +55,10 @@ public class DefaultStreamPoller implements StreamPoller {

private Set<IngestionShardPointer> persistedPointers;

private BlockingQueue<IngestionShardConsumer.ReadResult<? extends IngestionShardPointer, ? extends Message>> blockingQueue;

private MessageProcessorRunnable processorRunnable;

// A pointer to the max persisted pointer for optimizing the check
@Nullable
private IngestionShardPointer maxPersistedPointer;
Expand All @@ -59,20 +67,46 @@ public DefaultStreamPoller(
IngestionShardPointer startPointer,
Set<IngestionShardPointer> persistedPointers,
IngestionShardConsumer consumer,
MessageProcessor processor,
IngestionEngine ingestionEngine,
ResetState resetState
) {
this(
startPointer,
persistedPointers,
consumer,
new MessageProcessorRunnable(new ArrayBlockingQueue<>(100), ingestionEngine),
resetState
);
}

DefaultStreamPoller(
IngestionShardPointer startPointer,
Set<IngestionShardPointer> persistedPointers,
IngestionShardConsumer consumer,
MessageProcessorRunnable processorRunnable,
ResetState resetState
) {
this.consumer = consumer;
this.processor = processor;
this.consumer = Objects.requireNonNull(consumer);
this.resetState = resetState;
batchStartPointer = startPointer;
this.persistedPointers = persistedPointers;
if (!this.persistedPointers.isEmpty()) {
maxPersistedPointer = this.persistedPointers.stream().max(IngestionShardPointer::compareTo).get();
}
this.consumerThread = Executors.newSingleThreadExecutor(r -> new Thread(
this.processorRunnable = processorRunnable;
blockingQueue = processorRunnable.getBlockingQueue();
this.consumerThread = Executors.newSingleThreadExecutor(
r -> new Thread(
r,
String.format(Locale.ROOT, "stream-poller-consumer-%d-%d", consumer.getShardId(), System.currentTimeMillis())
)
);

// TODO: allow multiple threads for processing the messages in parallel
this.processorThread = Executors.newSingleThreadExecutor(
r -> new Thread(
r,
String.format(Locale.ROOT, "stream-poller-%d-%d", consumer.getShardId(), System.currentTimeMillis())
String.format(Locale.ROOT, "stream-poller-processor-%d-%d", consumer.getShardId(), System.currentTimeMillis())
)
);
}
Expand All @@ -83,7 +117,8 @@ public void start() {
throw new RuntimeException("poller is closed!");
}
started = true;
consumerThread.submit(this::startPoll).isDone();
consumerThread.submit(this::startPoll);
processorThread.submit(processorRunnable);
}

/**
Expand Down Expand Up @@ -126,7 +161,7 @@ protected void startPoll() {
// TODO: make sleep time configurable
Thread.sleep(100);
} catch (Throwable e) {
logger.error("Error in pausing the poller of shard {}", consumer.getShardId(), e);
logger.error("Error in pausing the poller of shard {}: {}", consumer.getShardId(), e);
}
continue;
}
Expand All @@ -146,16 +181,15 @@ protected void startPoll() {

state = State.PROCESSING;
// process the records
// TODO: separate threads for processing the messages in parallel
for (IngestionShardConsumer.ReadResult<? extends IngestionShardPointer, ? extends Message> result : results) {
// check if the message is already processed
if (isProcessed(result.getPointer())) {
logger.info("Skipping message with pointer {} as it is already processed", result.getPointer().asString());
continue;
}
processor.process(result.getMessage(), result.getPointer());
blockingQueue.put(result);
logger.debug(
"Processed message {} with pointer {}",
"Put message {} with pointer {} to the blocking queue",
String.valueOf(result.getMessage().getPayload()),
result.getPointer().asString()
);
Expand All @@ -164,7 +198,7 @@ protected void startPoll() {
batchStartPointer = consumer.nextPointer();
} catch (Throwable e) {
// TODO better error handling
logger.error("Error in polling the shard {}", consumer.getShardId(), e);
logger.error("Error in polling the shard {}: {}", consumer.getShardId(), e);
}
}
}
Expand Down Expand Up @@ -210,10 +244,6 @@ public void close() {
logger.info("consumer thread not started");
return;
}
if (consumerThread.isShutdown()) {
logger.info("consumer thread already closed");
return;
}
long startTime = System.currentTimeMillis(); // Record the start time
long timeout = 5000;
while (state != State.CLOSED) {
Expand All @@ -225,10 +255,13 @@ public void close() {
try {
Thread.sleep(100);
} catch (Throwable e) {
logger.error("Error in closing the poller of shard {}", consumer.getShardId(), e);
logger.error("Error in closing the poller of shard {}: {}", consumer.getShardId(), e);
}
}
blockingQueue.clear();
consumerThread.shutdown();
// interrupts the processor
processorThread.shutdownNow();
logger.info("closed the poller of shard {}", consumer.getShardId());
}

Expand Down

This file was deleted.

Loading

0 comments on commit 16dd9d0

Please sign in to comment.