Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions .idea/runConfigurations/Debug_OpenSearch.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Fix unnecessary refreshes on update preparation failures ([#15261](https://github.com/opensearch-project/OpenSearch/issues/15261))
- Fix NullPointerException in segment replicator ([#18997](https://github.com/opensearch-project/OpenSearch/pull/18997))
- Ensure that plugins that utilize dumpCoverage can write to jacoco.dir when tests.security.manager is enabled ([#18983](https://github.com/opensearch-project/OpenSearch/pull/18983))
- Fix OOM due to large number of shard result buffering ([#19066](https://github.com/opensearch-project/OpenSearch/pull/19066))

### Dependencies
- Bump `com.netflix.nebula.ospackage-base` from 12.0.0 to 12.1.0 ([#19019](https://github.com/opensearch-project/OpenSearch/pull/19019))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,8 @@ SearchActionListener<Result> createShardActionListener(
public void innerOnResponse(Result result) {
try {
onShardResult(result, shardIt);
} catch (Exception e) {
logger.trace("Failed to consume the shard {} result: {}", shard.getShardId(), e);
} finally {
executeNext(pendingExecutions, thread);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import org.opensearch.core.common.breaker.CircuitBreaker;
import org.opensearch.core.common.breaker.CircuitBreakingException;
import org.opensearch.core.common.io.stream.NamedWriteableRegistry;
import org.opensearch.core.tasks.TaskCancelledException;
import org.opensearch.search.SearchPhaseResult;
import org.opensearch.search.SearchShardTarget;
import org.opensearch.search.aggregations.InternalAggregation.ReduceContextBuilder;
Expand All @@ -57,6 +58,7 @@
import java.util.List;
import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BooleanSupplier;
import java.util.function.Consumer;

/**
Expand Down Expand Up @@ -86,6 +88,30 @@ public class QueryPhaseResultConsumer extends ArraySearchPhaseResults<SearchPhas

final PendingMerges pendingMerges;
private final Consumer<Exception> onPartialMergeFailure;
private final BooleanSupplier isTaskCancelled;

public QueryPhaseResultConsumer(
SearchRequest request,
Executor executor,
CircuitBreaker circuitBreaker,
SearchPhaseController controller,
SearchProgressListener progressListener,
NamedWriteableRegistry namedWriteableRegistry,
int expectedResultSize,
Consumer<Exception> onPartialMergeFailure
) {
this(
request,
executor,
circuitBreaker,
controller,
progressListener,
namedWriteableRegistry,
expectedResultSize,
onPartialMergeFailure,
() -> false
);
}

/**
* Creates a {@link QueryPhaseResultConsumer} that incrementally reduces aggregation results
Expand All @@ -99,7 +125,8 @@ public QueryPhaseResultConsumer(
SearchProgressListener progressListener,
NamedWriteableRegistry namedWriteableRegistry,
int expectedResultSize,
Consumer<Exception> onPartialMergeFailure
Consumer<Exception> onPartialMergeFailure,
BooleanSupplier isTaskCancelled
) {
super(expectedResultSize);
this.executor = executor;
Expand All @@ -117,6 +144,7 @@ public QueryPhaseResultConsumer(
this.hasAggs = source != null && source.aggregations() != null;
int batchReduceSize = getBatchReduceSize(request.getBatchedReduceSize(), expectedResultSize);
this.pendingMerges = new PendingMerges(batchReduceSize, request.resolveTrackTotalHitsUpTo());
this.isTaskCancelled = isTaskCancelled;
}

int getBatchReduceSize(int requestBatchedReduceSize, int minBatchReduceSize) {
Expand All @@ -133,6 +161,7 @@ public void consumeResult(SearchPhaseResult result, Runnable next) {
super.consumeResult(result, () -> {});
QuerySearchResult querySearchResult = result.queryResult();
progressListener.notifyQueryResult(querySearchResult.getShardIndex());
checkCancellation();
pendingMerges.consume(querySearchResult, next);
}

Expand All @@ -143,6 +172,7 @@ public SearchPhaseController.ReducedQueryPhase reduce() throws Exception {
} else if (pendingMerges.hasFailure()) {
throw pendingMerges.getFailure();
}
checkCancellation();

// ensure consistent ordering
pendingMerges.sortBuffer();
Expand Down Expand Up @@ -186,6 +216,7 @@ private MergeResult partialReduce(
MergeResult lastMerge,
int numReducePhases
) {
checkCancellation();
// ensure consistent ordering
Arrays.sort(toConsume, Comparator.comparingInt(QuerySearchResult::getShardIndex));

Expand Down Expand Up @@ -242,6 +273,16 @@ private MergeResult partialReduce(
return new MergeResult(processedShards, newTopDocs, newAggs, hasAggs ? serializedSize : 0);
}

private void checkCancellation() {
if (isTaskCancelled.getAsBoolean()) {
pendingMerges.resetCircuitBreakerForCurrentRequest();
// This check is to ensure that we are not masking the actual reason for cancellation i,e; CircuitBreakingException
if (!pendingMerges.hasFailure()) {
pendingMerges.failure.set(new TaskCancelledException("request has been terminated"));
}
}
}

public int getNumReducePhases() {
return pendingMerges.numReducePhases;
}
Expand Down Expand Up @@ -342,9 +383,10 @@ long estimateRamBytesUsedForReduce(long size) {
return Math.round(1.5d * size - size);
}

public void consume(QuerySearchResult result, Runnable next) {
public void consume(QuerySearchResult result, Runnable next) throws CircuitBreakingException {
boolean executeNextImmediately = true;
synchronized (this) {
checkCircuitBreaker(next);
if (hasFailure() || result.isNull()) {
result.consumeAll();
if (result.isNull()) {
Expand Down Expand Up @@ -378,17 +420,32 @@ public void consume(QuerySearchResult result, Runnable next) {
}
}

/**
* This method is needed to prevent OOM when the buffered results are too large
*
*/
private void checkCircuitBreaker(Runnable next) throws CircuitBreakingException {
try {
// force the CircuitBreaker eval to ensure during buffering we did not hit the circuit breaker limit
addEstimateAndMaybeBreak(0);
} catch (CircuitBreakingException e) {
resetCircuitBreakerForCurrentRequest();
// onPartialMergeFailure should only be invoked once since this is responsible for cancelling the
// search task
if (!hasFailure()) {
failure.set(e);
onPartialMergeFailure.accept(e);
}
}
}

private synchronized void onMergeFailure(Exception exc) {
if (hasFailure()) {
assert circuitBreakerBytes == 0;
return;
}
assert circuitBreakerBytes >= 0;
if (circuitBreakerBytes > 0) {
// make sure that we reset the circuit breaker
circuitBreaker.addWithoutBreaking(-circuitBreakerBytes);
circuitBreakerBytes = 0;
}
resetCircuitBreakerForCurrentRequest();
failure.compareAndSet(null, exc);
MergeTask task = runningTask.get();
runningTask.compareAndSet(task, null);
Expand All @@ -405,6 +462,13 @@ private synchronized void onMergeFailure(Exception exc) {
}
}

private void resetCircuitBreakerForCurrentRequest() {
if (circuitBreakerBytes > 0) {
circuitBreaker.addWithoutBreaking(-circuitBreakerBytes);
circuitBreakerBytes = 0;
}
}

private void onAfterMerge(MergeTask task, MergeResult newResult, long estimatedSize) {
synchronized (this) {
if (hasFailure()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
import java.util.List;
import java.util.Map;
import java.util.concurrent.Executor;
import java.util.function.BooleanSupplier;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.IntFunction;
Expand Down Expand Up @@ -778,6 +779,21 @@ QueryPhaseResultConsumer newSearchPhaseResults(
SearchRequest request,
int numShards,
Consumer<Exception> onPartialMergeFailure
) {
return newSearchPhaseResults(executor, circuitBreaker, listener, request, numShards, onPartialMergeFailure, () -> false);
}

/**
* Returns a new {@link QueryPhaseResultConsumer} instance that reduces search responses incrementally.
*/
QueryPhaseResultConsumer newSearchPhaseResults(
Executor executor,
CircuitBreaker circuitBreaker,
SearchProgressListener listener,
SearchRequest request,
int numShards,
Consumer<Exception> onPartialMergeFailure,
BooleanSupplier isTaskCancelled
) {
return new QueryPhaseResultConsumer(
request,
Expand All @@ -787,7 +803,8 @@ QueryPhaseResultConsumer newSearchPhaseResults(
listener,
namedWriteableRegistry,
numShards,
onPartialMergeFailure
onPartialMergeFailure,
isTaskCancelled
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1270,7 +1270,8 @@ AbstractSearchAsyncAction<? extends SearchPhaseResult> searchAsyncAction(
task.getProgressListener(),
searchRequest,
shardIterators.size(),
exc -> cancelTask(task, exc)
exc -> cancelTask(task, exc),
task::isCancelled
);
AbstractSearchAsyncAction<? extends SearchPhaseResult> searchAsyncAction;
switch (searchRequest.searchType()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,23 @@
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.opensearch.action.OriginalIndices;
import org.opensearch.common.breaker.TestCircuitBreaker;
import org.opensearch.common.lucene.search.TopDocsAndMaxScore;
import org.opensearch.common.util.BigArrays;
import org.opensearch.common.util.concurrent.OpenSearchExecutors;
import org.opensearch.common.util.concurrent.OpenSearchThreadPoolExecutor;
import org.opensearch.core.common.breaker.CircuitBreaker;
import org.opensearch.core.common.breaker.CircuitBreakingException;
import org.opensearch.core.common.breaker.NoopCircuitBreaker;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.search.DocValueFormat;
import org.opensearch.search.SearchShardTarget;
import org.opensearch.search.aggregations.AggregationBuilders;
import org.opensearch.search.aggregations.InternalAggregation;
import org.opensearch.search.aggregations.InternalAggregations;
import org.opensearch.search.aggregations.metrics.InternalMax;
import org.opensearch.search.aggregations.pipeline.PipelineAggregator;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.search.query.QuerySearchResult;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.TestThreadPool;
Expand Down Expand Up @@ -136,17 +141,7 @@ public void testProgressListenerExceptionsAreCaught() throws Exception {
CountDownLatch partialReduceLatch = new CountDownLatch(10);

for (int i = 0; i < 10; i++) {
SearchShardTarget searchShardTarget = new SearchShardTarget(
"node",
new ShardId("index", "uuid", i),
null,
OriginalIndices.NONE
);
QuerySearchResult querySearchResult = new QuerySearchResult();
TopDocs topDocs = new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
querySearchResult.topDocs(new TopDocsAndMaxScore(topDocs, Float.NaN), new DocValueFormat[0]);
querySearchResult.setSearchShardTarget(searchShardTarget);
querySearchResult.setShardIndex(i);
QuerySearchResult querySearchResult = getQuerySearchResult(i);
queryPhaseResultConsumer.consumeResult(querySearchResult, partialReduceLatch::countDown);
}

Expand All @@ -159,6 +154,48 @@ public void testProgressListenerExceptionsAreCaught() throws Exception {
assertEquals(1, searchProgressListener.onFinalReduce.get());
}

public void testCircuitBreakerTriggersBeforeBatchedReduce() {
SearchRequest searchRequest = new SearchRequest("index");
searchRequest.source(new SearchSourceBuilder().aggregation(AggregationBuilders.max("max").field("test")).size(1));
final int BATCH_REDUCED_SIZE = 5;
searchRequest.setBatchedReduceSize(BATCH_REDUCED_SIZE);
AtomicReference<Exception> onPartialMergeFailure = new AtomicReference<>();
TestCircuitBreaker testCircuitBreaker = new TestCircuitBreaker();
QueryPhaseResultConsumer queryPhaseResultConsumer = new QueryPhaseResultConsumer(
searchRequest,
executor,
testCircuitBreaker,
searchPhaseController,
SearchProgressListener.NOOP,
writableRegistry(),
10,
onPartialMergeFailure::set
);

for (int i = 0; i < BATCH_REDUCED_SIZE - 2; i++) {
QuerySearchResult querySearchResult = getQuerySearchResult(i);
querySearchResult.aggregations(InternalAggregations.from(List.of(new InternalMax("test", 23, DocValueFormat.RAW, null))));
queryPhaseResultConsumer.consumeResult(querySearchResult, () -> {});
}

testCircuitBreaker.startBreaking();
QuerySearchResult querySearchResult = getQuerySearchResult(7);
querySearchResult.aggregations(InternalAggregations.from(List.of(new InternalMax("test", 23, DocValueFormat.RAW, null))));
queryPhaseResultConsumer.consumeResult(querySearchResult, () -> {});
assertThrows(CircuitBreakingException.class, queryPhaseResultConsumer::reduce);
}

private static QuerySearchResult getQuerySearchResult(int i) {
SearchShardTarget searchShardTarget = new SearchShardTarget("node", new ShardId("index", "uuid", i), null, OriginalIndices.NONE);

QuerySearchResult querySearchResult = new QuerySearchResult();
TopDocs topDocs = new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
querySearchResult.topDocs(new TopDocsAndMaxScore(topDocs, Float.NaN), new DocValueFormat[0]);
querySearchResult.setSearchShardTarget(searchShardTarget);
querySearchResult.setShardIndex(i);
return querySearchResult;
}

private static class ThrowingSearchProgressListener extends SearchProgressListener {
private final AtomicInteger onQueryResult = new AtomicInteger(0);
private final AtomicInteger onPartialReduce = new AtomicInteger(0);
Expand Down
Loading
Loading