Skip to content

Commit

Permalink
Cleanup usages of QueryPhaseResultConsumer (#61713)
Browse files Browse the repository at this point in the history
This commit generalizes how QueryPhaseResultConsumer is initialized.
The query phase always uses this consumer so it doesn't need to be hidden behind
an abstract class.
  • Loading branch information
jimczi committed Sep 2, 2020
1 parent a8bbdd9 commit a0e4331
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
* @see CountedCollector#onFailure(int, SearchShardTarget, Exception)
*/
final class DfsQueryPhase extends SearchPhase {
private final ArraySearchPhaseResults<SearchPhaseResult> queryResult;
private final QueryPhaseResultConsumer queryResult;
private final SearchPhaseController searchPhaseController;
private final AtomicArray<DfsSearchResult> dfsSearchResults;
private final Function<ArraySearchPhaseResults<SearchPhaseResult>, SearchPhase> nextPhaseFactory;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.elasticsearch.search.SearchShardTarget;
import org.elasticsearch.search.aggregations.InternalAggregation.ReduceContextBuilder;
import org.elasticsearch.search.aggregations.InternalAggregations;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.query.QuerySearchResult;

import java.util.ArrayDeque;
Expand All @@ -43,6 +44,7 @@
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;

import static org.elasticsearch.action.search.SearchPhaseController.getTopDocsSize;
import static org.elasticsearch.action.search.SearchPhaseController.mergeTopDocs;
import static org.elasticsearch.action.search.SearchPhaseController.setShardIndex;

Expand All @@ -52,7 +54,7 @@
* This implementation can be configured to batch up a certain amount of results and reduce
* them asynchronously in the provided {@link Executor} iff the buffer is exhausted.
*/
class QueryPhaseResultConsumer extends ArraySearchPhaseResults<SearchPhaseResult> {
public class QueryPhaseResultConsumer extends ArraySearchPhaseResults<SearchPhaseResult> {
private static final Logger logger = LogManager.getLogger(QueryPhaseResultConsumer.class);

private final Executor executor;
Expand All @@ -76,43 +78,39 @@ class QueryPhaseResultConsumer extends ArraySearchPhaseResults<SearchPhaseResult
* Creates a {@link QueryPhaseResultConsumer} that incrementally reduces aggregation results
* as shard results are consumed.
*/
QueryPhaseResultConsumer(Executor executor,
SearchPhaseController controller,
SearchProgressListener progressListener,
ReduceContextBuilder aggReduceContextBuilder,
NamedWriteableRegistry namedWriteableRegistry,
int expectedResultSize,
int bufferSize,
boolean hasTopDocs,
boolean hasAggs,
int trackTotalHitsUpTo,
int topNSize,
boolean performFinalReduce,
Consumer<Exception> onPartialMergeFailure) {
public QueryPhaseResultConsumer(SearchRequest request,
Executor executor,
SearchPhaseController controller,
SearchProgressListener progressListener,
NamedWriteableRegistry namedWriteableRegistry,
int expectedResultSize,
Consumer<Exception> onPartialMergeFailure) {
super(expectedResultSize);
this.executor = executor;
this.controller = controller;
this.progressListener = progressListener;
this.aggReduceContextBuilder = aggReduceContextBuilder;
this.aggReduceContextBuilder = controller.getReduceContext(request);
this.namedWriteableRegistry = namedWriteableRegistry;
this.topNSize = topNSize;
this.pendingMerges = new PendingMerges(bufferSize, trackTotalHitsUpTo);
this.hasTopDocs = hasTopDocs;
this.hasAggs = hasAggs;
this.performFinalReduce = performFinalReduce;
this.topNSize = getTopDocsSize(request);
this.performFinalReduce = request.isFinalReduce();
this.onPartialMergeFailure = onPartialMergeFailure;
SearchSourceBuilder source = request.source();
this.hasTopDocs = source == null || source.size() != 0;
this.hasAggs = source != null && source.aggregations() != null;
int bufferSize = (hasAggs || hasTopDocs) ? Math.min(request.getBatchedReduceSize(), expectedResultSize) : expectedResultSize;
this.pendingMerges = new PendingMerges(bufferSize, request.resolveTrackTotalHitsUpTo());
}

@Override
void consumeResult(SearchPhaseResult result, Runnable next) {
public void consumeResult(SearchPhaseResult result, Runnable next) {
super.consumeResult(result, () -> {});
QuerySearchResult querySearchResult = result.queryResult();
progressListener.notifyQueryResult(querySearchResult.getShardIndex());
pendingMerges.consume(querySearchResult, next);
}

@Override
SearchPhaseController.ReducedQueryPhase reduce() throws Exception {
public SearchPhaseController.ReducedQueryPhase reduce() throws Exception {
if (pendingMerges.hasPendingMerges()) {
throw new AssertionError("partial reduce in-flight");
} else if (pendingMerges.hasFailure()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -558,23 +558,19 @@ public InternalSearchResponse buildResponse(SearchHits hits) {
}
}

InternalAggregation.ReduceContextBuilder getReduceContext(SearchRequest request) {
return requestToAggReduceContextBuilder.apply(request);
}

/**
* Returns a new ArraySearchPhaseResults instance. This might return an instance that reduces search responses incrementally.
* Returns a new {@link QueryPhaseResultConsumer} instance. This might return an instance that reduces search responses incrementally.
*/
ArraySearchPhaseResults<SearchPhaseResult> newSearchPhaseResults(Executor executor,
SearchProgressListener listener,
SearchRequest request,
int numShards,
Consumer<Exception> onPartialMergeFailure) {
SearchSourceBuilder source = request.source();
final boolean hasAggs = source != null && source.aggregations() != null;
final boolean hasTopDocs = source == null || source.size() != 0;
final int trackTotalHitsUpTo = request.resolveTrackTotalHitsUpTo();
InternalAggregation.ReduceContextBuilder aggReduceContextBuilder = requestToAggReduceContextBuilder.apply(request);
int topNSize = getTopDocsSize(request);
int bufferSize = (hasAggs || hasTopDocs) ? Math.min(request.getBatchedReduceSize(), numShards) : numShards;
return new QueryPhaseResultConsumer(executor, this, listener, aggReduceContextBuilder, namedWriteableRegistry,
numShards, bufferSize, hasTopDocs, hasAggs, trackTotalHitsUpTo, topNSize, request.isFinalReduce(), onPartialMergeFailure);
QueryPhaseResultConsumer newSearchPhaseResults(Executor executor,
SearchProgressListener listener,
SearchRequest request,
int numShards,
Consumer<Exception> onPartialMergeFailure) {
return new QueryPhaseResultConsumer(request, executor, this, listener, namedWriteableRegistry, numShards, onPartialMergeFailure);
}

static final class TopDocsStats {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
/**
* A listener that allows to track progress of the {@link SearchAction}.
*/
abstract class SearchProgressListener {
public abstract class SearchProgressListener {
private static final Logger logger = LogManager.getLogger(SearchProgressListener.class);

public static final SearchProgressListener NOOP = new SearchProgressListener() {};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
import org.elasticsearch.search.DocValueFormat;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.SearchPhaseResult;
import org.elasticsearch.search.SearchShardTarget;
import org.elasticsearch.search.fetch.FetchSearchResult;
import org.elasticsearch.search.fetch.QueryFetchSearchResult;
Expand All @@ -53,7 +52,7 @@ public void testShortcutQueryAndFetchOptimization() {
SearchPhaseController controller = new SearchPhaseController(
writableRegistry(), s -> InternalAggregationTestCase.emptyReduceContextBuilder());
MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(1);
ArraySearchPhaseResults<SearchPhaseResult> results = controller.newSearchPhaseResults(EsExecutors.newDirectExecutorService(),
QueryPhaseResultConsumer results = controller.newSearchPhaseResults(EsExecutors.newDirectExecutorService(),
NOOP, mockSearchPhaseContext.getRequest(), 1, exc -> {});
boolean hasHits = randomBoolean();
final int numHits;
Expand Down Expand Up @@ -97,7 +96,7 @@ public void testFetchTwoDocument() {
MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(2);
SearchPhaseController controller = new SearchPhaseController(
writableRegistry(), s -> InternalAggregationTestCase.emptyReduceContextBuilder());
ArraySearchPhaseResults<SearchPhaseResult> results = controller.newSearchPhaseResults(EsExecutors.newDirectExecutorService(),
QueryPhaseResultConsumer results = controller.newSearchPhaseResults(EsExecutors.newDirectExecutorService(),
NOOP, mockSearchPhaseContext.getRequest(), 2, exc -> {});
int resultSetSize = randomIntBetween(2, 10);
final SearchContextId ctx1 = new SearchContextId(UUIDs.randomBase64UUID(), 123);
Expand Down Expand Up @@ -158,7 +157,7 @@ public void testFailFetchOneDoc() {
MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(2);
SearchPhaseController controller = new SearchPhaseController(
writableRegistry(), s -> InternalAggregationTestCase.emptyReduceContextBuilder());
ArraySearchPhaseResults<SearchPhaseResult> results = controller.newSearchPhaseResults(EsExecutors.newDirectExecutorService(),
QueryPhaseResultConsumer results = controller.newSearchPhaseResults(EsExecutors.newDirectExecutorService(),
NOOP, mockSearchPhaseContext.getRequest(), 2, exc -> {});
int resultSetSize = randomIntBetween(2, 10);
SearchContextId ctx1 = new SearchContextId(UUIDs.randomBase64UUID(), 123);
Expand Down Expand Up @@ -223,7 +222,7 @@ public void testFetchDocsConcurrently() throws InterruptedException {
SearchPhaseController controller = new SearchPhaseController(
writableRegistry(), s -> InternalAggregationTestCase.emptyReduceContextBuilder());
MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(numHits);
ArraySearchPhaseResults<SearchPhaseResult> results = controller.newSearchPhaseResults(EsExecutors.newDirectExecutorService(), NOOP,
QueryPhaseResultConsumer results = controller.newSearchPhaseResults(EsExecutors.newDirectExecutorService(), NOOP,
mockSearchPhaseContext.getRequest(), numHits, exc -> {});
for (int i = 0; i < numHits; i++) {
QuerySearchResult queryResult = new QuerySearchResult(new SearchContextId("", i),
Expand Down Expand Up @@ -280,7 +279,7 @@ public void testExceptionFailsPhase() {
MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(2);
SearchPhaseController controller = new SearchPhaseController(
writableRegistry(), s -> InternalAggregationTestCase.emptyReduceContextBuilder());
ArraySearchPhaseResults<SearchPhaseResult> results =
QueryPhaseResultConsumer results =
controller.newSearchPhaseResults(EsExecutors.newDirectExecutorService(),
NOOP, mockSearchPhaseContext.getRequest(), 2, exc -> {});
int resultSetSize = randomIntBetween(2, 10);
Expand Down Expand Up @@ -338,7 +337,7 @@ public void testCleanupIrrelevantContexts() { // contexts that are not fetched s
MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(2);
SearchPhaseController controller = new SearchPhaseController(
writableRegistry(), s -> InternalAggregationTestCase.emptyReduceContextBuilder());
ArraySearchPhaseResults<SearchPhaseResult> results = controller.newSearchPhaseResults(EsExecutors.newDirectExecutorService(),
QueryPhaseResultConsumer results = controller.newSearchPhaseResults(EsExecutors.newDirectExecutorService(),
NOOP, mockSearchPhaseContext.getRequest(), 2, exc -> {});
int resultSetSize = 1;
SearchContextId ctx1 = new SearchContextId(UUIDs.randomBase64UUID(), 123);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,7 @@ public void testConsumerOnlyAggs() throws Exception {
SearchRequest request = randomSearchRequest();
request.source(new SearchSourceBuilder().aggregation(AggregationBuilders.avg("foo")).size(0));
request.setBatchedReduceSize(bufferSize);
ArraySearchPhaseResults<SearchPhaseResult> consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor,
QueryPhaseResultConsumer consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor,
NOOP, request, expectedNumResults, exc -> {});
AtomicInteger max = new AtomicInteger();
CountDownLatch latch = new CountDownLatch(expectedNumResults);
Expand Down Expand Up @@ -596,7 +596,7 @@ public void testConsumerOnlyHits() throws Exception {
request.source(new SearchSourceBuilder().size(randomIntBetween(1, 10)));
}
request.setBatchedReduceSize(bufferSize);
ArraySearchPhaseResults<SearchPhaseResult> consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor,
QueryPhaseResultConsumer consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor,
NOOP, request, expectedNumResults, exc -> {});
AtomicInteger max = new AtomicInteger();
CountDownLatch latch = new CountDownLatch(expectedNumResults);
Expand Down Expand Up @@ -639,7 +639,7 @@ public void testReduceTopNWithFromOffset() throws Exception {
SearchRequest request = new SearchRequest();
request.source(new SearchSourceBuilder().size(5).from(5));
request.setBatchedReduceSize(randomIntBetween(2, 4));
ArraySearchPhaseResults<SearchPhaseResult> consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor,
QueryPhaseResultConsumer consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor,
NOOP, request, 4, exc -> {});
int score = 100;
CountDownLatch latch = new CountDownLatch(4);
Expand Down Expand Up @@ -677,7 +677,7 @@ public void testConsumerSortByField() throws Exception {
SearchRequest request = randomSearchRequest();
int size = randomIntBetween(1, 10);
request.setBatchedReduceSize(bufferSize);
ArraySearchPhaseResults<SearchPhaseResult> consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor,
QueryPhaseResultConsumer consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor,
NOOP, request, expectedNumResults, exc -> {});
AtomicInteger max = new AtomicInteger();
SortField[] sortFields = {new SortField("field", SortField.Type.INT, true)};
Expand Down Expand Up @@ -715,7 +715,7 @@ public void testConsumerFieldCollapsing() throws Exception {
SearchRequest request = randomSearchRequest();
int size = randomIntBetween(5, 10);
request.setBatchedReduceSize(bufferSize);
ArraySearchPhaseResults<SearchPhaseResult> consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor,
QueryPhaseResultConsumer consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor,
NOOP, request, expectedNumResults, exc -> {});
SortField[] sortFields = {new SortField("field", SortField.Type.STRING)};
BytesRef a = new BytesRef("a");
Expand Down Expand Up @@ -756,7 +756,7 @@ public void testConsumerSuggestions() throws Exception {
int bufferSize = randomIntBetween(2, 200);
SearchRequest request = randomSearchRequest();
request.setBatchedReduceSize(bufferSize);
ArraySearchPhaseResults<SearchPhaseResult> consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor,
QueryPhaseResultConsumer consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor,
NOOP, request, expectedNumResults, exc -> {});
int maxScoreTerm = -1;
int maxScorePhrase = -1;
Expand Down Expand Up @@ -882,7 +882,7 @@ public void onFinalReduce(List<SearchShard> shards, TotalHits totalHits, Interna
assertEquals(numReduceListener.incrementAndGet(), reducePhase);
}
};
ArraySearchPhaseResults<SearchPhaseResult> consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor,
QueryPhaseResultConsumer consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor,
progressListener, request, expectedNumResults, exc -> {});
AtomicInteger max = new AtomicInteger();
Thread[] threads = new Thread[expectedNumResults];
Expand Down Expand Up @@ -940,7 +940,7 @@ public void testPartialMergeFailure() throws InterruptedException {
request.source(new SearchSourceBuilder().aggregation(AggregationBuilders.avg("foo")).size(0));
request.setBatchedReduceSize(bufferSize);
AtomicBoolean hasConsumedFailure = new AtomicBoolean();
ArraySearchPhaseResults<SearchPhaseResult> consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor,
QueryPhaseResultConsumer consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor,
NOOP, request, expectedNumResults, exc -> hasConsumedFailure.set(true));
CountDownLatch latch = new CountDownLatch(expectedNumResults);
Thread[] threads = new Thread[expectedNumResults];
Expand Down

0 comments on commit a0e4331

Please sign in to comment.