Skip to content

Commit 10ff9d3

Browse files
Add circuit breaking logic for shard level results (#19066)
--------- Signed-off-by: Kaushal Kumar <ravi.kaushal97@gmail.com>
1 parent 1273c5a commit 10ff9d3

File tree

8 files changed

+235
-50
lines changed

8 files changed

+235
-50
lines changed

.idea/runConfigurations/Debug_OpenSearch.xml

Lines changed: 10 additions & 10 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
1818
- Fix unnecessary refreshes on update preparation failures ([#15261](https://github.com/opensearch-project/OpenSearch/issues/15261))
1919
- Fix NullPointerException in segment replicator ([#18997](https://github.com/opensearch-project/OpenSearch/pull/18997))
2020
- Ensure that plugins that utilize dumpCoverage can write to jacoco.dir when tests.security.manager is enabled ([#18983](https://github.com/opensearch-project/OpenSearch/pull/18983))
21+
- Fix OOM due to large number of shard result buffering ([#19066](https://github.com/opensearch-project/OpenSearch/pull/19066))
2122

2223
### Dependencies
2324
- Bump `com.netflix.nebula.ospackage-base` from 12.0.0 to 12.1.0 ([#19019](https://github.com/opensearch-project/OpenSearch/pull/19019))

server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,8 @@ SearchActionListener<Result> createShardActionListener(
359359
public void innerOnResponse(Result result) {
360360
try {
361361
onShardResult(result, shardIt);
362+
} catch (Exception e) {
363+
logger.trace("Failed to consume the shard {} result: {}", shard.getShardId(), e);
362364
} finally {
363365
executeNext(pendingExecutions, thread);
364366
}

server/src/main/java/org/opensearch/action/search/QueryPhaseResultConsumer.java

Lines changed: 71 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
import org.opensearch.core.common.breaker.CircuitBreaker;
4343
import org.opensearch.core.common.breaker.CircuitBreakingException;
4444
import org.opensearch.core.common.io.stream.NamedWriteableRegistry;
45+
import org.opensearch.core.tasks.TaskCancelledException;
4546
import org.opensearch.search.SearchPhaseResult;
4647
import org.opensearch.search.SearchShardTarget;
4748
import org.opensearch.search.aggregations.InternalAggregation.ReduceContextBuilder;
@@ -57,6 +58,7 @@
5758
import java.util.List;
5859
import java.util.concurrent.Executor;
5960
import java.util.concurrent.atomic.AtomicReference;
61+
import java.util.function.BooleanSupplier;
6062
import java.util.function.Consumer;
6163

6264
/**
@@ -86,6 +88,30 @@ public class QueryPhaseResultConsumer extends ArraySearchPhaseResults<SearchPhas
8688

8789
final PendingMerges pendingMerges;
8890
private final Consumer<Exception> onPartialMergeFailure;
91+
private final BooleanSupplier isTaskCancelled;
92+
93+
public QueryPhaseResultConsumer(
94+
SearchRequest request,
95+
Executor executor,
96+
CircuitBreaker circuitBreaker,
97+
SearchPhaseController controller,
98+
SearchProgressListener progressListener,
99+
NamedWriteableRegistry namedWriteableRegistry,
100+
int expectedResultSize,
101+
Consumer<Exception> onPartialMergeFailure
102+
) {
103+
this(
104+
request,
105+
executor,
106+
circuitBreaker,
107+
controller,
108+
progressListener,
109+
namedWriteableRegistry,
110+
expectedResultSize,
111+
onPartialMergeFailure,
112+
() -> false
113+
);
114+
}
89115

90116
/**
91117
* Creates a {@link QueryPhaseResultConsumer} that incrementally reduces aggregation results
@@ -99,7 +125,8 @@ public QueryPhaseResultConsumer(
99125
SearchProgressListener progressListener,
100126
NamedWriteableRegistry namedWriteableRegistry,
101127
int expectedResultSize,
102-
Consumer<Exception> onPartialMergeFailure
128+
Consumer<Exception> onPartialMergeFailure,
129+
BooleanSupplier isTaskCancelled
103130
) {
104131
super(expectedResultSize);
105132
this.executor = executor;
@@ -117,6 +144,7 @@ public QueryPhaseResultConsumer(
117144
this.hasAggs = source != null && source.aggregations() != null;
118145
int batchReduceSize = getBatchReduceSize(request.getBatchedReduceSize(), expectedResultSize);
119146
this.pendingMerges = new PendingMerges(batchReduceSize, request.resolveTrackTotalHitsUpTo());
147+
this.isTaskCancelled = isTaskCancelled;
120148
}
121149

122150
int getBatchReduceSize(int requestBatchedReduceSize, int minBatchReduceSize) {
@@ -133,6 +161,7 @@ public void consumeResult(SearchPhaseResult result, Runnable next) {
133161
super.consumeResult(result, () -> {});
134162
QuerySearchResult querySearchResult = result.queryResult();
135163
progressListener.notifyQueryResult(querySearchResult.getShardIndex());
164+
checkCancellation();
136165
pendingMerges.consume(querySearchResult, next);
137166
}
138167

@@ -143,6 +172,7 @@ public SearchPhaseController.ReducedQueryPhase reduce() throws Exception {
143172
} else if (pendingMerges.hasFailure()) {
144173
throw pendingMerges.getFailure();
145174
}
175+
checkCancellation();
146176

147177
// ensure consistent ordering
148178
pendingMerges.sortBuffer();
@@ -186,6 +216,7 @@ private MergeResult partialReduce(
186216
MergeResult lastMerge,
187217
int numReducePhases
188218
) {
219+
checkCancellation();
189220
// ensure consistent ordering
190221
Arrays.sort(toConsume, Comparator.comparingInt(QuerySearchResult::getShardIndex));
191222

@@ -242,6 +273,16 @@ private MergeResult partialReduce(
242273
return new MergeResult(processedShards, newTopDocs, newAggs, hasAggs ? serializedSize : 0);
243274
}
244275

276+
private void checkCancellation() {
277+
if (isTaskCancelled.getAsBoolean()) {
278+
pendingMerges.resetCircuitBreakerForCurrentRequest();
279+
// This check is to ensure that we are not masking the actual reason for cancellation i,e; CircuitBreakingException
280+
if (!pendingMerges.hasFailure()) {
281+
pendingMerges.failure.set(new TaskCancelledException("request has been terminated"));
282+
}
283+
}
284+
}
285+
245286
public int getNumReducePhases() {
246287
return pendingMerges.numReducePhases;
247288
}
@@ -342,9 +383,10 @@ long estimateRamBytesUsedForReduce(long size) {
342383
return Math.round(1.5d * size - size);
343384
}
344385

345-
public void consume(QuerySearchResult result, Runnable next) {
386+
public void consume(QuerySearchResult result, Runnable next) throws CircuitBreakingException {
346387
boolean executeNextImmediately = true;
347388
synchronized (this) {
389+
checkCircuitBreaker(next);
348390
if (hasFailure() || result.isNull()) {
349391
result.consumeAll();
350392
if (result.isNull()) {
@@ -378,17 +420,32 @@ public void consume(QuerySearchResult result, Runnable next) {
378420
}
379421
}
380422

423+
/**
424+
* This method is needed to prevent OOM when the buffered results are too large
425+
*
426+
*/
427+
private void checkCircuitBreaker(Runnable next) throws CircuitBreakingException {
428+
try {
429+
// force the CircuitBreaker eval to ensure during buffering we did not hit the circuit breaker limit
430+
addEstimateAndMaybeBreak(0);
431+
} catch (CircuitBreakingException e) {
432+
resetCircuitBreakerForCurrentRequest();
433+
// onPartialMergeFailure should only be invoked once since this is responsible for cancelling the
434+
// search task
435+
if (!hasFailure()) {
436+
failure.set(e);
437+
onPartialMergeFailure.accept(e);
438+
}
439+
}
440+
}
441+
381442
private synchronized void onMergeFailure(Exception exc) {
382443
if (hasFailure()) {
383444
assert circuitBreakerBytes == 0;
384445
return;
385446
}
386447
assert circuitBreakerBytes >= 0;
387-
if (circuitBreakerBytes > 0) {
388-
// make sure that we reset the circuit breaker
389-
circuitBreaker.addWithoutBreaking(-circuitBreakerBytes);
390-
circuitBreakerBytes = 0;
391-
}
448+
resetCircuitBreakerForCurrentRequest();
392449
failure.compareAndSet(null, exc);
393450
MergeTask task = runningTask.get();
394451
runningTask.compareAndSet(task, null);
@@ -405,6 +462,13 @@ private synchronized void onMergeFailure(Exception exc) {
405462
}
406463
}
407464

465+
private void resetCircuitBreakerForCurrentRequest() {
466+
if (circuitBreakerBytes > 0) {
467+
circuitBreaker.addWithoutBreaking(-circuitBreakerBytes);
468+
circuitBreakerBytes = 0;
469+
}
470+
}
471+
408472
private void onAfterMerge(MergeTask task, MergeResult newResult, long estimatedSize) {
409473
synchronized (this) {
410474
if (hasFailure()) {

server/src/main/java/org/opensearch/action/search/SearchPhaseController.java

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@
7878
import java.util.List;
7979
import java.util.Map;
8080
import java.util.concurrent.Executor;
81+
import java.util.function.BooleanSupplier;
8182
import java.util.function.Consumer;
8283
import java.util.function.Function;
8384
import java.util.function.IntFunction;
@@ -778,6 +779,21 @@ QueryPhaseResultConsumer newSearchPhaseResults(
778779
SearchRequest request,
779780
int numShards,
780781
Consumer<Exception> onPartialMergeFailure
782+
) {
783+
return newSearchPhaseResults(executor, circuitBreaker, listener, request, numShards, onPartialMergeFailure, () -> false);
784+
}
785+
786+
/**
787+
* Returns a new {@link QueryPhaseResultConsumer} instance that reduces search responses incrementally.
788+
*/
789+
QueryPhaseResultConsumer newSearchPhaseResults(
790+
Executor executor,
791+
CircuitBreaker circuitBreaker,
792+
SearchProgressListener listener,
793+
SearchRequest request,
794+
int numShards,
795+
Consumer<Exception> onPartialMergeFailure,
796+
BooleanSupplier isTaskCancelled
781797
) {
782798
return new QueryPhaseResultConsumer(
783799
request,
@@ -787,7 +803,8 @@ QueryPhaseResultConsumer newSearchPhaseResults(
787803
listener,
788804
namedWriteableRegistry,
789805
numShards,
790-
onPartialMergeFailure
806+
onPartialMergeFailure,
807+
isTaskCancelled
791808
);
792809
}
793810

server/src/main/java/org/opensearch/action/search/TransportSearchAction.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1270,7 +1270,8 @@ AbstractSearchAsyncAction<? extends SearchPhaseResult> searchAsyncAction(
12701270
task.getProgressListener(),
12711271
searchRequest,
12721272
shardIterators.size(),
1273-
exc -> cancelTask(task, exc)
1273+
exc -> cancelTask(task, exc),
1274+
task::isCancelled
12741275
);
12751276
AbstractSearchAsyncAction<? extends SearchPhaseResult> searchAsyncAction;
12761277
switch (searchRequest.searchType()) {

server/src/test/java/org/opensearch/action/search/QueryPhaseResultConsumerTests.java

Lines changed: 48 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -36,18 +36,23 @@
3636
import org.apache.lucene.search.TopDocs;
3737
import org.apache.lucene.search.TotalHits;
3838
import org.opensearch.action.OriginalIndices;
39+
import org.opensearch.common.breaker.TestCircuitBreaker;
3940
import org.opensearch.common.lucene.search.TopDocsAndMaxScore;
4041
import org.opensearch.common.util.BigArrays;
4142
import org.opensearch.common.util.concurrent.OpenSearchExecutors;
4243
import org.opensearch.common.util.concurrent.OpenSearchThreadPoolExecutor;
4344
import org.opensearch.core.common.breaker.CircuitBreaker;
45+
import org.opensearch.core.common.breaker.CircuitBreakingException;
4446
import org.opensearch.core.common.breaker.NoopCircuitBreaker;
4547
import org.opensearch.core.index.shard.ShardId;
4648
import org.opensearch.search.DocValueFormat;
4749
import org.opensearch.search.SearchShardTarget;
50+
import org.opensearch.search.aggregations.AggregationBuilders;
4851
import org.opensearch.search.aggregations.InternalAggregation;
4952
import org.opensearch.search.aggregations.InternalAggregations;
53+
import org.opensearch.search.aggregations.metrics.InternalMax;
5054
import org.opensearch.search.aggregations.pipeline.PipelineAggregator;
55+
import org.opensearch.search.builder.SearchSourceBuilder;
5156
import org.opensearch.search.query.QuerySearchResult;
5257
import org.opensearch.test.OpenSearchTestCase;
5358
import org.opensearch.threadpool.TestThreadPool;
@@ -136,17 +141,7 @@ public void testProgressListenerExceptionsAreCaught() throws Exception {
136141
CountDownLatch partialReduceLatch = new CountDownLatch(10);
137142

138143
for (int i = 0; i < 10; i++) {
139-
SearchShardTarget searchShardTarget = new SearchShardTarget(
140-
"node",
141-
new ShardId("index", "uuid", i),
142-
null,
143-
OriginalIndices.NONE
144-
);
145-
QuerySearchResult querySearchResult = new QuerySearchResult();
146-
TopDocs topDocs = new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
147-
querySearchResult.topDocs(new TopDocsAndMaxScore(topDocs, Float.NaN), new DocValueFormat[0]);
148-
querySearchResult.setSearchShardTarget(searchShardTarget);
149-
querySearchResult.setShardIndex(i);
144+
QuerySearchResult querySearchResult = getQuerySearchResult(i);
150145
queryPhaseResultConsumer.consumeResult(querySearchResult, partialReduceLatch::countDown);
151146
}
152147

@@ -159,6 +154,48 @@ public void testProgressListenerExceptionsAreCaught() throws Exception {
159154
assertEquals(1, searchProgressListener.onFinalReduce.get());
160155
}
161156

157+
public void testCircuitBreakerTriggersBeforeBatchedReduce() {
158+
SearchRequest searchRequest = new SearchRequest("index");
159+
searchRequest.source(new SearchSourceBuilder().aggregation(AggregationBuilders.max("max").field("test")).size(1));
160+
final int BATCH_REDUCED_SIZE = 5;
161+
searchRequest.setBatchedReduceSize(BATCH_REDUCED_SIZE);
162+
AtomicReference<Exception> onPartialMergeFailure = new AtomicReference<>();
163+
TestCircuitBreaker testCircuitBreaker = new TestCircuitBreaker();
164+
QueryPhaseResultConsumer queryPhaseResultConsumer = new QueryPhaseResultConsumer(
165+
searchRequest,
166+
executor,
167+
testCircuitBreaker,
168+
searchPhaseController,
169+
SearchProgressListener.NOOP,
170+
writableRegistry(),
171+
10,
172+
onPartialMergeFailure::set
173+
);
174+
175+
for (int i = 0; i < BATCH_REDUCED_SIZE - 2; i++) {
176+
QuerySearchResult querySearchResult = getQuerySearchResult(i);
177+
querySearchResult.aggregations(InternalAggregations.from(List.of(new InternalMax("test", 23, DocValueFormat.RAW, null))));
178+
queryPhaseResultConsumer.consumeResult(querySearchResult, () -> {});
179+
}
180+
181+
testCircuitBreaker.startBreaking();
182+
QuerySearchResult querySearchResult = getQuerySearchResult(7);
183+
querySearchResult.aggregations(InternalAggregations.from(List.of(new InternalMax("test", 23, DocValueFormat.RAW, null))));
184+
queryPhaseResultConsumer.consumeResult(querySearchResult, () -> {});
185+
assertThrows(CircuitBreakingException.class, queryPhaseResultConsumer::reduce);
186+
}
187+
188+
private static QuerySearchResult getQuerySearchResult(int i) {
189+
SearchShardTarget searchShardTarget = new SearchShardTarget("node", new ShardId("index", "uuid", i), null, OriginalIndices.NONE);
190+
191+
QuerySearchResult querySearchResult = new QuerySearchResult();
192+
TopDocs topDocs = new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
193+
querySearchResult.topDocs(new TopDocsAndMaxScore(topDocs, Float.NaN), new DocValueFormat[0]);
194+
querySearchResult.setSearchShardTarget(searchShardTarget);
195+
querySearchResult.setShardIndex(i);
196+
return querySearchResult;
197+
}
198+
162199
private static class ThrowingSearchProgressListener extends SearchProgressListener {
163200
private final AtomicInteger onQueryResult = new AtomicInteger(0);
164201
private final AtomicInteger onPartialReduce = new AtomicInteger(0);

0 commit comments

Comments
 (0)