Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions .idea/runConfigurations/Debug_OpenSearch.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Removed

### Fixed
- Add task cancellation checks in aggregators ([#18426](https://github.com/opensearch-project/OpenSearch/pull/18426))
- Fix OOM due to large number of shard result buffering ([#19066](https://github.com/opensearch-project/OpenSearch/pull/19066))
- Fix QueryPhaseResultConsumer incomplete callback loops ([#19231](https://github.com/opensearch-project/OpenSearch/pull/19231))
- Use Bad Request status for InputCoercionException ([#18161](https://github.com/opensearch-project/OpenSearch/pull/18161))
- Avoid NPE if on SnapshotInfo if 'shallow' boolean not present ([#18187](https://github.com/opensearch-project/OpenSearch/issues/18187))
- Null check field names in QueryStringQueryBuilder ([#18194](https://github.com/opensearch-project/OpenSearch/pull/18194))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import org.opensearch.core.common.breaker.CircuitBreaker;
import org.opensearch.core.common.breaker.CircuitBreakingException;
import org.opensearch.core.common.io.stream.NamedWriteableRegistry;
import org.opensearch.core.tasks.TaskCancelledException;
import org.opensearch.search.SearchPhaseResult;
import org.opensearch.search.SearchShardTarget;
import org.opensearch.search.aggregations.InternalAggregation.ReduceContextBuilder;
Expand All @@ -57,6 +58,7 @@
import java.util.List;
import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BooleanSupplier;
import java.util.function.Consumer;

/**
Expand Down Expand Up @@ -86,6 +88,30 @@ public class QueryPhaseResultConsumer extends ArraySearchPhaseResults<SearchPhas

private final PendingMerges pendingMerges;
private final Consumer<Exception> onPartialMergeFailure;
private final BooleanSupplier isTaskCancelled;

public QueryPhaseResultConsumer(
SearchRequest request,
Executor executor,
CircuitBreaker circuitBreaker,
SearchPhaseController controller,
SearchProgressListener progressListener,
NamedWriteableRegistry namedWriteableRegistry,
int expectedResultSize,
Consumer<Exception> onPartialMergeFailure
) {
this(
request,
executor,
circuitBreaker,
controller,
progressListener,
namedWriteableRegistry,
expectedResultSize,
onPartialMergeFailure,
() -> false
);
}

/**
* Creates a {@link QueryPhaseResultConsumer} that incrementally reduces aggregation results
Expand All @@ -99,7 +125,8 @@ public QueryPhaseResultConsumer(
SearchProgressListener progressListener,
NamedWriteableRegistry namedWriteableRegistry,
int expectedResultSize,
Consumer<Exception> onPartialMergeFailure
Consumer<Exception> onPartialMergeFailure,
BooleanSupplier isTaskCancelled
) {
super(expectedResultSize);
this.executor = executor;
Expand All @@ -117,6 +144,7 @@ public QueryPhaseResultConsumer(
this.hasAggs = source != null && source.aggregations() != null;
int batchReduceSize = (hasAggs || hasTopDocs) ? Math.min(request.getBatchedReduceSize(), expectedResultSize) : expectedResultSize;
this.pendingMerges = new PendingMerges(batchReduceSize, request.resolveTrackTotalHitsUpTo());
this.isTaskCancelled = isTaskCancelled;
}

@Override
Expand All @@ -129,6 +157,7 @@ public void consumeResult(SearchPhaseResult result, Runnable next) {
super.consumeResult(result, () -> {});
QuerySearchResult querySearchResult = result.queryResult();
progressListener.notifyQueryResult(querySearchResult.getShardIndex());
checkCancellation();
pendingMerges.consume(querySearchResult, next);
}

Expand All @@ -139,6 +168,7 @@ public SearchPhaseController.ReducedQueryPhase reduce() throws Exception {
} else if (pendingMerges.hasFailure()) {
throw pendingMerges.getFailure();
}
checkCancellation();

// ensure consistent ordering
pendingMerges.sortBuffer();
Expand Down Expand Up @@ -182,6 +212,10 @@ private MergeResult partialReduce(
MergeResult lastMerge,
int numReducePhases
) {
checkCancellation();
if (pendingMerges.hasFailure()) {
return lastMerge;
}
// ensure consistent ordering
Arrays.sort(toConsume, Comparator.comparingInt(QuerySearchResult::getShardIndex));

Expand Down Expand Up @@ -238,6 +272,16 @@ private MergeResult partialReduce(
return new MergeResult(processedShards, newTopDocs, newAggs, hasAggs ? serializedSize : 0);
}

private void checkCancellation() {
if (isTaskCancelled.getAsBoolean()) {
pendingMerges.resetCircuitBreakerForCurrentRequest();
// This check is to ensure that we are not masking the actual reason for cancellation i,e; CircuitBreakingException
if (!pendingMerges.hasFailure()) {
pendingMerges.failure.set(new TaskCancelledException("request has been terminated"));
}
}
}

public int getNumReducePhases() {
return pendingMerges.numReducePhases;
}
Expand Down Expand Up @@ -338,9 +382,10 @@ long estimateRamBytesUsedForReduce(long size) {
return Math.round(1.5d * size - size);
}

public void consume(QuerySearchResult result, Runnable next) {
public void consume(QuerySearchResult result, Runnable next) throws CircuitBreakingException {
boolean executeNextImmediately = true;
synchronized (this) {
checkCircuitBreaker(next);
if (hasFailure() || result.isNull()) {
result.consumeAll();
if (result.isNull()) {
Expand Down Expand Up @@ -374,21 +419,40 @@ public void consume(QuerySearchResult result, Runnable next) {
}
}

/**
* This method is needed to prevent OOM when the buffered results are too large
*
*/
private void checkCircuitBreaker(Runnable next) throws CircuitBreakingException {
try {
// force the CircuitBreaker eval to ensure during buffering we did not hit the circuit breaker limit
addEstimateAndMaybeBreak(0);
} catch (CircuitBreakingException e) {
resetCircuitBreakerForCurrentRequest();
// onPartialMergeFailure should only be invoked once since this is responsible for cancelling the
// search task
if (!hasFailure()) {
failure.set(e);
onPartialMergeFailure.accept(e);
}
}
}

private synchronized void onMergeFailure(Exception exc) {
if (hasFailure()) {
assert circuitBreakerBytes == 0;
return;
}
assert circuitBreakerBytes >= 0;
if (circuitBreakerBytes > 0) {
// make sure that we reset the circuit breaker
circuitBreaker.addWithoutBreaking(-circuitBreakerBytes);
circuitBreakerBytes = 0;
}
resetCircuitBreakerForCurrentRequest();
failure.compareAndSet(null, exc);
MergeTask task = runningTask.get();
runningTask.compareAndSet(task, null);
onPartialMergeFailure.accept(exc);
clearPendingMerges(task);
}

void clearPendingMerges(MergeTask task) {
List<MergeTask> toCancels = new ArrayList<>();
if (task != null) {
toCancels.add(task);
Expand All @@ -401,12 +465,20 @@ private synchronized void onMergeFailure(Exception exc) {
}
}

private void resetCircuitBreakerForCurrentRequest() {
if (circuitBreakerBytes > 0) {
circuitBreaker.addWithoutBreaking(-circuitBreakerBytes);
circuitBreakerBytes = 0;
}
}

private void onAfterMerge(MergeTask task, MergeResult newResult, long estimatedSize) {
synchronized (this) {
runningTask.compareAndSet(task, null);
if (hasFailure()) {
task.cancel();
return;
}
runningTask.compareAndSet(task, null);
mergeResult = newResult;
if (hasAggs) {
// Update the circuit breaker to remove the size of the source aggregations
Expand All @@ -427,7 +499,11 @@ private void onAfterMerge(MergeTask task, MergeResult newResult, long estimatedS
private void tryExecuteNext() {
final MergeTask task;
synchronized (this) {
if (queue.isEmpty() || hasFailure() || runningTask.get() != null) {
if (hasFailure()) {
clearPendingMerges(null);
return;
}
if (queue.isEmpty() || runningTask.get() != null) {
return;
}
task = queue.poll();
Expand All @@ -443,6 +519,7 @@ protected void doRun() {
try {
final QuerySearchResult[] toConsume = task.consumeBuffer();
if (toConsume == null) {
task.cancel();
return;
}
long estimatedMergeSize = estimateRamBytesUsedForReduce(estimatedTotalSize);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
import java.util.List;
import java.util.Map;
import java.util.concurrent.Executor;
import java.util.function.BooleanSupplier;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.IntFunction;
Expand Down Expand Up @@ -758,6 +759,21 @@ QueryPhaseResultConsumer newSearchPhaseResults(
SearchRequest request,
int numShards,
Consumer<Exception> onPartialMergeFailure
) {
return newSearchPhaseResults(executor, circuitBreaker, listener, request, numShards, onPartialMergeFailure, () -> false);
}

/**
* Returns a new {@link QueryPhaseResultConsumer} instance that reduces search responses incrementally.
*/
QueryPhaseResultConsumer newSearchPhaseResults(
Executor executor,
CircuitBreaker circuitBreaker,
SearchProgressListener listener,
SearchRequest request,
int numShards,
Consumer<Exception> onPartialMergeFailure,
BooleanSupplier isTaskCancelled
) {
return new QueryPhaseResultConsumer(
request,
Expand All @@ -767,7 +783,8 @@ QueryPhaseResultConsumer newSearchPhaseResults(
listener,
namedWriteableRegistry,
numShards,
onPartialMergeFailure
onPartialMergeFailure,
isTaskCancelled
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1265,7 +1265,8 @@ private AbstractSearchAsyncAction<? extends SearchPhaseResult> searchAsyncAction
task.getProgressListener(),
searchRequest,
shardIterators.size(),
exc -> cancelTask(task, exc)
exc -> cancelTask(task, exc),
task::isCancelled
);
AbstractSearchAsyncAction<? extends SearchPhaseResult> searchAsyncAction;
switch (searchRequest.searchType()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import org.opensearch.core.common.breaker.CircuitBreaker;
import org.opensearch.core.common.breaker.CircuitBreakingException;
import org.opensearch.core.indices.breaker.CircuitBreakerService;
import org.opensearch.core.tasks.TaskCancelledException;
import org.opensearch.search.SearchShardTarget;
import org.opensearch.search.aggregations.support.ValuesSourceConfig;
import org.opensearch.search.internal.SearchContext;
Expand Down Expand Up @@ -328,4 +329,10 @@ protected final InternalAggregations buildEmptySubAggregations() {
public String toString() {
return name;
}

protected void checkCancelled() {
if (context.isCancelled()) {
throw new TaskCancelledException("The query has been cancelled");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -229,9 +229,11 @@ protected void beforeBuildingBuckets(long[] ordsToCollect) throws IOException {}
* array of ordinals
*/
protected final InternalAggregations[] buildSubAggsForBuckets(long[] bucketOrdsToCollect) throws IOException {
checkCancelled();
beforeBuildingBuckets(bucketOrdsToCollect);
InternalAggregation[][] aggregations = new InternalAggregation[subAggregators.length][];
for (int i = 0; i < subAggregators.length; i++) {
checkCancelled();
aggregations[i] = subAggregators[i].buildAggregations(bucketOrdsToCollect);
}
InternalAggregations[] result = new InternalAggregations[bucketOrdsToCollect.length];
Expand Down Expand Up @@ -317,6 +319,7 @@ protected final <B> InternalAggregation[] buildAggregationsForFixedBucketCount(
BucketBuilderForFixedCount<B> bucketBuilder,
Function<List<B>, InternalAggregation> resultBuilder
) throws IOException {
checkCancelled();
int totalBuckets = owningBucketOrds.length * bucketsPerOwningBucketOrd;
long[] bucketOrdsToCollect = new long[totalBuckets];
int bucketOrdIdx = 0;
Expand Down Expand Up @@ -367,6 +370,7 @@ protected final InternalAggregation[] buildAggregationsForSingleBucket(long[] ow
* `consumeBucketsAndMaybeBreak(owningBucketOrds.length)`
* here but we don't because single bucket aggs never have.
*/
checkCancelled();
InternalAggregations[] subAggregationResults = buildSubAggsForBuckets(owningBucketOrds);
InternalAggregation[] results = new InternalAggregation[owningBucketOrds.length];
for (int ordIdx = 0; ordIdx < owningBucketOrds.length; ordIdx++) {
Expand Down Expand Up @@ -397,6 +401,7 @@ protected final <B> InternalAggregation[] buildAggregationsForVariableBuckets(
BucketBuilderForVariable<B> bucketBuilder,
ResultBuilderForVariable<B> resultBuilder
) throws IOException {
checkCancelled();
long totalOrdsToCollect = 0;
for (int ordIdx = 0; ordIdx < owningBucketOrds.length; ordIdx++) {
totalOrdsToCollect += bucketOrds.bucketsInOrd(owningBucketOrds[ordIdx]);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ public void collect(int doc, long bucket) throws IOException {

@Override
public InternalAggregation[] buildAggregations(long[] owningBucketOrds) throws IOException {
checkCancelled();
// Buckets are ordered into groups - [keyed filters] [key1&key2 intersects]
int maxOrd = owningBucketOrds.length * totalNumKeys;
int totalBucketsToBuild = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ protected void doPostCollection() throws IOException {

@Override
public InternalAggregation[] buildAggregations(long[] owningBucketOrds) throws IOException {
checkCancelled();
// Composite aggregator must be at the top of the aggregation tree
assert owningBucketOrds.length == 1 && owningBucketOrds[0] == 0L;
if (deferredCollectors != NO_OP_COLLECTOR) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ public InternalAggregation[] buildAggregations(long[] owningBucketOrds) throws I
owningBucketOrds,
keys.length + (showOtherBucket ? 1 : 0),
(offsetInOwningOrd, docCount, subAggregationResults) -> {
checkCancelled();
if (offsetInOwningOrd < keys.length) {
return new InternalFilters.InternalBucket(keys[offsetInOwningOrd], docCount, subAggregationResults, keyed);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,12 @@ public AbstractHistogramAggregator(
@Override
public InternalAggregation[] buildAggregations(long[] owningBucketOrds) throws IOException {
return buildAggregationsForVariableBuckets(owningBucketOrds, bucketOrds, (bucketValue, docCount, subAggregationResults) -> {
checkCancelled();
double roundKey = Double.longBitsToDouble(bucketValue);
double key = roundKey * interval + offset;
return new InternalHistogram.Bucket(key, docCount, keyed, formatter, subAggregationResults);
}, (owningBucketOrd, buckets) -> {
checkCancelled();
// the contract of the histogram aggregation is that shards must return buckets ordered by key in ascending order
CollectionUtil.introSort(buckets, BucketOrder.key(true).comparator());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ protected final InternalAggregation[] buildAggregations(
subAggregationResults
),
(owningBucketOrd, buckets) -> {
checkCancelled();
// the contract of the histogram aggregation is that shards must return
// buckets ordered by key in ascending order
CollectionUtil.introSort(buckets, BucketOrder.key(true).comparator());
Expand Down Expand Up @@ -727,6 +728,7 @@ private int increaseRoundingIfNeeded(long owningBucketOrd, int oldEstimatedBucke
private void rebucket() {
rebucketCount++;
try (LongKeyedBucketOrds oldOrds = bucketOrds) {
checkCancelled();
long[] mergeMap = new long[Math.toIntExact(oldOrds.size())];
bucketOrds = new LongKeyedBucketOrds.FromMany(context.bigArrays());
for (long owningBucketOrd = 0; owningBucketOrd <= oldOrds.maxOwningBucketOrd(); owningBucketOrd++) {
Expand Down
Loading
Loading