4242import org .opensearch .core .common .breaker .CircuitBreaker ;
4343import org .opensearch .core .common .breaker .CircuitBreakingException ;
4444import org .opensearch .core .common .io .stream .NamedWriteableRegistry ;
45+ import org .opensearch .core .tasks .TaskCancelledException ;
4546import org .opensearch .search .SearchPhaseResult ;
4647import org .opensearch .search .SearchShardTarget ;
4748import org .opensearch .search .aggregations .InternalAggregation .ReduceContextBuilder ;
5758import java .util .List ;
5859import java .util .concurrent .Executor ;
5960import java .util .concurrent .atomic .AtomicReference ;
61+ import java .util .function .BooleanSupplier ;
6062import java .util .function .Consumer ;
6163
6264/**
@@ -86,6 +88,30 @@ public class QueryPhaseResultConsumer extends ArraySearchPhaseResults<SearchPhas
8688
8789 final PendingMerges pendingMerges ;
8890 private final Consumer <Exception > onPartialMergeFailure ;
91+ private final BooleanSupplier isTaskCancelled ;
92+
93+ public QueryPhaseResultConsumer (
94+ SearchRequest request ,
95+ Executor executor ,
96+ CircuitBreaker circuitBreaker ,
97+ SearchPhaseController controller ,
98+ SearchProgressListener progressListener ,
99+ NamedWriteableRegistry namedWriteableRegistry ,
100+ int expectedResultSize ,
101+ Consumer <Exception > onPartialMergeFailure
102+ ) {
103+ this (
104+ request ,
105+ executor ,
106+ circuitBreaker ,
107+ controller ,
108+ progressListener ,
109+ namedWriteableRegistry ,
110+ expectedResultSize ,
111+ onPartialMergeFailure ,
112+ () -> false
113+ );
114+ }
89115
90116 /**
91117 * Creates a {@link QueryPhaseResultConsumer} that incrementally reduces aggregation results
@@ -99,7 +125,8 @@ public QueryPhaseResultConsumer(
99125 SearchProgressListener progressListener ,
100126 NamedWriteableRegistry namedWriteableRegistry ,
101127 int expectedResultSize ,
102- Consumer <Exception > onPartialMergeFailure
128+ Consumer <Exception > onPartialMergeFailure ,
129+ BooleanSupplier isTaskCancelled
103130 ) {
104131 super (expectedResultSize );
105132 this .executor = executor ;
@@ -117,6 +144,7 @@ public QueryPhaseResultConsumer(
117144 this .hasAggs = source != null && source .aggregations () != null ;
118145 int batchReduceSize = getBatchReduceSize (request .getBatchedReduceSize (), expectedResultSize );
119146 this .pendingMerges = new PendingMerges (batchReduceSize , request .resolveTrackTotalHitsUpTo ());
147+ this .isTaskCancelled = isTaskCancelled ;
120148 }
121149
122150 int getBatchReduceSize (int requestBatchedReduceSize , int minBatchReduceSize ) {
@@ -133,6 +161,7 @@ public void consumeResult(SearchPhaseResult result, Runnable next) {
133161 super .consumeResult (result , () -> {});
134162 QuerySearchResult querySearchResult = result .queryResult ();
135163 progressListener .notifyQueryResult (querySearchResult .getShardIndex ());
164+ checkCancellation ();
136165 pendingMerges .consume (querySearchResult , next );
137166 }
138167
@@ -143,6 +172,7 @@ public SearchPhaseController.ReducedQueryPhase reduce() throws Exception {
143172 } else if (pendingMerges .hasFailure ()) {
144173 throw pendingMerges .getFailure ();
145174 }
175+ checkCancellation ();
146176
147177 // ensure consistent ordering
148178 pendingMerges .sortBuffer ();
@@ -186,6 +216,7 @@ private MergeResult partialReduce(
186216 MergeResult lastMerge ,
187217 int numReducePhases
188218 ) {
219+ checkCancellation ();
189220 // ensure consistent ordering
190221 Arrays .sort (toConsume , Comparator .comparingInt (QuerySearchResult ::getShardIndex ));
191222
@@ -242,6 +273,16 @@ private MergeResult partialReduce(
242273 return new MergeResult (processedShards , newTopDocs , newAggs , hasAggs ? serializedSize : 0 );
243274 }
244275
276+ private void checkCancellation () {
277+ if (isTaskCancelled .getAsBoolean ()) {
278+ pendingMerges .resetCircuitBreakerForCurrentRequest ();
279+ // This check is to ensure that we are not masking the actual reason for cancellation i,e; CircuitBreakingException
280+ if (!pendingMerges .hasFailure ()) {
281+ pendingMerges .failure .set (new TaskCancelledException ("request has been terminated" ));
282+ }
283+ }
284+ }
285+
245286 public int getNumReducePhases () {
246287 return pendingMerges .numReducePhases ;
247288 }
@@ -342,9 +383,10 @@ long estimateRamBytesUsedForReduce(long size) {
342383 return Math .round (1.5d * size - size );
343384 }
344385
345- public void consume (QuerySearchResult result , Runnable next ) {
386+ public void consume (QuerySearchResult result , Runnable next ) throws CircuitBreakingException {
346387 boolean executeNextImmediately = true ;
347388 synchronized (this ) {
389+ checkCircuitBreaker (next );
348390 if (hasFailure () || result .isNull ()) {
349391 result .consumeAll ();
350392 if (result .isNull ()) {
@@ -378,17 +420,32 @@ public void consume(QuerySearchResult result, Runnable next) {
378420 }
379421 }
380422
423+ /**
424+ * This method is needed to prevent OOM when the buffered results are too large
425+ *
426+ */
427+ private void checkCircuitBreaker (Runnable next ) throws CircuitBreakingException {
428+ try {
429+ // force the CircuitBreaker eval to ensure during buffering we did not hit the circuit breaker limit
430+ addEstimateAndMaybeBreak (0 );
431+ } catch (CircuitBreakingException e ) {
432+ resetCircuitBreakerForCurrentRequest ();
433+ // onPartialMergeFailure should only be invoked once since this is responsible for cancelling the
434+ // search task
435+ if (!hasFailure ()) {
436+ failure .set (e );
437+ onPartialMergeFailure .accept (e );
438+ }
439+ }
440+ }
441+
381442 private synchronized void onMergeFailure (Exception exc ) {
382443 if (hasFailure ()) {
383444 assert circuitBreakerBytes == 0 ;
384445 return ;
385446 }
386447 assert circuitBreakerBytes >= 0 ;
387- if (circuitBreakerBytes > 0 ) {
388- // make sure that we reset the circuit breaker
389- circuitBreaker .addWithoutBreaking (-circuitBreakerBytes );
390- circuitBreakerBytes = 0 ;
391- }
448+ resetCircuitBreakerForCurrentRequest ();
392449 failure .compareAndSet (null , exc );
393450 MergeTask task = runningTask .get ();
394451 runningTask .compareAndSet (task , null );
@@ -405,6 +462,13 @@ private synchronized void onMergeFailure(Exception exc) {
405462 }
406463 }
407464
465+ private void resetCircuitBreakerForCurrentRequest () {
466+ if (circuitBreakerBytes > 0 ) {
467+ circuitBreaker .addWithoutBreaking (-circuitBreakerBytes );
468+ circuitBreakerBytes = 0 ;
469+ }
470+ }
471+
408472 private void onAfterMerge (MergeTask task , MergeResult newResult , long estimatedSize ) {
409473 synchronized (this ) {
410474 if (hasFailure ()) {
0 commit comments