diff --git a/sandbox/plugins/concurrent-search/build.gradle b/sandbox/plugins/concurrent-search/build.gradle new file mode 100644 index 0000000000000..acc3cb5092cd8 --- /dev/null +++ b/sandbox/plugins/concurrent-search/build.gradle @@ -0,0 +1,42 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +apply plugin: 'opensearch.opensearchplugin' +apply plugin: 'opensearch.yaml-rest-test' + +opensearchplugin { + name 'concurrent-search' + description 'The experimental plugin which implements concurrent search over Apache Lucene segments' + classname 'org.opensearch.search.ConcurrentSegmentSearchPlugin' + licenseFile rootProject.file('licenses/APACHE-LICENSE-2.0.txt') + noticeFile rootProject.file('NOTICE.txt') +} + +yamlRestTest.enabled = false; +testingConventions.enabled = false; \ No newline at end of file diff --git a/sandbox/plugins/concurrent-search/src/main/java/org/opensearch/search/ConcurrentSegmentSearchPlugin.java b/sandbox/plugins/concurrent-search/src/main/java/org/opensearch/search/ConcurrentSegmentSearchPlugin.java new file mode 100644 index 0000000000000..da999e40f0f07 --- /dev/null +++ b/sandbox/plugins/concurrent-search/src/main/java/org/opensearch/search/ConcurrentSegmentSearchPlugin.java @@ -0,0 +1,53 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search; + +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.OpenSearchExecutors; +import org.opensearch.plugins.Plugin; +import org.opensearch.plugins.SearchPlugin; +import org.opensearch.search.query.ConcurrentQueryPhaseSearcher; +import org.opensearch.search.query.QueryPhaseSearcher; +import org.opensearch.threadpool.ExecutorBuilder; +import org.opensearch.threadpool.FixedExecutorBuilder; +import org.opensearch.threadpool.ThreadPool; + +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +/** + * The experimental plugin which implements the concurrent search over Apache Lucene segments. + */ +public class ConcurrentSegmentSearchPlugin extends Plugin implements SearchPlugin { + private static final String INDEX_SEARCHER = "index_searcher"; + + /** + * Default constructor + */ + public ConcurrentSegmentSearchPlugin() {} + + @Override + public Optional getQueryPhaseSearcher() { + return Optional.of(new ConcurrentQueryPhaseSearcher()); + } + + @Override + public List> getExecutorBuilders(Settings settings) { + final int allocatedProcessors = OpenSearchExecutors.allocatedProcessors(settings); + return Collections.singletonList( + new FixedExecutorBuilder(settings, INDEX_SEARCHER, allocatedProcessors, 1000, "thread_pool." + INDEX_SEARCHER) + ); + } + + @Override + public Optional getIndexSearcherExecutorProvider() { + return Optional.of((ThreadPool threadPool) -> threadPool.executor(INDEX_SEARCHER)); + } +} diff --git a/sandbox/plugins/concurrent-search/src/main/java/org/opensearch/search/package-info.java b/sandbox/plugins/concurrent-search/src/main/java/org/opensearch/search/package-info.java new file mode 100644 index 0000000000000..041f914fab7d7 --- /dev/null +++ b/sandbox/plugins/concurrent-search/src/main/java/org/opensearch/search/package-info.java @@ -0,0 +1,12 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/** + * The implementation of the experimental plugin which implements the concurrent search over Apache Lucene segments. + */ +package org.opensearch.search; diff --git a/sandbox/plugins/concurrent-search/src/main/java/org/opensearch/search/query/ConcurrentQueryPhaseSearcher.java b/sandbox/plugins/concurrent-search/src/main/java/org/opensearch/search/query/ConcurrentQueryPhaseSearcher.java new file mode 100644 index 0000000000000..0acf34c946df7 --- /dev/null +++ b/sandbox/plugins/concurrent-search/src/main/java/org/opensearch/search/query/ConcurrentQueryPhaseSearcher.java @@ -0,0 +1,124 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.query; + +import static org.opensearch.search.query.TopDocsCollectorContext.createTopDocsCollectorContext; + +import java.io.IOException; +import java.util.LinkedList; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.lucene.search.Collector; +import org.apache.lucene.search.CollectorManager; +import org.apache.lucene.search.Query; +import org.opensearch.search.internal.ContextIndexSearcher; +import org.opensearch.search.internal.SearchContext; +import org.opensearch.search.profile.query.ProfileCollectorManager; +import org.opensearch.search.query.QueryPhase.DefaultQueryPhaseSearcher; +import org.opensearch.search.query.QueryPhase.TimeExceededException; + +/** + * The implementation of the {@link QueryPhaseSearcher} which attempts to use concurrent + * search of Apache Lucene segments if it has been enabled. + */ +public class ConcurrentQueryPhaseSearcher extends DefaultQueryPhaseSearcher { + private static final Logger LOGGER = LogManager.getLogger(ConcurrentQueryPhaseSearcher.class); + + /** + * Default constructor + */ + public ConcurrentQueryPhaseSearcher() {} + + @Override + protected boolean searchWithCollector( + SearchContext searchContext, + ContextIndexSearcher searcher, + Query query, + LinkedList collectors, + boolean hasFilterCollector, + boolean hasTimeout + ) throws IOException { + boolean couldUseConcurrentSegmentSearch = allowConcurrentSegmentSearch(searcher); + + // TODO: support aggregations + if (searchContext.aggregations() != null) { + couldUseConcurrentSegmentSearch = false; + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("Unable to use concurrent search over index segments (experimental): aggregations are present"); + } + } + + if (couldUseConcurrentSegmentSearch) { + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("Using concurrent search over index segments (experimental)"); + } + + return searchWithCollectorManager(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout); + } else { + return super.searchWithCollector(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout); + } + } + + private static boolean searchWithCollectorManager( + SearchContext searchContext, + ContextIndexSearcher searcher, + Query query, + LinkedList collectorContexts, + boolean hasFilterCollector, + boolean timeoutSet + ) throws IOException { + // create the top docs collector last when the other collectors are known + final TopDocsCollectorContext topDocsFactory = createTopDocsCollectorContext(searchContext, hasFilterCollector); + // add the top docs collector, the first collector context in the chain + collectorContexts.addFirst(topDocsFactory); + + final QuerySearchResult queryResult = searchContext.queryResult(); + final CollectorManager collectorManager; + + // TODO: support aggregations in concurrent segment search flow + if (searchContext.aggregations() != null) { + throw new UnsupportedOperationException("The concurrent segment search does not support aggregations yet"); + } + + if (searchContext.getProfilers() != null) { + final ProfileCollectorManager profileCollectorManager = + QueryCollectorManagerContext.createQueryCollectorManagerWithProfiler(collectorContexts); + searchContext.getProfilers().getCurrentQueryProfiler().setCollector(profileCollectorManager); + collectorManager = profileCollectorManager; + } else { + // Create multi collector manager instance + collectorManager = QueryCollectorManagerContext.createMultiCollectorManager(collectorContexts); + } + + try { + final ReduceableSearchResult result = searcher.search(query, collectorManager); + result.reduce(queryResult); + } catch (EarlyTerminatingCollector.EarlyTerminationException e) { + queryResult.terminatedEarly(true); + } catch (TimeExceededException e) { + assert timeoutSet : "TimeExceededException thrown even though timeout wasn't set"; + if (searchContext.request().allowPartialSearchResults() == false) { + // Can't rethrow TimeExceededException because not serializable + throw new QueryPhaseExecutionException(searchContext.shardTarget(), "Time exceeded"); + } + queryResult.searchTimedOut(true); + } + if (searchContext.terminateAfter() != SearchContext.DEFAULT_TERMINATE_AFTER && queryResult.terminatedEarly() == null) { + queryResult.terminatedEarly(false); + } + + return topDocsFactory.shouldRescore(); + } + + private static boolean allowConcurrentSegmentSearch(final ContextIndexSearcher searcher) { + return (searcher.getExecutor() != null); + } + +} diff --git a/sandbox/plugins/concurrent-search/src/main/java/org/opensearch/search/query/package-info.java b/sandbox/plugins/concurrent-search/src/main/java/org/opensearch/search/query/package-info.java new file mode 100644 index 0000000000000..0f98ae7682a84 --- /dev/null +++ b/sandbox/plugins/concurrent-search/src/main/java/org/opensearch/search/query/package-info.java @@ -0,0 +1,12 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/** + * {@link org.opensearch.search.query.QueryPhaseSearcher} implementation for concurrent search + */ +package org.opensearch.search.query; diff --git a/sandbox/plugins/concurrent-search/src/test/java/org/opensearch/search/profile/query/QueryProfilerTests.java b/sandbox/plugins/concurrent-search/src/test/java/org/opensearch/search/profile/query/QueryProfilerTests.java new file mode 100644 index 0000000000000..51cb3c8c0cddc --- /dev/null +++ b/sandbox/plugins/concurrent-search/src/test/java/org/opensearch/search/profile/query/QueryProfilerTests.java @@ -0,0 +1,316 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.profile.query; + +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + +import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field.Store; +import org.apache.lucene.document.StringField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.Term; +import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.LRUQueryCache; +import org.apache.lucene.search.LeafCollector; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryCachingPolicy; +import org.apache.lucene.search.QueryVisitor; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.ScorerSupplier; +import org.apache.lucene.search.Sort; +import org.apache.lucene.search.TermQuery; +import org.apache.lucene.search.TotalHitCountCollector; +import org.apache.lucene.search.Weight; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.index.RandomIndexWriter; +import org.apache.lucene.tests.search.RandomApproximationQuery; +import org.apache.lucene.tests.util.TestUtil; +import org.opensearch.core.internal.io.IOUtils; +import org.opensearch.search.internal.ContextIndexSearcher; +import org.opensearch.search.profile.ProfileResult; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; + +public class QueryProfilerTests extends OpenSearchTestCase { + + private Directory dir; + private IndexReader reader; + private ContextIndexSearcher searcher; + private ExecutorService executor; + + @ParametersFactory + public static Collection concurrency() { + return Arrays.asList(new Integer[] { 0 }, new Integer[] { 5 }); + } + + public QueryProfilerTests(int concurrency) { + this.executor = (concurrency > 0) ? Executors.newFixedThreadPool(concurrency) : null; + } + + @Before + public void setUp() throws Exception { + super.setUp(); + + dir = newDirectory(); + RandomIndexWriter w = new RandomIndexWriter(random(), dir); + final int numDocs = TestUtil.nextInt(random(), 1, 20); + for (int i = 0; i < numDocs; ++i) { + final int numHoles = random().nextInt(5); + for (int j = 0; j < numHoles; ++j) { + w.addDocument(new Document()); + } + Document doc = new Document(); + doc.add(new StringField("foo", "bar", Store.NO)); + w.addDocument(doc); + } + reader = w.getReader(); + w.close(); + searcher = new ContextIndexSearcher( + reader, + IndexSearcher.getDefaultSimilarity(), + IndexSearcher.getDefaultQueryCache(), + ALWAYS_CACHE_POLICY, + true, + executor + ); + } + + @After + public void tearDown() throws Exception { + super.tearDown(); + + LRUQueryCache cache = (LRUQueryCache) searcher.getQueryCache(); + assertThat(cache.getHitCount(), equalTo(0L)); + assertThat(cache.getCacheCount(), equalTo(0L)); + assertThat(cache.getTotalCount(), equalTo(cache.getMissCount())); + assertThat(cache.getCacheSize(), equalTo(0L)); + + if (executor != null) { + ThreadPool.terminate(executor, 10, TimeUnit.SECONDS); + } + + IOUtils.close(reader, dir); + dir = null; + reader = null; + searcher = null; + } + + public void testBasic() throws IOException { + QueryProfiler profiler = new QueryProfiler(executor != null); + searcher.setProfiler(profiler); + Query query = new TermQuery(new Term("foo", "bar")); + searcher.search(query, 1); + List results = profiler.getTree(); + assertEquals(1, results.size()); + Map breakdown = results.get(0).getTimeBreakdown(); + assertThat(breakdown.get(QueryTimingType.CREATE_WEIGHT.toString()), greaterThan(0L)); + assertThat(breakdown.get(QueryTimingType.BUILD_SCORER.toString()), greaterThan(0L)); + assertThat(breakdown.get(QueryTimingType.NEXT_DOC.toString()), greaterThan(0L)); + assertThat(breakdown.get(QueryTimingType.ADVANCE.toString()), equalTo(0L)); + assertThat(breakdown.get(QueryTimingType.SCORE.toString()), greaterThan(0L)); + assertThat(breakdown.get(QueryTimingType.MATCH.toString()), equalTo(0L)); + + assertThat(breakdown.get(QueryTimingType.CREATE_WEIGHT.toString() + "_count"), greaterThan(0L)); + assertThat(breakdown.get(QueryTimingType.BUILD_SCORER.toString() + "_count"), greaterThan(0L)); + assertThat(breakdown.get(QueryTimingType.NEXT_DOC.toString() + "_count"), greaterThan(0L)); + assertThat(breakdown.get(QueryTimingType.ADVANCE.toString() + "_count"), equalTo(0L)); + assertThat(breakdown.get(QueryTimingType.SCORE.toString() + "_count"), greaterThan(0L)); + assertThat(breakdown.get(QueryTimingType.MATCH.toString() + "_count"), equalTo(0L)); + + long rewriteTime = profiler.getRewriteTime(); + assertThat(rewriteTime, greaterThan(0L)); + } + + public void testNoScoring() throws IOException { + QueryProfiler profiler = new QueryProfiler(executor != null); + searcher.setProfiler(profiler); + Query query = new TermQuery(new Term("foo", "bar")); + searcher.search(query, 1, Sort.INDEXORDER); // scores are not needed + List results = profiler.getTree(); + assertEquals(1, results.size()); + Map breakdown = results.get(0).getTimeBreakdown(); + assertThat(breakdown.get(QueryTimingType.CREATE_WEIGHT.toString()), greaterThan(0L)); + assertThat(breakdown.get(QueryTimingType.BUILD_SCORER.toString()), greaterThan(0L)); + assertThat(breakdown.get(QueryTimingType.NEXT_DOC.toString()), greaterThan(0L)); + assertThat(breakdown.get(QueryTimingType.ADVANCE.toString()), equalTo(0L)); + assertThat(breakdown.get(QueryTimingType.SCORE.toString()), equalTo(0L)); + assertThat(breakdown.get(QueryTimingType.MATCH.toString()), equalTo(0L)); + + assertThat(breakdown.get(QueryTimingType.CREATE_WEIGHT.toString() + "_count"), greaterThan(0L)); + assertThat(breakdown.get(QueryTimingType.BUILD_SCORER.toString() + "_count"), greaterThan(0L)); + assertThat(breakdown.get(QueryTimingType.NEXT_DOC.toString() + "_count"), greaterThan(0L)); + assertThat(breakdown.get(QueryTimingType.ADVANCE.toString() + "_count"), equalTo(0L)); + assertThat(breakdown.get(QueryTimingType.SCORE.toString() + "_count"), equalTo(0L)); + assertThat(breakdown.get(QueryTimingType.MATCH.toString() + "_count"), equalTo(0L)); + + long rewriteTime = profiler.getRewriteTime(); + assertThat(rewriteTime, greaterThan(0L)); + } + + public void testUseIndexStats() throws IOException { + QueryProfiler profiler = new QueryProfiler(executor != null); + searcher.setProfiler(profiler); + Query query = new TermQuery(new Term("foo", "bar")); + searcher.count(query); // will use index stats + List results = profiler.getTree(); + assertEquals(1, results.size()); + ProfileResult result = results.get(0); + assertEquals(0, (long) result.getTimeBreakdown().get("build_scorer_count")); + + long rewriteTime = profiler.getRewriteTime(); + assertThat(rewriteTime, greaterThan(0L)); + } + + public void testApproximations() throws IOException { + QueryProfiler profiler = new QueryProfiler(executor != null); + searcher.setProfiler(profiler); + Query query = new RandomApproximationQuery(new TermQuery(new Term("foo", "bar")), random()); + searcher.count(query); + List results = profiler.getTree(); + assertEquals(1, results.size()); + Map breakdown = results.get(0).getTimeBreakdown(); + assertThat(breakdown.get(QueryTimingType.CREATE_WEIGHT.toString()), greaterThan(0L)); + assertThat(breakdown.get(QueryTimingType.BUILD_SCORER.toString()), greaterThan(0L)); + assertThat(breakdown.get(QueryTimingType.NEXT_DOC.toString()), greaterThan(0L)); + assertThat(breakdown.get(QueryTimingType.ADVANCE.toString()), equalTo(0L)); + assertThat(breakdown.get(QueryTimingType.SCORE.toString()), equalTo(0L)); + assertThat(breakdown.get(QueryTimingType.MATCH.toString()), greaterThan(0L)); + + assertThat(breakdown.get(QueryTimingType.CREATE_WEIGHT.toString() + "_count"), greaterThan(0L)); + assertThat(breakdown.get(QueryTimingType.BUILD_SCORER.toString() + "_count"), greaterThan(0L)); + assertThat(breakdown.get(QueryTimingType.NEXT_DOC.toString() + "_count"), greaterThan(0L)); + assertThat(breakdown.get(QueryTimingType.ADVANCE.toString() + "_count"), equalTo(0L)); + assertThat(breakdown.get(QueryTimingType.SCORE.toString() + "_count"), equalTo(0L)); + assertThat(breakdown.get(QueryTimingType.MATCH.toString() + "_count"), greaterThan(0L)); + + long rewriteTime = profiler.getRewriteTime(); + assertThat(rewriteTime, greaterThan(0L)); + } + + public void testCollector() throws IOException { + TotalHitCountCollector collector = new TotalHitCountCollector(); + ProfileCollector profileCollector = new ProfileCollector(collector); + assertEquals(0, profileCollector.getTime()); + final LeafCollector leafCollector = profileCollector.getLeafCollector(reader.leaves().get(0)); + assertThat(profileCollector.getTime(), greaterThan(0L)); + long time = profileCollector.getTime(); + leafCollector.setScorer(null); + assertThat(profileCollector.getTime(), greaterThan(time)); + time = profileCollector.getTime(); + leafCollector.collect(0); + assertThat(profileCollector.getTime(), greaterThan(time)); + } + + private static class DummyQuery extends Query { + + @Override + public String toString(String field) { + return getClass().getSimpleName(); + } + + @Override + public boolean equals(Object obj) { + return this == obj; + } + + @Override + public int hashCode() { + return 0; + } + + @Override + public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { + return new Weight(this) { + @Override + public Explanation explain(LeafReaderContext context, int doc) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public Scorer scorer(LeafReaderContext context) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException { + return new ScorerSupplier() { + + @Override + public Scorer get(long loadCost) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public long cost() { + return 42; + } + }; + } + + @Override + public boolean isCacheable(LeafReaderContext ctx) { + return true; + } + }; + } + + @Override + public void visit(QueryVisitor visitor) { + visitor.visitLeaf(this); + } + } + + public void testScorerSupplier() throws IOException { + Directory dir = newDirectory(); + IndexWriter w = new IndexWriter(dir, newIndexWriterConfig()); + w.addDocument(new Document()); + DirectoryReader reader = DirectoryReader.open(w); + w.close(); + IndexSearcher s = newSearcher(reader); + s.setQueryCache(null); + Weight weight = s.createWeight(s.rewrite(new DummyQuery()), randomFrom(ScoreMode.values()), 1f); + // exception when getting the scorer + expectThrows(UnsupportedOperationException.class, () -> weight.scorer(s.getIndexReader().leaves().get(0))); + // no exception, means scorerSupplier is delegated + weight.scorerSupplier(s.getIndexReader().leaves().get(0)); + reader.close(); + dir.close(); + } + + private static final QueryCachingPolicy ALWAYS_CACHE_POLICY = new QueryCachingPolicy() { + + @Override + public void onUse(Query query) {} + + @Override + public boolean shouldCache(Query query) throws IOException { + return true; + } + + }; +} diff --git a/sandbox/plugins/concurrent-search/src/test/java/org/opensearch/search/query/QueryPhaseTests.java b/sandbox/plugins/concurrent-search/src/test/java/org/opensearch/search/query/QueryPhaseTests.java new file mode 100644 index 0000000000000..83a0a63a6a5c8 --- /dev/null +++ b/sandbox/plugins/concurrent-search/src/test/java/org/opensearch/search/query/QueryPhaseTests.java @@ -0,0 +1,1335 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.search.query; + +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + +import org.apache.lucene.analysis.standard.StandardAnalyzer; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field.Store; +import org.apache.lucene.document.LatLonDocValuesField; +import org.apache.lucene.document.LatLonPoint; +import org.apache.lucene.document.LongPoint; +import org.apache.lucene.document.NumericDocValuesField; +import org.apache.lucene.document.SortedDocValuesField; +import org.apache.lucene.document.SortedSetDocValuesField; +import org.apache.lucene.document.StringField; +import org.apache.lucene.document.TextField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.NoMergePolicy; +import org.apache.lucene.index.Term; +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.BooleanClause.Occur; +import org.apache.lucene.search.BooleanQuery; +import org.apache.lucene.search.Collector; +import org.apache.lucene.search.CollectorManager; +import org.apache.lucene.search.ConstantScoreQuery; +import org.apache.lucene.search.DocValuesFieldExistsQuery; +import org.apache.lucene.search.FieldComparator; +import org.apache.lucene.search.FieldDoc; +import org.apache.lucene.search.FilterCollector; +import org.apache.lucene.search.FilterLeafCollector; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.LeafCollector; +import org.apache.lucene.search.MatchAllDocsQuery; +import org.apache.lucene.search.MatchNoDocsQuery; +import org.apache.lucene.search.MultiTermQuery; +import org.apache.lucene.search.PrefixQuery; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.Sort; +import org.apache.lucene.search.SortField; +import org.apache.lucene.search.TermQuery; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TopFieldDocs; +import org.apache.lucene.search.TotalHitCountCollector; +import org.apache.lucene.search.TotalHits; +import org.apache.lucene.search.Weight; +import org.apache.lucene.search.grouping.CollapseTopFieldDocs; +import org.apache.lucene.search.join.BitSetProducer; +import org.apache.lucene.search.join.ScoreMode; +import org.apache.lucene.queries.spans.SpanNearQuery; +import org.apache.lucene.queries.spans.SpanTermQuery; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.index.RandomIndexWriter; +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.FixedBitSet; +import org.opensearch.action.search.SearchShardTask; +import org.opensearch.common.settings.Settings; +import org.opensearch.index.mapper.DateFieldMapper; +import org.opensearch.index.mapper.MappedFieldType; +import org.opensearch.index.mapper.MapperService; +import org.opensearch.index.mapper.NumberFieldMapper; +import org.opensearch.index.mapper.NumberFieldMapper.NumberFieldType; +import org.opensearch.index.mapper.NumberFieldMapper.NumberType; +import org.opensearch.index.query.ParsedQuery; +import org.opensearch.index.query.QueryShardContext; +import org.opensearch.index.search.OpenSearchToParentBlockJoinQuery; +import org.opensearch.index.shard.IndexShard; +import org.opensearch.index.shard.IndexShardTestCase; +import org.opensearch.lucene.queries.MinDocQuery; +import org.opensearch.search.DocValueFormat; +import org.opensearch.search.collapse.CollapseBuilder; +import org.opensearch.search.internal.ContextIndexSearcher; +import org.opensearch.search.internal.ScrollContext; +import org.opensearch.search.internal.SearchContext; +import org.opensearch.search.sort.SortAndFormats; +import org.opensearch.tasks.TaskCancelledException; +import org.opensearch.test.TestSearchContext; +import org.opensearch.threadpool.ThreadPool; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; + +import static org.opensearch.search.query.TopDocsCollectorContext.hasInfMaxScore; +import static org.hamcrest.Matchers.anyOf; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.lessThanOrEqualTo; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.mockito.Mockito.spy; + +public class QueryPhaseTests extends IndexShardTestCase { + + private IndexShard indexShard; + private final ExecutorService executor; + private final QueryPhaseSearcher queryPhaseSearcher; + + @ParametersFactory + public static Collection concurrency() { + return Arrays.asList( + new Object[] { 0, QueryPhase.DEFAULT_QUERY_PHASE_SEARCHER }, + new Object[] { 5, new ConcurrentQueryPhaseSearcher() } + ); + } + + public QueryPhaseTests(int concurrency, QueryPhaseSearcher queryPhaseSearcher) { + this.executor = (concurrency > 0) ? Executors.newFixedThreadPool(concurrency) : null; + this.queryPhaseSearcher = queryPhaseSearcher; + } + + @Override + public Settings threadPoolSettings() { + return Settings.builder().put(super.threadPoolSettings()).put("thread_pool.search.min_queue_size", 10).build(); + } + + @Override + public void setUp() throws Exception { + super.setUp(); + indexShard = newShard(true); + } + + @Override + public void tearDown() throws Exception { + super.tearDown(); + closeShards(indexShard); + + if (executor != null) { + ThreadPool.terminate(executor, 10, TimeUnit.SECONDS); + } + } + + private void countTestCase(Query query, IndexReader reader, boolean shouldCollectSearch, boolean shouldCollectCount) throws Exception { + ContextIndexSearcher searcher = shouldCollectSearch + ? newContextSearcher(reader, executor) + : newEarlyTerminationContextSearcher(reader, 0, executor); + TestSearchContext context = new TestSearchContext(null, indexShard, searcher); + context.parsedQuery(new ParsedQuery(query)); + context.setSize(0); + context.setTask(new SearchShardTask(123L, "", "", "", null, Collections.emptyMap())); + final boolean rescore = QueryPhase.executeInternal(context.withCleanQueryResult(), queryPhaseSearcher); + assertFalse(rescore); + + ContextIndexSearcher countSearcher = shouldCollectCount + ? newContextSearcher(reader, executor) + : newEarlyTerminationContextSearcher(reader, 0, executor); + assertEquals(countSearcher.count(query), context.queryResult().topDocs().topDocs.totalHits.value); + } + + private void countTestCase(boolean withDeletions) throws Exception { + Directory dir = newDirectory(); + IndexWriterConfig iwc = newIndexWriterConfig().setMergePolicy(NoMergePolicy.INSTANCE); + RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc); + final int numDocs = scaledRandomIntBetween(600, 900); + for (int i = 0; i < numDocs; ++i) { + Document doc = new Document(); + if (randomBoolean()) { + doc.add(new StringField("foo", "bar", Store.NO)); + doc.add(new SortedSetDocValuesField("foo", new BytesRef("bar"))); + doc.add(new SortedSetDocValuesField("docValuesOnlyField", new BytesRef("bar"))); + doc.add(new LatLonDocValuesField("latLonDVField", 1.0, 1.0)); + doc.add(new LatLonPoint("latLonDVField", 1.0, 1.0)); + } + if (randomBoolean()) { + doc.add(new StringField("foo", "baz", Store.NO)); + doc.add(new SortedSetDocValuesField("foo", new BytesRef("baz"))); + } + if (withDeletions && (rarely() || i == 0)) { + doc.add(new StringField("delete", "yes", Store.NO)); + } + w.addDocument(doc); + } + if (withDeletions) { + w.deleteDocuments(new Term("delete", "yes")); + } + final IndexReader reader = w.getReader(); + Query matchAll = new MatchAllDocsQuery(); + Query matchAllCsq = new ConstantScoreQuery(matchAll); + Query tq = new TermQuery(new Term("foo", "bar")); + Query tCsq = new ConstantScoreQuery(tq); + Query dvfeq = new DocValuesFieldExistsQuery("foo"); + Query dvfeq_points = new DocValuesFieldExistsQuery("latLonDVField"); + Query dvfeqCsq = new ConstantScoreQuery(dvfeq); + // field with doc-values but not indexed will need to collect + Query dvOnlyfeq = new DocValuesFieldExistsQuery("docValuesOnlyField"); + BooleanQuery bq = new BooleanQuery.Builder().add(matchAll, Occur.SHOULD).add(tq, Occur.MUST).build(); + + countTestCase(matchAll, reader, false, false); + countTestCase(matchAllCsq, reader, false, false); + countTestCase(tq, reader, withDeletions, withDeletions); + countTestCase(tCsq, reader, withDeletions, withDeletions); + countTestCase(dvfeq, reader, withDeletions, true); + countTestCase(dvfeq_points, reader, withDeletions, true); + countTestCase(dvfeqCsq, reader, withDeletions, true); + countTestCase(dvOnlyfeq, reader, true, true); + countTestCase(bq, reader, true, true); + reader.close(); + w.close(); + dir.close(); + } + + public void testCountWithoutDeletions() throws Exception { + countTestCase(false); + } + + public void testCountWithDeletions() throws Exception { + countTestCase(true); + } + + public void testPostFilterDisablesCountOptimization() throws Exception { + Directory dir = newDirectory(); + final Sort sort = new Sort(new SortField("rank", SortField.Type.INT)); + IndexWriterConfig iwc = newIndexWriterConfig().setIndexSort(sort); + RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc); + Document doc = new Document(); + w.addDocument(doc); + w.close(); + + IndexReader reader = DirectoryReader.open(dir); + + TestSearchContext context = new TestSearchContext(null, indexShard, newEarlyTerminationContextSearcher(reader, 0, executor)); + context.setTask(new SearchShardTask(123L, "", "", "", null, Collections.emptyMap())); + context.parsedQuery(new ParsedQuery(new MatchAllDocsQuery())); + + QueryPhase.executeInternal(context.withCleanQueryResult(), queryPhaseSearcher); + assertEquals(1, context.queryResult().topDocs().topDocs.totalHits.value); + + context.setSearcher(newContextSearcher(reader, executor)); + context.parsedPostFilter(new ParsedQuery(new MatchNoDocsQuery())); + QueryPhase.executeInternal(context.withCleanQueryResult(), queryPhaseSearcher); + assertEquals(0, context.queryResult().topDocs().topDocs.totalHits.value); + reader.close(); + dir.close(); + } + + public void testTerminateAfterWithFilter() throws Exception { + Directory dir = newDirectory(); + final Sort sort = new Sort(new SortField("rank", SortField.Type.INT)); + IndexWriterConfig iwc = newIndexWriterConfig().setIndexSort(sort); + RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc); + Document doc = new Document(); + for (int i = 0; i < 10; i++) { + doc.add(new StringField("foo", Integer.toString(i), Store.NO)); + } + w.addDocument(doc); + w.close(); + + IndexReader reader = DirectoryReader.open(dir); + + TestSearchContext context = new TestSearchContext(null, indexShard, newContextSearcher(reader, executor)); + context.setTask(new SearchShardTask(123L, "", "", "", null, Collections.emptyMap())); + + context.parsedQuery(new ParsedQuery(new MatchAllDocsQuery())); + context.terminateAfter(1); + context.setSize(10); + for (int i = 0; i < 10; i++) { + context.parsedPostFilter(new ParsedQuery(new TermQuery(new Term("foo", Integer.toString(i))))); + QueryPhase.executeInternal(context.withCleanQueryResult(), queryPhaseSearcher); + assertEquals(1, context.queryResult().topDocs().topDocs.totalHits.value); + assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(1)); + } + reader.close(); + dir.close(); + } + + public void testMinScoreDisablesCountOptimization() throws Exception { + Directory dir = newDirectory(); + final Sort sort = new Sort(new SortField("rank", SortField.Type.INT)); + IndexWriterConfig iwc = newIndexWriterConfig().setIndexSort(sort); + RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc); + Document doc = new Document(); + w.addDocument(doc); + w.close(); + + IndexReader reader = DirectoryReader.open(dir); + TestSearchContext context = new TestSearchContext(null, indexShard, newEarlyTerminationContextSearcher(reader, 0, executor)); + context.parsedQuery(new ParsedQuery(new MatchAllDocsQuery())); + context.setSize(0); + context.setTask(new SearchShardTask(123L, "", "", "", null, Collections.emptyMap())); + QueryPhase.executeInternal(context.withCleanQueryResult(), queryPhaseSearcher); + assertEquals(1, context.queryResult().topDocs().topDocs.totalHits.value); + + context.minimumScore(100); + QueryPhase.executeInternal(context.withCleanQueryResult(), queryPhaseSearcher); + assertEquals(0, context.queryResult().topDocs().topDocs.totalHits.value); + assertEquals(TotalHits.Relation.EQUAL_TO, context.queryResult().topDocs().topDocs.totalHits.relation); + reader.close(); + dir.close(); + } + + public void testQueryCapturesThreadPoolStats() throws Exception { + Directory dir = newDirectory(); + IndexWriterConfig iwc = newIndexWriterConfig(); + RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc); + final int numDocs = scaledRandomIntBetween(600, 900); + for (int i = 0; i < numDocs; ++i) { + w.addDocument(new Document()); + } + w.close(); + IndexReader reader = DirectoryReader.open(dir); + TestSearchContext context = new TestSearchContext(null, indexShard, newContextSearcher(reader, executor)); + context.setTask(new SearchShardTask(123L, "", "", "", null, Collections.emptyMap())); + context.parsedQuery(new ParsedQuery(new MatchAllDocsQuery())); + + QueryPhase.executeInternal(context.withCleanQueryResult(), queryPhaseSearcher); + QuerySearchResult results = context.queryResult(); + assertThat(results.serviceTimeEWMA(), greaterThanOrEqualTo(0L)); + assertThat(results.nodeQueueSize(), greaterThanOrEqualTo(0)); + reader.close(); + dir.close(); + } + + public void testInOrderScrollOptimization() throws Exception { + Directory dir = newDirectory(); + final Sort sort = new Sort(new SortField("rank", SortField.Type.INT)); + IndexWriterConfig iwc = newIndexWriterConfig().setIndexSort(sort); + RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc); + final int numDocs = scaledRandomIntBetween(600, 900); + for (int i = 0; i < numDocs; ++i) { + w.addDocument(new Document()); + } + w.close(); + IndexReader reader = DirectoryReader.open(dir); + ScrollContext scrollContext = new ScrollContext(); + TestSearchContext context = new TestSearchContext(null, indexShard, newContextSearcher(reader, executor), scrollContext); + context.parsedQuery(new ParsedQuery(new MatchAllDocsQuery())); + scrollContext.lastEmittedDoc = null; + scrollContext.maxScore = Float.NaN; + scrollContext.totalHits = null; + context.setTask(new SearchShardTask(123L, "", "", "", null, Collections.emptyMap())); + int size = randomIntBetween(2, 5); + context.setSize(size); + + QueryPhase.executeInternal(context.withCleanQueryResult(), queryPhaseSearcher); + assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo((long) numDocs)); + assertNull(context.queryResult().terminatedEarly()); + assertThat(context.terminateAfter(), equalTo(0)); + assertThat(context.queryResult().getTotalHits().value, equalTo((long) numDocs)); + + context.setSearcher(newEarlyTerminationContextSearcher(reader, size, executor)); + QueryPhase.executeInternal(context.withCleanQueryResult(), queryPhaseSearcher); + assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo((long) numDocs)); + assertThat(context.terminateAfter(), equalTo(size)); + assertThat(context.queryResult().getTotalHits().value, equalTo((long) numDocs)); + assertThat(context.queryResult().topDocs().topDocs.scoreDocs[0].doc, greaterThanOrEqualTo(size)); + reader.close(); + dir.close(); + } + + public void testTerminateAfterEarlyTermination() throws Exception { + Directory dir = newDirectory(); + IndexWriterConfig iwc = newIndexWriterConfig(); + RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc); + final int numDocs = scaledRandomIntBetween(600, 900); + for (int i = 0; i < numDocs; ++i) { + Document doc = new Document(); + if (randomBoolean()) { + doc.add(new StringField("foo", "bar", Store.NO)); + } + if (randomBoolean()) { + doc.add(new StringField("foo", "baz", Store.NO)); + } + doc.add(new NumericDocValuesField("rank", numDocs - i)); + w.addDocument(doc); + } + w.close(); + final IndexReader reader = DirectoryReader.open(dir); + TestSearchContext context = new TestSearchContext(null, indexShard, newContextSearcher(reader, executor)); + context.setTask(new SearchShardTask(123L, "", "", "", null, Collections.emptyMap())); + context.parsedQuery(new ParsedQuery(new MatchAllDocsQuery())); + + context.terminateAfter(numDocs); + { + context.setSize(10); + final TestTotalHitCountCollectorManager manager = TestTotalHitCountCollectorManager.create(executor); + context.queryCollectorManagers().put(TotalHitCountCollector.class, manager); + QueryPhase.executeInternal(context.withCleanQueryResult(), queryPhaseSearcher); + assertFalse(context.queryResult().terminatedEarly()); + assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo((long) numDocs)); + assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(10)); + assertThat(manager.getTotalHits(), equalTo(numDocs)); + } + + context.terminateAfter(1); + { + context.setSize(1); + QueryPhase.executeInternal(context.withCleanQueryResult(), queryPhaseSearcher); + assertTrue(context.queryResult().terminatedEarly()); + assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo(1L)); + assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(1)); + + context.setSize(0); + QueryPhase.executeInternal(context.withCleanQueryResult(), queryPhaseSearcher); + assertTrue(context.queryResult().terminatedEarly()); + assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo(1L)); + assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(0)); + } + + { + context.setSize(1); + QueryPhase.executeInternal(context.withCleanQueryResult(), queryPhaseSearcher); + assertTrue(context.queryResult().terminatedEarly()); + assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo(1L)); + assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(1)); + } + { + context.setSize(1); + BooleanQuery bq = new BooleanQuery.Builder().add(new TermQuery(new Term("foo", "bar")), Occur.SHOULD) + .add(new TermQuery(new Term("foo", "baz")), Occur.SHOULD) + .build(); + context.parsedQuery(new ParsedQuery(bq)); + QueryPhase.executeInternal(context.withCleanQueryResult(), queryPhaseSearcher); + assertTrue(context.queryResult().terminatedEarly()); + assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo(1L)); + assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(1)); + + context.setSize(0); + context.parsedQuery(new ParsedQuery(bq)); + QueryPhase.executeInternal(context.withCleanQueryResult(), queryPhaseSearcher); + assertTrue(context.queryResult().terminatedEarly()); + assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo(1L)); + assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(0)); + } + { + context.setSize(1); + final TestTotalHitCountCollectorManager manager = TestTotalHitCountCollectorManager.create(executor, 1); + context.queryCollectorManagers().put(TotalHitCountCollector.class, manager); + QueryPhase.executeInternal(context.withCleanQueryResult(), queryPhaseSearcher); + assertTrue(context.queryResult().terminatedEarly()); + assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo(1L)); + assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(1)); + assertThat(manager.getTotalHits(), equalTo(1)); + context.queryCollectorManagers().clear(); + } + { + context.setSize(0); + final TestTotalHitCountCollectorManager manager = TestTotalHitCountCollectorManager.create(executor, 1); + context.queryCollectorManagers().put(TotalHitCountCollector.class, manager); + QueryPhase.executeInternal(context.withCleanQueryResult(), queryPhaseSearcher); + assertTrue(context.queryResult().terminatedEarly()); + assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo(1L)); + assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(0)); + assertThat(manager.getTotalHits(), equalTo(1)); + } + + // tests with trackTotalHits and terminateAfter + context.terminateAfter(10); + context.setSize(0); + for (int trackTotalHits : new int[] { -1, 3, 76, 100 }) { + context.trackTotalHitsUpTo(trackTotalHits); + final TestTotalHitCountCollectorManager manager = TestTotalHitCountCollectorManager.create(executor); + context.queryCollectorManagers().put(TotalHitCountCollector.class, manager); + QueryPhase.executeInternal(context.withCleanQueryResult(), queryPhaseSearcher); + assertTrue(context.queryResult().terminatedEarly()); + if (trackTotalHits == -1) { + assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo(0L)); + } else { + assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo((long) Math.min(trackTotalHits, 10))); + } + assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(0)); + // The concurrent search terminates the collection when the number of hits is reached by each + // concurrent collector. In this case, in general, the number of results are multiplied by the number of + // slices (as the unit of concurrency). To address that, we have to use the shared global state, + // much as HitsThresholdChecker does. + if (executor == null) { + assertThat(manager.getTotalHits(), equalTo(10)); + } + } + + context.terminateAfter(7); + context.setSize(10); + for (int trackTotalHits : new int[] { -1, 3, 75, 100 }) { + context.trackTotalHitsUpTo(trackTotalHits); + QueryPhase.executeInternal(context.withCleanQueryResult(), queryPhaseSearcher); + assertTrue(context.queryResult().terminatedEarly()); + if (trackTotalHits == -1) { + assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo(0L)); + } else { + assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo(7L)); + } + assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(7)); + } + reader.close(); + dir.close(); + } + + public void testIndexSortingEarlyTermination() throws Exception { + Directory dir = newDirectory(); + final Sort sort = new Sort(new SortField("rank", SortField.Type.INT)); + IndexWriterConfig iwc = newIndexWriterConfig().setIndexSort(sort); + RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc); + final int numDocs = scaledRandomIntBetween(600, 900); + for (int i = 0; i < numDocs; ++i) { + Document doc = new Document(); + if (randomBoolean()) { + doc.add(new StringField("foo", "bar", Store.NO)); + } + if (randomBoolean()) { + doc.add(new StringField("foo", "baz", Store.NO)); + } + doc.add(new NumericDocValuesField("rank", numDocs - i)); + w.addDocument(doc); + } + w.close(); + + final IndexReader reader = DirectoryReader.open(dir); + TestSearchContext context = new TestSearchContext(null, indexShard, newContextSearcher(reader, executor)); + context.parsedQuery(new ParsedQuery(new MatchAllDocsQuery())); + context.setSize(1); + context.setTask(new SearchShardTask(123L, "", "", "", null, Collections.emptyMap())); + context.sort(new SortAndFormats(sort, new DocValueFormat[] { DocValueFormat.RAW })); + + QueryPhase.executeInternal(context.withCleanQueryResult(), queryPhaseSearcher); + assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo((long) numDocs)); + assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(1)); + assertThat(context.queryResult().topDocs().topDocs.scoreDocs[0], instanceOf(FieldDoc.class)); + FieldDoc fieldDoc = (FieldDoc) context.queryResult().topDocs().topDocs.scoreDocs[0]; + assertThat(fieldDoc.fields[0], equalTo(1)); + + { + context.parsedPostFilter(new ParsedQuery(new MinDocQuery(1))); + QueryPhase.executeInternal(context.withCleanQueryResult(), queryPhaseSearcher); + assertNull(context.queryResult().terminatedEarly()); + assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo(numDocs - 1L)); + assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(1)); + assertThat(context.queryResult().topDocs().topDocs.scoreDocs[0], instanceOf(FieldDoc.class)); + assertThat(fieldDoc.fields[0], anyOf(equalTo(1), equalTo(2))); + context.parsedPostFilter(null); + + final TestTotalHitCountCollectorManager manager = TestTotalHitCountCollectorManager.create(executor, sort); + context.queryCollectorManagers().put(TotalHitCountCollector.class, manager); + QueryPhase.executeInternal(context.withCleanQueryResult(), queryPhaseSearcher); + assertNull(context.queryResult().terminatedEarly()); + assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo((long) numDocs)); + assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(1)); + assertThat(context.queryResult().topDocs().topDocs.scoreDocs[0], instanceOf(FieldDoc.class)); + assertThat(fieldDoc.fields[0], anyOf(equalTo(1), equalTo(2))); + // When searching concurrently, each executors short-circuits when "size" is reached, + // including total hits collector + assertThat(manager.getTotalHits(), lessThanOrEqualTo(numDocs)); + + context.queryCollectorManagers().clear(); + } + + { + context.setSearcher(newEarlyTerminationContextSearcher(reader, 1, executor)); + context.trackTotalHitsUpTo(SearchContext.TRACK_TOTAL_HITS_DISABLED); + QueryPhase.executeInternal(context.withCleanQueryResult(), queryPhaseSearcher); + assertNull(context.queryResult().terminatedEarly()); + assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(1)); + assertThat(context.queryResult().topDocs().topDocs.scoreDocs[0], instanceOf(FieldDoc.class)); + assertThat(fieldDoc.fields[0], anyOf(equalTo(1), equalTo(2))); + + QueryPhase.executeInternal(context.withCleanQueryResult(), queryPhaseSearcher); + assertNull(context.queryResult().terminatedEarly()); + assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(1)); + assertThat(context.queryResult().topDocs().topDocs.scoreDocs[0], instanceOf(FieldDoc.class)); + assertThat(fieldDoc.fields[0], anyOf(equalTo(1), equalTo(2))); + } + reader.close(); + dir.close(); + } + + public void testIndexSortScrollOptimization() throws Exception { + Directory dir = newDirectory(); + final Sort indexSort = new Sort(new SortField("rank", SortField.Type.INT), new SortField("tiebreaker", SortField.Type.INT)); + IndexWriterConfig iwc = newIndexWriterConfig().setIndexSort(indexSort); + RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc); + final int numDocs = scaledRandomIntBetween(600, 900); + for (int i = 0; i < numDocs; ++i) { + Document doc = new Document(); + doc.add(new NumericDocValuesField("rank", random().nextInt())); + doc.add(new NumericDocValuesField("tiebreaker", i)); + w.addDocument(doc); + } + if (randomBoolean()) { + w.forceMerge(randomIntBetween(1, 10)); + } + w.close(); + + final IndexReader reader = DirectoryReader.open(dir); + List searchSortAndFormats = new ArrayList<>(); + searchSortAndFormats.add(new SortAndFormats(indexSort, new DocValueFormat[] { DocValueFormat.RAW, DocValueFormat.RAW })); + // search sort is a prefix of the index sort + searchSortAndFormats.add(new SortAndFormats(new Sort(indexSort.getSort()[0]), new DocValueFormat[] { DocValueFormat.RAW })); + for (SortAndFormats searchSortAndFormat : searchSortAndFormats) { + ScrollContext scrollContext = new ScrollContext(); + TestSearchContext context = new TestSearchContext(null, indexShard, newContextSearcher(reader, executor), scrollContext); + context.parsedQuery(new ParsedQuery(new MatchAllDocsQuery())); + scrollContext.lastEmittedDoc = null; + scrollContext.maxScore = Float.NaN; + scrollContext.totalHits = null; + context.setTask(new SearchShardTask(123L, "", "", "", null, Collections.emptyMap())); + context.setSize(10); + context.sort(searchSortAndFormat); + + QueryPhase.executeInternal(context.withCleanQueryResult(), queryPhaseSearcher); + assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo((long) numDocs)); + assertNull(context.queryResult().terminatedEarly()); + assertThat(context.terminateAfter(), equalTo(0)); + assertThat(context.queryResult().getTotalHits().value, equalTo((long) numDocs)); + int sizeMinus1 = context.queryResult().topDocs().topDocs.scoreDocs.length - 1; + FieldDoc lastDoc = (FieldDoc) context.queryResult().topDocs().topDocs.scoreDocs[sizeMinus1]; + + context.setSearcher(newEarlyTerminationContextSearcher(reader, 10, executor)); + QueryPhase.executeInternal(context.withCleanQueryResult(), queryPhaseSearcher); + assertNull(context.queryResult().terminatedEarly()); + assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo((long) numDocs)); + assertThat(context.terminateAfter(), equalTo(0)); + assertThat(context.queryResult().getTotalHits().value, equalTo((long) numDocs)); + FieldDoc firstDoc = (FieldDoc) context.queryResult().topDocs().topDocs.scoreDocs[0]; + for (int i = 0; i < searchSortAndFormat.sort.getSort().length; i++) { + @SuppressWarnings("unchecked") + FieldComparator comparator = (FieldComparator) searchSortAndFormat.sort.getSort()[i].getComparator( + i, + false + ); + int cmp = comparator.compareValues(firstDoc.fields[i], lastDoc.fields[i]); + if (cmp == 0) { + continue; + } + assertThat(cmp, equalTo(1)); + break; + } + } + reader.close(); + dir.close(); + } + + public void testDisableTopScoreCollection() throws Exception { + Directory dir = newDirectory(); + IndexWriterConfig iwc = newIndexWriterConfig(new StandardAnalyzer()); + RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc); + Document doc = new Document(); + final int numDocs = 2 * scaledRandomIntBetween(50, 450); + for (int i = 0; i < numDocs; i++) { + doc.clear(); + if (i % 2 == 0) { + doc.add(new TextField("title", "foo bar", Store.NO)); + } else { + doc.add(new TextField("title", "foo", Store.NO)); + } + w.addDocument(doc); + } + w.close(); + + IndexReader reader = DirectoryReader.open(dir); + TestSearchContext context = new TestSearchContext(null, indexShard, newContextSearcher(reader, executor)); + context.setTask(new SearchShardTask(123L, "", "", "", null, Collections.emptyMap())); + Query q = new SpanNearQuery.Builder("title", true).addClause(new SpanTermQuery(new Term("title", "foo"))) + .addClause(new SpanTermQuery(new Term("title", "bar"))) + .build(); + + context.parsedQuery(new ParsedQuery(q)); + context.setSize(3); + context.trackTotalHitsUpTo(3); + TopDocsCollectorContext topDocsContext = TopDocsCollectorContext.createTopDocsCollectorContext(context, false); + assertEquals(topDocsContext.create(null).scoreMode(), org.apache.lucene.search.ScoreMode.COMPLETE); + QueryPhase.executeInternal(context.withCleanQueryResult(), queryPhaseSearcher); + assertEquals(numDocs / 2, context.queryResult().topDocs().topDocs.totalHits.value); + assertEquals(context.queryResult().topDocs().topDocs.totalHits.relation, TotalHits.Relation.EQUAL_TO); + assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(3)); + + context.sort(new SortAndFormats(new Sort(new SortField("other", SortField.Type.INT)), new DocValueFormat[] { DocValueFormat.RAW })); + topDocsContext = TopDocsCollectorContext.createTopDocsCollectorContext(context, false); + assertEquals(topDocsContext.create(null).scoreMode(), org.apache.lucene.search.ScoreMode.TOP_DOCS); + QueryPhase.executeInternal(context.withCleanQueryResult(), queryPhaseSearcher); + assertEquals(numDocs / 2, context.queryResult().topDocs().topDocs.totalHits.value); + assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(3)); + assertEquals(context.queryResult().topDocs().topDocs.totalHits.relation, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO); + + reader.close(); + dir.close(); + } + + public void testEnhanceSortOnNumeric() throws Exception { + final String fieldNameLong = "long-field"; + final String fieldNameDate = "date-field"; + MappedFieldType fieldTypeLong = new NumberFieldMapper.NumberFieldType(fieldNameLong, NumberFieldMapper.NumberType.LONG); + MappedFieldType fieldTypeDate = new DateFieldMapper.DateFieldType(fieldNameDate); + MapperService mapperService = mock(MapperService.class); + when(mapperService.fieldType(fieldNameLong)).thenReturn(fieldTypeLong); + when(mapperService.fieldType(fieldNameDate)).thenReturn(fieldTypeDate); + // enough docs to have a tree with several leaf nodes + final int numDocs = 3500 * 5; + Directory dir = newDirectory(); + IndexWriter writer = new IndexWriter(dir, new IndexWriterConfig(null)); + long firstValue = randomLongBetween(-10000000L, 10000000L); + long longValue = firstValue; + long dateValue = randomLongBetween(0, 3000000000000L); + for (int i = 1; i <= numDocs; ++i) { + Document doc = new Document(); + + doc.add(new LongPoint(fieldNameLong, longValue)); + doc.add(new NumericDocValuesField(fieldNameLong, longValue)); + + doc.add(new LongPoint(fieldNameDate, dateValue)); + doc.add(new NumericDocValuesField(fieldNameDate, dateValue)); + writer.addDocument(doc); + longValue++; + dateValue++; + if (i % 3500 == 0) writer.commit(); + } + writer.close(); + final IndexReader reader = DirectoryReader.open(dir); + final SortField sortFieldLong = new SortField(fieldNameLong, SortField.Type.LONG); + sortFieldLong.setMissingValue(Long.MAX_VALUE); + final SortField sortFieldDate = new SortField(fieldNameDate, SortField.Type.LONG); + sortFieldDate.setMissingValue(Long.MAX_VALUE); + DocValueFormat dateFormat = fieldTypeDate.docValueFormat(null, null); + final Sort longSort = new Sort(sortFieldLong); + final Sort longDateSort = new Sort(sortFieldLong, sortFieldDate); + final Sort dateSort = new Sort(sortFieldDate); + final Sort dateLongSort = new Sort(sortFieldDate, sortFieldLong); + SortAndFormats longSortAndFormats = new SortAndFormats(longSort, new DocValueFormat[] { DocValueFormat.RAW }); + SortAndFormats longDateSortAndFormats = new SortAndFormats(longDateSort, new DocValueFormat[] { DocValueFormat.RAW, dateFormat }); + SortAndFormats dateSortAndFormats = new SortAndFormats(dateSort, new DocValueFormat[] { dateFormat }); + SortAndFormats dateLongSortAndFormats = new SortAndFormats(dateLongSort, new DocValueFormat[] { dateFormat, DocValueFormat.RAW }); + ParsedQuery query = new ParsedQuery(new MatchAllDocsQuery()); + SearchShardTask task = new SearchShardTask(123L, "", "", "", null, Collections.emptyMap()); + + // 1. Test a sort on long field + { + TestSearchContext searchContext = spy(new TestSearchContext(null, indexShard, newContextSearcher(reader, executor))); + when(searchContext.mapperService()).thenReturn(mapperService); + searchContext.sort(longSortAndFormats); + searchContext.parsedQuery(query); + searchContext.setTask(task); + searchContext.setSize(10); + QueryPhase.executeInternal(searchContext.withCleanQueryResult(), queryPhaseSearcher); + assertSortResults(searchContext.queryResult().topDocs().topDocs, (long) numDocs, false); + } + + // 2. Test a sort on long field + date field + { + TestSearchContext searchContext = spy(new TestSearchContext(null, indexShard, newContextSearcher(reader, executor))); + when(searchContext.mapperService()).thenReturn(mapperService); + searchContext.sort(longDateSortAndFormats); + searchContext.parsedQuery(query); + searchContext.setTask(task); + searchContext.setSize(10); + QueryPhase.executeInternal(searchContext.withCleanQueryResult(), queryPhaseSearcher); + assertSortResults(searchContext.queryResult().topDocs().topDocs, (long) numDocs, true); + } + + // 3. Test a sort on date field + { + TestSearchContext searchContext = spy(new TestSearchContext(null, indexShard, newContextSearcher(reader, executor))); + when(searchContext.mapperService()).thenReturn(mapperService); + searchContext.sort(dateSortAndFormats); + searchContext.parsedQuery(query); + searchContext.setTask(task); + searchContext.setSize(10); + QueryPhase.executeInternal(searchContext.withCleanQueryResult(), queryPhaseSearcher); + assertSortResults(searchContext.queryResult().topDocs().topDocs, (long) numDocs, false); + } + + // 4. Test a sort on date field + long field + { + TestSearchContext searchContext = spy(new TestSearchContext(null, indexShard, newContextSearcher(reader, executor))); + when(searchContext.mapperService()).thenReturn(mapperService); + searchContext.sort(dateLongSortAndFormats); + searchContext.parsedQuery(query); + searchContext.setTask(task); + searchContext.setSize(10); + QueryPhase.executeInternal(searchContext.withCleanQueryResult(), queryPhaseSearcher); + assertSortResults(searchContext.queryResult().topDocs().topDocs, (long) numDocs, true); + } + + // 5. Test that sort optimization is run when from > 0 and size = 0 + { + TestSearchContext searchContext = spy(new TestSearchContext(null, indexShard, newContextSearcher(reader, executor))); + when(searchContext.mapperService()).thenReturn(mapperService); + searchContext.sort(longSortAndFormats); + searchContext.parsedQuery(query); + searchContext.setTask(task); + searchContext.from(5); + searchContext.setSize(0); + QueryPhase.executeInternal(searchContext.withCleanQueryResult(), queryPhaseSearcher); + assertSortResults(searchContext.queryResult().topDocs().topDocs, (long) numDocs, false); + } + + // 6. Test that sort optimization works with from = 0 and size= 0 + { + TestSearchContext searchContext = spy(new TestSearchContext(null, indexShard, newContextSearcher(reader, executor))); + when(searchContext.mapperService()).thenReturn(mapperService); + searchContext.sort(longSortAndFormats); + searchContext.parsedQuery(query); + searchContext.setTask(task); + searchContext.setSize(0); + QueryPhase.executeInternal(searchContext.withCleanQueryResult(), queryPhaseSearcher); + } + + // 7. Test that sort optimization works with search after + { + TestSearchContext searchContext = spy(new TestSearchContext(null, indexShard, newContextSearcher(reader, executor))); + when(searchContext.mapperService()).thenReturn(mapperService); + int afterDocument = (int) randomLongBetween(0, 50); + long afterValue = firstValue + afterDocument; + FieldDoc after = new FieldDoc(afterDocument, Float.NaN, new Long[] { afterValue }); + searchContext.searchAfter(after); + searchContext.sort(longSortAndFormats); + searchContext.parsedQuery(query); + searchContext.setTask(task); + searchContext.setSize(10); + QueryPhase.executeInternal(searchContext.withCleanQueryResult(), queryPhaseSearcher); + final TopDocs topDocs = searchContext.queryResult().topDocs().topDocs; + long topValue = (long) ((FieldDoc) topDocs.scoreDocs[0]).fields[0]; + assertThat(topValue, greaterThan(afterValue)); + assertSortResults(topDocs, (long) numDocs, false); + + final TotalHits totalHits = topDocs.totalHits; + assertEquals(TotalHits.Relation.EQUAL_TO, totalHits.relation); + assertEquals(numDocs, totalHits.value); + } + + reader.close(); + dir.close(); + } + + public void testMaxScoreQueryVisitor() { + BitSetProducer producer = context -> new FixedBitSet(1); + Query query = new OpenSearchToParentBlockJoinQuery(new MatchAllDocsQuery(), producer, ScoreMode.Avg, "nested"); + assertTrue(hasInfMaxScore(query)); + + query = new OpenSearchToParentBlockJoinQuery(new MatchAllDocsQuery(), producer, ScoreMode.None, "nested"); + assertFalse(hasInfMaxScore(query)); + + for (Occur occur : Occur.values()) { + query = new BooleanQuery.Builder().add( + new OpenSearchToParentBlockJoinQuery(new MatchAllDocsQuery(), producer, ScoreMode.Avg, "nested"), + occur + ).build(); + if (occur == Occur.MUST) { + assertTrue(hasInfMaxScore(query)); + } else { + assertFalse(hasInfMaxScore(query)); + } + + query = new BooleanQuery.Builder().add( + new BooleanQuery.Builder().add( + new OpenSearchToParentBlockJoinQuery(new MatchAllDocsQuery(), producer, ScoreMode.Avg, "nested"), + occur + ).build(), + occur + ).build(); + if (occur == Occur.MUST) { + assertTrue(hasInfMaxScore(query)); + } else { + assertFalse(hasInfMaxScore(query)); + } + + query = new BooleanQuery.Builder().add( + new BooleanQuery.Builder().add( + new OpenSearchToParentBlockJoinQuery(new MatchAllDocsQuery(), producer, ScoreMode.Avg, "nested"), + occur + ).build(), + Occur.FILTER + ).build(); + assertFalse(hasInfMaxScore(query)); + + query = new BooleanQuery.Builder().add( + new BooleanQuery.Builder().add(new SpanTermQuery(new Term("field", "foo")), occur) + .add(new OpenSearchToParentBlockJoinQuery(new MatchAllDocsQuery(), producer, ScoreMode.Avg, "nested"), occur) + .build(), + occur + ).build(); + if (occur == Occur.MUST) { + assertTrue(hasInfMaxScore(query)); + } else { + assertFalse(hasInfMaxScore(query)); + } + } + } + + // assert score docs are in order and their number is as expected + private void assertSortResults(TopDocs topDocs, long expectedNumDocs, boolean isDoubleSort) { + if (topDocs.totalHits.relation == TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO) { + assertThat(topDocs.totalHits.value, lessThanOrEqualTo(expectedNumDocs)); + } else { + assertEquals(topDocs.totalHits.value, expectedNumDocs); + } + long cur1, cur2; + long prev1 = Long.MIN_VALUE; + long prev2 = Long.MIN_VALUE; + for (ScoreDoc scoreDoc : topDocs.scoreDocs) { + cur1 = (long) ((FieldDoc) scoreDoc).fields[0]; + assertThat(cur1, greaterThanOrEqualTo(prev1)); // test that docs are properly sorted on the first sort + if (isDoubleSort) { + cur2 = (long) ((FieldDoc) scoreDoc).fields[1]; + if (cur1 == prev1) { + assertThat(cur2, greaterThanOrEqualTo(prev2)); // test that docs are properly sorted on the secondary sort + } + prev2 = cur2; + } + prev1 = cur1; + } + } + + public void testMinScore() throws Exception { + Directory dir = newDirectory(); + IndexWriterConfig iwc = newIndexWriterConfig(); + RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc); + for (int i = 0; i < 10; i++) { + Document doc = new Document(); + doc.add(new StringField("foo", "bar", Store.NO)); + doc.add(new StringField("filter", "f1", Store.NO)); + w.addDocument(doc); + } + w.close(); + + IndexReader reader = DirectoryReader.open(dir); + TestSearchContext context = new TestSearchContext(null, indexShard, newContextSearcher(reader, executor)); + context.parsedQuery( + new ParsedQuery( + new BooleanQuery.Builder().add(new TermQuery(new Term("foo", "bar")), Occur.MUST) + .add(new TermQuery(new Term("filter", "f1")), Occur.SHOULD) + .build() + ) + ); + context.minimumScore(0.01f); + context.setTask(new SearchShardTask(123L, "", "", "", null, Collections.emptyMap())); + context.setSize(1); + context.trackTotalHitsUpTo(5); + + QueryPhase.executeInternal(context.withCleanQueryResult(), queryPhaseSearcher); + assertEquals(10, context.queryResult().topDocs().topDocs.totalHits.value); + + reader.close(); + dir.close(); + } + + public void testMaxScore() throws Exception { + Directory dir = newDirectory(); + final Sort sort = new Sort(new SortField("filter", SortField.Type.STRING)); + IndexWriterConfig iwc = newIndexWriterConfig().setIndexSort(sort); + RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc); + + final int numDocs = scaledRandomIntBetween(600, 900); + for (int i = 0; i < numDocs; i++) { + Document doc = new Document(); + doc.add(new StringField("foo", "bar", Store.NO)); + doc.add(new StringField("filter", "f1" + ((i > 0) ? " " + Integer.toString(i) : ""), Store.NO)); + doc.add(new SortedDocValuesField("filter", newBytesRef("f1" + ((i > 0) ? " " + Integer.toString(i) : "")))); + w.addDocument(doc); + } + w.close(); + + IndexReader reader = DirectoryReader.open(dir); + TestSearchContext context = new TestSearchContext(null, indexShard, newContextSearcher(reader, executor)); + context.trackScores(true); + context.parsedQuery( + new ParsedQuery( + new BooleanQuery.Builder().add(new TermQuery(new Term("foo", "bar")), Occur.MUST) + .add(new TermQuery(new Term("filter", "f1")), Occur.SHOULD) + .build() + ) + ); + context.setTask(new SearchShardTask(123L, "", "", "", null, Collections.emptyMap())); + context.setSize(1); + context.trackTotalHitsUpTo(5); + + QueryPhase.executeInternal(context.withCleanQueryResult(), queryPhaseSearcher); + assertFalse(Float.isNaN(context.queryResult().getMaxScore())); + assertEquals(1, context.queryResult().topDocs().topDocs.scoreDocs.length); + assertThat(context.queryResult().topDocs().topDocs.totalHits.value, greaterThanOrEqualTo(6L)); + + context.sort(new SortAndFormats(sort, new DocValueFormat[] { DocValueFormat.RAW })); + QueryPhase.executeInternal(context.withCleanQueryResult(), queryPhaseSearcher); + assertFalse(Float.isNaN(context.queryResult().getMaxScore())); + assertEquals(1, context.queryResult().topDocs().topDocs.scoreDocs.length); + assertThat(context.queryResult().topDocs().topDocs.totalHits.value, greaterThanOrEqualTo(6L)); + + context.trackScores(false); + QueryPhase.executeInternal(context.withCleanQueryResult(), queryPhaseSearcher); + assertTrue(Float.isNaN(context.queryResult().getMaxScore())); + assertEquals(1, context.queryResult().topDocs().topDocs.scoreDocs.length); + assertThat(context.queryResult().topDocs().topDocs.totalHits.value, greaterThanOrEqualTo(6L)); + + reader.close(); + dir.close(); + } + + public void testCollapseQuerySearchResults() throws Exception { + Directory dir = newDirectory(); + final Sort sort = new Sort(new SortField("user", SortField.Type.INT)); + IndexWriterConfig iwc = newIndexWriterConfig().setIndexSort(sort); + RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc); + + // Always end up with uneven buckets so collapsing is predictable + final int numDocs = 2 * scaledRandomIntBetween(600, 900) - 1; + for (int i = 0; i < numDocs; i++) { + Document doc = new Document(); + doc.add(new StringField("foo", "bar", Store.NO)); + doc.add(new NumericDocValuesField("user", i & 1)); + w.addDocument(doc); + } + w.close(); + + IndexReader reader = DirectoryReader.open(dir); + QueryShardContext queryShardContext = mock(QueryShardContext.class); + when(queryShardContext.fieldMapper("user")).thenReturn( + new NumberFieldType("user", NumberType.INTEGER, true, false, true, false, null, Collections.emptyMap()) + ); + + TestSearchContext context = new TestSearchContext(queryShardContext, indexShard, newContextSearcher(reader, executor)); + context.collapse(new CollapseBuilder("user").build(context.getQueryShardContext())); + context.trackScores(true); + context.parsedQuery(new ParsedQuery(new TermQuery(new Term("foo", "bar")))); + context.setTask(new SearchShardTask(123L, "", "", "", null, Collections.emptyMap())); + context.setSize(2); + context.trackTotalHitsUpTo(5); + + QueryPhase.executeInternal(context.withCleanQueryResult(), queryPhaseSearcher); + assertFalse(Float.isNaN(context.queryResult().getMaxScore())); + assertEquals(2, context.queryResult().topDocs().topDocs.scoreDocs.length); + assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo((long) numDocs)); + assertThat(context.queryResult().topDocs().topDocs, instanceOf(CollapseTopFieldDocs.class)); + + CollapseTopFieldDocs topDocs = (CollapseTopFieldDocs) context.queryResult().topDocs().topDocs; + assertThat(topDocs.collapseValues.length, equalTo(2)); + assertThat(topDocs.collapseValues[0], equalTo(0L)); // user == 0 + assertThat(topDocs.collapseValues[1], equalTo(1L)); // user == 1 + + context.sort(new SortAndFormats(sort, new DocValueFormat[] { DocValueFormat.RAW })); + QueryPhase.executeInternal(context.withCleanQueryResult(), queryPhaseSearcher); + assertFalse(Float.isNaN(context.queryResult().getMaxScore())); + assertEquals(2, context.queryResult().topDocs().topDocs.scoreDocs.length); + assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo((long) numDocs)); + assertThat(context.queryResult().topDocs().topDocs, instanceOf(CollapseTopFieldDocs.class)); + + topDocs = (CollapseTopFieldDocs) context.queryResult().topDocs().topDocs; + assertThat(topDocs.collapseValues.length, equalTo(2)); + assertThat(topDocs.collapseValues[0], equalTo(0L)); // user == 0 + assertThat(topDocs.collapseValues[1], equalTo(1L)); // user == 1 + + context.trackScores(false); + QueryPhase.executeInternal(context.withCleanQueryResult(), queryPhaseSearcher); + assertTrue(Float.isNaN(context.queryResult().getMaxScore())); + assertEquals(2, context.queryResult().topDocs().topDocs.scoreDocs.length); + assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo((long) numDocs)); + assertThat(context.queryResult().topDocs().topDocs, instanceOf(CollapseTopFieldDocs.class)); + + topDocs = (CollapseTopFieldDocs) context.queryResult().topDocs().topDocs; + assertThat(topDocs.collapseValues.length, equalTo(2)); + assertThat(topDocs.collapseValues[0], equalTo(0L)); // user == 0 + assertThat(topDocs.collapseValues[1], equalTo(1L)); // user == 1 + + reader.close(); + dir.close(); + } + + public void testCancellationDuringPreprocess() throws IOException { + try (Directory dir = newDirectory(); RandomIndexWriter w = new RandomIndexWriter(random(), dir, newIndexWriterConfig())) { + + for (int i = 0; i < 10; i++) { + Document doc = new Document(); + StringBuilder sb = new StringBuilder(); + for (int j = 0; j < i; j++) { + sb.append('a'); + } + doc.add(new StringField("foo", sb.toString(), Store.NO)); + w.addDocument(doc); + } + w.flush(); + w.close(); + + try (IndexReader reader = DirectoryReader.open(dir)) { + TestSearchContext context = new TestSearchContextWithRewriteAndCancellation( + null, + indexShard, + newContextSearcher(reader, executor) + ); + PrefixQuery prefixQuery = new PrefixQuery(new Term("foo", "a")); + prefixQuery.setRewriteMethod(MultiTermQuery.SCORING_BOOLEAN_REWRITE); + context.parsedQuery(new ParsedQuery(prefixQuery)); + SearchShardTask task = mock(SearchShardTask.class); + when(task.isCancelled()).thenReturn(true); + context.setTask(task); + expectThrows(TaskCancelledException.class, () -> new QueryPhase().preProcess(context)); + } + } + } + + private static class TestSearchContextWithRewriteAndCancellation extends TestSearchContext { + + private TestSearchContextWithRewriteAndCancellation( + QueryShardContext queryShardContext, + IndexShard indexShard, + ContextIndexSearcher searcher + ) { + super(queryShardContext, indexShard, searcher); + } + + @Override + public void preProcess(boolean rewrite) { + try { + searcher().rewrite(query()); + } catch (IOException e) { + fail("IOException shouldn't be thrown"); + } + } + + @Override + public boolean lowLevelCancellation() { + return true; + } + } + + private static ContextIndexSearcher newContextSearcher(IndexReader reader, ExecutorService executor) throws IOException { + return new ContextIndexSearcher( + reader, + IndexSearcher.getDefaultSimilarity(), + IndexSearcher.getDefaultQueryCache(), + IndexSearcher.getDefaultQueryCachingPolicy(), + true, + executor + ); + } + + private static ContextIndexSearcher newEarlyTerminationContextSearcher(IndexReader reader, int size, ExecutorService executor) + throws IOException { + return new ContextIndexSearcher( + reader, + IndexSearcher.getDefaultSimilarity(), + IndexSearcher.getDefaultQueryCache(), + IndexSearcher.getDefaultQueryCachingPolicy(), + true, + executor + ) { + + @Override + public void search(List leaves, Weight weight, Collector collector) throws IOException { + final Collector in = new AssertingEarlyTerminationFilterCollector(collector, size); + super.search(leaves, weight, in); + } + }; + } + + // used to check that numeric long or date sort optimization was run + private static ContextIndexSearcher newOptimizedContextSearcher(IndexReader reader, int queryType, ExecutorService executor) + throws IOException { + return new ContextIndexSearcher( + reader, + IndexSearcher.getDefaultSimilarity(), + IndexSearcher.getDefaultQueryCache(), + IndexSearcher.getDefaultQueryCachingPolicy(), + true, + executor + ) { + + @Override + public void search( + Query query, + CollectorManager manager, + QuerySearchResult result, + DocValueFormat[] formats, + TotalHits totalHits + ) throws IOException { + assertTrue(query instanceof BooleanQuery); + List clauses = ((BooleanQuery) query).clauses(); + assertTrue(clauses.size() == 2); + assertTrue(clauses.get(0).getOccur() == Occur.FILTER); + assertTrue(clauses.get(1).getOccur() == Occur.SHOULD); + if (queryType == 0) { + assertTrue( + clauses.get(1).getQuery().getClass() == LongPoint.newDistanceFeatureQuery("random_field", 1, 1, 1).getClass() + ); + } + if (queryType == 1) assertTrue(clauses.get(1).getQuery() instanceof DocValuesFieldExistsQuery); + super.search(query, manager, result, formats, totalHits); + } + + @Override + public void search( + List leaves, + Weight weight, + @SuppressWarnings("rawtypes") CollectorManager manager, + QuerySearchResult result, + DocValueFormat[] formats, + TotalHits totalHits + ) throws IOException { + final Query query = weight.getQuery(); + assertTrue(query instanceof BooleanQuery); + List clauses = ((BooleanQuery) query).clauses(); + assertTrue(clauses.size() == 2); + assertTrue(clauses.get(0).getOccur() == Occur.FILTER); + assertTrue(clauses.get(1).getOccur() == Occur.SHOULD); + if (queryType == 0) { + assertTrue( + clauses.get(1).getQuery().getClass() == LongPoint.newDistanceFeatureQuery("random_field", 1, 1, 1).getClass() + ); + } + if (queryType == 1) assertTrue(clauses.get(1).getQuery() instanceof DocValuesFieldExistsQuery); + super.search(leaves, weight, manager, result, formats, totalHits); + } + + @Override + public void search(List leaves, Weight weight, Collector collector) throws IOException { + if (getExecutor() == null) { + assert (false); // should not be there, expected to search with CollectorManager + } else { + super.search(leaves, weight, collector); + } + } + }; + } + + private static class TestTotalHitCountCollectorManager extends TotalHitCountCollectorManager { + private int totalHits; + private final TotalHitCountCollector collector; + private final Integer teminateAfter; + + static TestTotalHitCountCollectorManager create(final ExecutorService executor) { + return create(executor, null, null); + } + + static TestTotalHitCountCollectorManager create(final ExecutorService executor, final Integer teminateAfter) { + return create(executor, null, teminateAfter); + } + + static TestTotalHitCountCollectorManager create(final ExecutorService executor, final Sort sort) { + return create(executor, sort, null); + } + + static TestTotalHitCountCollectorManager create(final ExecutorService executor, final Sort sort, final Integer teminateAfter) { + if (executor == null) { + return new TestTotalHitCountCollectorManager(new TotalHitCountCollector(), sort); + } else { + return new TestTotalHitCountCollectorManager(sort, teminateAfter); + } + } + + private TestTotalHitCountCollectorManager(final TotalHitCountCollector collector, final Sort sort) { + super(sort); + this.collector = collector; + this.teminateAfter = null; + } + + private TestTotalHitCountCollectorManager(final Sort sort, final Integer teminateAfter) { + super(sort); + this.collector = null; + this.teminateAfter = teminateAfter; + } + + @Override + public TotalHitCountCollector newCollector() throws IOException { + return (collector == null) ? super.newCollector() : collector; + } + + @Override + public ReduceableSearchResult reduce(Collection collectors) throws IOException { + final ReduceableSearchResult result = super.reduce(collectors); + totalHits = collectors.stream().mapToInt(TotalHitCountCollector::getTotalHits).sum(); + + if (teminateAfter != null) { + assertThat(totalHits, greaterThanOrEqualTo(teminateAfter)); + totalHits = Math.min(totalHits, teminateAfter); + } + + return result; + } + + public int getTotalHits() { + return (collector == null) ? totalHits : collector.getTotalHits(); + } + } + + private static class AssertingEarlyTerminationFilterCollector extends FilterCollector { + private final int size; + + AssertingEarlyTerminationFilterCollector(Collector in, int size) { + super(in); + this.size = size; + } + + @Override + public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException { + final LeafCollector in = super.getLeafCollector(context); + return new FilterLeafCollector(in) { + int collected; + + @Override + public void collect(int doc) throws IOException { + assert collected <= size : "should not collect more than " + size + " doc per segment, got " + collected; + ++collected; + super.collect(doc); + } + }; + } + } +} diff --git a/sandbox/plugins/concurrent-search/src/test/java/org/opensearch/search/query/QueryProfilePhaseTests.java b/sandbox/plugins/concurrent-search/src/test/java/org/opensearch/search/query/QueryProfilePhaseTests.java new file mode 100644 index 0000000000000..d2cb77f529793 --- /dev/null +++ b/sandbox/plugins/concurrent-search/src/test/java/org/opensearch/search/query/QueryProfilePhaseTests.java @@ -0,0 +1,1182 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.query; + +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + +import org.apache.lucene.analysis.standard.StandardAnalyzer; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field.Store; +import org.apache.lucene.document.NumericDocValuesField; +import org.apache.lucene.document.SortedDocValuesField; +import org.apache.lucene.document.StringField; +import org.apache.lucene.document.TextField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.Term; +import org.apache.lucene.queries.spans.SpanNearQuery; +import org.apache.lucene.queries.spans.SpanTermQuery; +import org.apache.lucene.search.BooleanClause.Occur; +import org.apache.lucene.search.grouping.CollapseTopFieldDocs; +import org.apache.lucene.search.BooleanQuery; +import org.apache.lucene.search.Collector; +import org.apache.lucene.search.FieldComparator; +import org.apache.lucene.search.FieldDoc; +import org.apache.lucene.search.FilterCollector; +import org.apache.lucene.search.FilterLeafCollector; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.LeafCollector; +import org.apache.lucene.search.MatchAllDocsQuery; +import org.apache.lucene.search.MatchNoDocsQuery; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.Sort; +import org.apache.lucene.search.SortField; +import org.apache.lucene.search.TermQuery; +import org.apache.lucene.search.TotalHits; +import org.apache.lucene.search.Weight; +import org.apache.lucene.tests.index.RandomIndexWriter; +import org.apache.lucene.store.Directory; +import org.opensearch.action.search.SearchShardTask; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.ToXContent; +import org.opensearch.common.xcontent.XContentBuilder; +import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.index.mapper.NumberFieldMapper.NumberFieldType; +import org.opensearch.index.mapper.NumberFieldMapper.NumberType; +import org.opensearch.index.query.ParsedQuery; +import org.opensearch.index.query.QueryShardContext; +import org.opensearch.index.shard.IndexShard; +import org.opensearch.index.shard.IndexShardTestCase; +import org.opensearch.lucene.queries.MinDocQuery; +import org.opensearch.search.DocValueFormat; +import org.opensearch.search.collapse.CollapseBuilder; +import org.opensearch.search.internal.ContextIndexSearcher; +import org.opensearch.search.internal.ScrollContext; +import org.opensearch.search.internal.SearchContext; +import org.opensearch.search.profile.ProfileResult; +import org.opensearch.search.profile.ProfileShardResult; +import org.opensearch.search.profile.SearchProfileShardResults; +import org.opensearch.search.profile.query.CollectorResult; +import org.opensearch.search.profile.query.QueryProfileShardResult; +import org.opensearch.search.sort.SortAndFormats; +import org.opensearch.test.TestSearchContext; +import org.opensearch.threadpool.ThreadPool; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.OutputStream; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.function.Consumer; + +import static org.hamcrest.CoreMatchers.not; +import static org.hamcrest.CoreMatchers.nullValue; +import static org.hamcrest.Matchers.anyOf; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.hamcrest.Matchers.instanceOf; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.hamcrest.Matchers.hasSize; + +public class QueryProfilePhaseTests extends IndexShardTestCase { + + private IndexShard indexShard; + private final ExecutorService executor; + private final QueryPhaseSearcher queryPhaseSearcher; + + @ParametersFactory + public static Collection concurrency() { + return Arrays.asList( + new Object[] { 0, QueryPhase.DEFAULT_QUERY_PHASE_SEARCHER }, + new Object[] { 5, new ConcurrentQueryPhaseSearcher() } + ); + } + + public QueryProfilePhaseTests(int concurrency, QueryPhaseSearcher queryPhaseSearcher) { + this.executor = (concurrency > 0) ? Executors.newFixedThreadPool(concurrency) : null; + this.queryPhaseSearcher = queryPhaseSearcher; + } + + @Override + public Settings threadPoolSettings() { + return Settings.builder().put(super.threadPoolSettings()).put("thread_pool.search.min_queue_size", 10).build(); + } + + @Override + public void setUp() throws Exception { + super.setUp(); + indexShard = newShard(true); + } + + @Override + public void tearDown() throws Exception { + super.tearDown(); + closeShards(indexShard); + + if (executor != null) { + ThreadPool.terminate(executor, 10, TimeUnit.SECONDS); + } + } + + public void testPostFilterDisablesCountOptimization() throws Exception { + Directory dir = newDirectory(); + final Sort sort = new Sort(new SortField("rank", SortField.Type.INT)); + IndexWriterConfig iwc = newIndexWriterConfig().setIndexSort(sort); + RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc); + Document doc = new Document(); + w.addDocument(doc); + w.close(); + + IndexReader reader = DirectoryReader.open(dir); + + TestSearchContext context = new TestSearchContext(null, indexShard, newEarlyTerminationContextSearcher(reader, 0, executor)); + context.setTask(new SearchShardTask(123L, "", "", "", null, Collections.emptyMap())); + context.parsedQuery(new ParsedQuery(new MatchAllDocsQuery())); + + QueryPhase.executeInternal(context.withCleanQueryResult().withProfilers(), queryPhaseSearcher); + assertEquals(1, context.queryResult().topDocs().topDocs.totalHits.value); + assertProfileData(context, "MatchAllDocsQuery", query -> { + assertThat(query.getTimeBreakdown().keySet(), not(empty())); + assertThat(query.getTimeBreakdown().get("score"), equalTo(0L)); + assertThat(query.getTimeBreakdown().get("score_count"), equalTo(0L)); + assertThat(query.getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + }, collector -> { + assertThat(collector.getReason(), equalTo("search_count")); + assertThat(collector.getTime(), greaterThan(0L)); + assertThat(collector.getProfiledChildren(), empty()); + }); + + context.setSearcher(newContextSearcher(reader, executor)); + context.parsedPostFilter(new ParsedQuery(new MatchNoDocsQuery())); + QueryPhase.executeInternal(context.withCleanQueryResult().withProfilers(), queryPhaseSearcher); + assertEquals(0, context.queryResult().topDocs().topDocs.totalHits.value); + assertProfileData(context, collector -> { + assertThat(collector.getReason(), equalTo("search_post_filter")); + assertThat(collector.getTime(), greaterThan(0L)); + assertThat(collector.getProfiledChildren(), hasSize(1)); + assertThat(collector.getProfiledChildren().get(0).getReason(), equalTo("search_count")); + assertThat(collector.getProfiledChildren().get(0).getTime(), greaterThan(0L)); + }, (query) -> { + assertThat(query.getQueryName(), equalTo("MatchNoDocsQuery")); + assertThat(query.getTimeBreakdown().keySet(), not(empty())); + assertThat(query.getTimeBreakdown().get("score"), equalTo(0L)); + assertThat(query.getTimeBreakdown().get("score_count"), equalTo(0L)); + assertThat(query.getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + }, (query) -> { + assertThat(query.getQueryName(), equalTo("MatchAllDocsQuery")); + assertThat(query.getTimeBreakdown().keySet(), not(empty())); + assertThat(query.getTimeBreakdown().get("score"), equalTo(0L)); + assertThat(query.getTimeBreakdown().get("score_count"), equalTo(0L)); + assertThat(query.getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + }); + + reader.close(); + dir.close(); + } + + public void testTerminateAfterWithFilter() throws Exception { + Directory dir = newDirectory(); + final Sort sort = new Sort(new SortField("rank", SortField.Type.INT)); + IndexWriterConfig iwc = newIndexWriterConfig().setIndexSort(sort); + RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc); + Document doc = new Document(); + for (int i = 0; i < 10; i++) { + doc.add(new StringField("foo", Integer.toString(i), Store.NO)); + } + w.addDocument(doc); + w.close(); + + IndexReader reader = DirectoryReader.open(dir); + + TestSearchContext context = new TestSearchContext(null, indexShard, newContextSearcher(reader, executor)); + context.setTask(new SearchShardTask(123L, "", "", "", null, Collections.emptyMap())); + + context.parsedQuery(new ParsedQuery(new MatchAllDocsQuery())); + context.terminateAfter(1); + context.setSize(10); + for (int i = 0; i < 10; i++) { + context.parsedPostFilter(new ParsedQuery(new TermQuery(new Term("foo", Integer.toString(i))))); + QueryPhase.executeInternal(context.withCleanQueryResult().withProfilers(), queryPhaseSearcher); + assertEquals(1, context.queryResult().topDocs().topDocs.totalHits.value); + assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(1)); + assertProfileData(context, collector -> { + assertThat(collector.getReason(), equalTo("search_post_filter")); + assertThat(collector.getTime(), greaterThan(0L)); + assertThat(collector.getProfiledChildren(), hasSize(1)); + assertThat(collector.getProfiledChildren().get(0).getReason(), equalTo("search_terminate_after_count")); + assertThat(collector.getProfiledChildren().get(0).getTime(), greaterThan(0L)); + assertThat(collector.getProfiledChildren().get(0).getProfiledChildren(), hasSize(1)); + assertThat(collector.getProfiledChildren().get(0).getProfiledChildren().get(0).getReason(), equalTo("search_top_hits")); + assertThat(collector.getProfiledChildren().get(0).getProfiledChildren().get(0).getTime(), greaterThan(0L)); + }, (query) -> { + assertThat(query.getQueryName(), equalTo("TermQuery")); + assertThat(query.getTimeBreakdown().keySet(), not(empty())); + assertThat(query.getTimeBreakdown().get("score"), equalTo(0L)); + assertThat(query.getTimeBreakdown().get("score_count"), equalTo(0L)); + assertThat(query.getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + }, (query) -> { + assertThat(query.getQueryName(), equalTo("MatchAllDocsQuery")); + assertThat(query.getTimeBreakdown().keySet(), not(empty())); + assertThat(query.getTimeBreakdown().get("score"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("score_count"), equalTo(1L)); + assertThat(query.getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + }); + } + reader.close(); + dir.close(); + } + + public void testMinScoreDisablesCountOptimization() throws Exception { + Directory dir = newDirectory(); + final Sort sort = new Sort(new SortField("rank", SortField.Type.INT)); + IndexWriterConfig iwc = newIndexWriterConfig().setIndexSort(sort); + RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc); + Document doc = new Document(); + w.addDocument(doc); + w.close(); + + IndexReader reader = DirectoryReader.open(dir); + TestSearchContext context = new TestSearchContext(null, indexShard, newEarlyTerminationContextSearcher(reader, 0, executor)); + context.parsedQuery(new ParsedQuery(new MatchAllDocsQuery())); + context.setSize(0); + context.setTask(new SearchShardTask(123L, "", "", "", null, Collections.emptyMap())); + QueryPhase.executeInternal(context.withCleanQueryResult().withProfilers(), queryPhaseSearcher); + assertEquals(1, context.queryResult().topDocs().topDocs.totalHits.value); + assertProfileData(context, "MatchAllDocsQuery", query -> { + assertThat(query.getTimeBreakdown().keySet(), not(empty())); + assertThat(query.getTimeBreakdown().get("score"), equalTo(0L)); + assertThat(query.getTimeBreakdown().get("score_count"), equalTo(0L)); + assertThat(query.getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + }, collector -> { + assertThat(collector.getReason(), equalTo("search_count")); + assertThat(collector.getTime(), greaterThan(0L)); + assertThat(collector.getProfiledChildren(), empty()); + }); + + context.minimumScore(100); + QueryPhase.executeInternal(context.withCleanQueryResult().withProfilers(), queryPhaseSearcher); + assertEquals(0, context.queryResult().topDocs().topDocs.totalHits.value); + assertEquals(TotalHits.Relation.EQUAL_TO, context.queryResult().topDocs().topDocs.totalHits.relation); + assertProfileData(context, "MatchAllDocsQuery", query -> { + assertThat(query.getTimeBreakdown().keySet(), not(empty())); + assertThat(query.getTimeBreakdown().get("score"), greaterThanOrEqualTo(100L)); + assertThat(query.getTimeBreakdown().get("score_count"), equalTo(1L)); + assertThat(query.getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + }, collector -> { + assertThat(collector.getReason(), equalTo("search_min_score")); + assertThat(collector.getTime(), greaterThan(0L)); + assertThat(collector.getProfiledChildren(), hasSize(1)); + assertThat(collector.getProfiledChildren().get(0).getReason(), equalTo("search_count")); + assertThat(collector.getProfiledChildren().get(0).getTime(), greaterThan(0L)); + }); + + reader.close(); + dir.close(); + } + + public void testInOrderScrollOptimization() throws Exception { + Directory dir = newDirectory(); + final Sort sort = new Sort(new SortField("rank", SortField.Type.INT)); + IndexWriterConfig iwc = newIndexWriterConfig().setIndexSort(sort); + RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc); + final int numDocs = scaledRandomIntBetween(600, 900); + for (int i = 0; i < numDocs; ++i) { + w.addDocument(new Document()); + } + w.close(); + IndexReader reader = DirectoryReader.open(dir); + ScrollContext scrollContext = new ScrollContext(); + TestSearchContext context = new TestSearchContext(null, indexShard, newContextSearcher(reader, executor), scrollContext); + context.parsedQuery(new ParsedQuery(new MatchAllDocsQuery())); + scrollContext.lastEmittedDoc = null; + scrollContext.maxScore = Float.NaN; + scrollContext.totalHits = null; + context.setTask(new SearchShardTask(123L, "", "", "", null, Collections.emptyMap())); + int size = randomIntBetween(2, 5); + context.setSize(size); + + QueryPhase.executeInternal(context.withCleanQueryResult().withProfilers(), queryPhaseSearcher); + assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo((long) numDocs)); + assertNull(context.queryResult().terminatedEarly()); + assertThat(context.terminateAfter(), equalTo(0)); + assertThat(context.queryResult().getTotalHits().value, equalTo((long) numDocs)); + assertProfileData(context, "MatchAllDocsQuery", query -> { + assertThat(query.getTimeBreakdown().keySet(), not(empty())); + assertThat(query.getTimeBreakdown().get("score"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("score_count"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + }, collector -> { + assertThat(collector.getReason(), equalTo("search_top_hits")); + assertThat(collector.getTime(), greaterThan(0L)); + assertThat(collector.getProfiledChildren(), empty()); + }); + + context.setSearcher(newEarlyTerminationContextSearcher(reader, size, executor)); + QueryPhase.executeInternal(context.withCleanQueryResult().withProfilers(), queryPhaseSearcher); + assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo((long) numDocs)); + assertThat(context.terminateAfter(), equalTo(size)); + assertThat(context.queryResult().getTotalHits().value, equalTo((long) numDocs)); + assertThat(context.queryResult().topDocs().topDocs.scoreDocs[0].doc, greaterThanOrEqualTo(size)); + assertProfileData(context, "ConstantScoreQuery", query -> { + assertThat(query.getTimeBreakdown().keySet(), not(empty())); + assertThat(query.getTimeBreakdown().get("score"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("score_count"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + assertThat(query.getProfiledChildren().get(0).getTimeBreakdown().get("score"), equalTo(0L)); + assertThat(query.getProfiledChildren().get(0).getTimeBreakdown().get("score_count"), equalTo(0L)); + assertThat(query.getProfiledChildren().get(0).getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getProfiledChildren().get(0).getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + }, collector -> { + assertThat(collector.getReason(), equalTo("search_terminate_after_count")); + assertThat(collector.getTime(), greaterThan(0L)); + assertThat(collector.getProfiledChildren(), hasSize(1)); + assertThat(collector.getProfiledChildren().get(0).getReason(), equalTo("search_top_hits")); + assertThat(collector.getProfiledChildren().get(0).getTime(), greaterThan(0L)); + }); + + reader.close(); + dir.close(); + } + + public void testTerminateAfterEarlyTermination() throws Exception { + Directory dir = newDirectory(); + IndexWriterConfig iwc = newIndexWriterConfig(); + RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc); + final int numDocs = scaledRandomIntBetween(600, 900); + for (int i = 0; i < numDocs; ++i) { + Document doc = new Document(); + if (randomBoolean()) { + doc.add(new StringField("foo", "bar", Store.NO)); + } + if (randomBoolean()) { + doc.add(new StringField("foo", "baz", Store.NO)); + } + doc.add(new NumericDocValuesField("rank", numDocs - i)); + w.addDocument(doc); + } + w.close(); + final IndexReader reader = DirectoryReader.open(dir); + TestSearchContext context = new TestSearchContext(null, indexShard, newContextSearcher(reader, executor)); + context.setTask(new SearchShardTask(123L, "", "", "", null, Collections.emptyMap())); + context.parsedQuery(new ParsedQuery(new MatchAllDocsQuery())); + + context.terminateAfter(1); + { + context.setSize(1); + QueryPhase.executeInternal(context.withCleanQueryResult().withProfilers(), queryPhaseSearcher); + assertTrue(context.queryResult().terminatedEarly()); + assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo(1L)); + assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(1)); + assertProfileData(context, "MatchAllDocsQuery", query -> { + assertThat(query.getTimeBreakdown().keySet(), not(empty())); + assertThat(query.getTimeBreakdown().get("score"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("score_count"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + }, collector -> { + assertThat(collector.getReason(), equalTo("search_terminate_after_count")); + assertThat(collector.getTime(), greaterThan(0L)); + assertThat(collector.getProfiledChildren(), hasSize(1)); + assertThat(collector.getProfiledChildren().get(0).getReason(), equalTo("search_top_hits")); + assertThat(collector.getProfiledChildren().get(0).getTime(), greaterThan(0L)); + }); + + context.setSize(0); + QueryPhase.executeInternal(context.withCleanQueryResult().withProfilers(), queryPhaseSearcher); + assertTrue(context.queryResult().terminatedEarly()); + assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo(1L)); + assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(0)); + assertProfileData(context, "MatchAllDocsQuery", query -> { + assertThat(query.getTimeBreakdown().keySet(), not(empty())); + assertThat(query.getTimeBreakdown().get("score"), equalTo(0L)); + assertThat(query.getTimeBreakdown().get("score_count"), equalTo(0L)); + assertThat(query.getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + }, collector -> { + assertThat(collector.getReason(), equalTo("search_terminate_after_count")); + assertThat(collector.getTime(), greaterThan(0L)); + assertThat(collector.getProfiledChildren(), hasSize(1)); + assertThat(collector.getProfiledChildren().get(0).getReason(), equalTo("search_count")); + assertThat(collector.getProfiledChildren().get(0).getTime(), greaterThan(0L)); + }); + } + + { + context.setSize(1); + QueryPhase.executeInternal(context.withCleanQueryResult().withProfilers(), queryPhaseSearcher); + assertTrue(context.queryResult().terminatedEarly()); + assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo(1L)); + assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(1)); + assertProfileData(context, "MatchAllDocsQuery", query -> { + assertThat(query.getTimeBreakdown().keySet(), not(empty())); + assertThat(query.getTimeBreakdown().get("score"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("score_count"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + }, collector -> { + assertThat(collector.getReason(), equalTo("search_terminate_after_count")); + assertThat(collector.getTime(), greaterThan(0L)); + assertThat(collector.getProfiledChildren(), hasSize(1)); + assertThat(collector.getProfiledChildren().get(0).getReason(), equalTo("search_top_hits")); + assertThat(collector.getProfiledChildren().get(0).getTime(), greaterThan(0L)); + }); + } + { + context.setSize(1); + BooleanQuery bq = new BooleanQuery.Builder().add(new TermQuery(new Term("foo", "bar")), Occur.SHOULD) + .add(new TermQuery(new Term("foo", "baz")), Occur.SHOULD) + .build(); + context.parsedQuery(new ParsedQuery(bq)); + QueryPhase.executeInternal(context.withCleanQueryResult().withProfilers(), queryPhaseSearcher); + assertTrue(context.queryResult().terminatedEarly()); + assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo(1L)); + assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(1)); + assertProfileData(context, "BooleanQuery", query -> { + assertThat(query.getTimeBreakdown().keySet(), not(empty())); + assertThat(query.getTimeBreakdown().get("score"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("score_count"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + + assertThat(query.getProfiledChildren(), hasSize(2)); + assertThat(query.getProfiledChildren().get(0).getQueryName(), equalTo("TermQuery")); + assertThat(query.getProfiledChildren().get(0).getTime(), greaterThan(0L)); + assertThat(query.getProfiledChildren().get(0).getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getProfiledChildren().get(0).getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + + assertThat(query.getProfiledChildren().get(1).getQueryName(), equalTo("TermQuery")); + assertThat(query.getProfiledChildren().get(1).getTime(), greaterThan(0L)); + assertThat(query.getProfiledChildren().get(1).getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getProfiledChildren().get(1).getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + }, collector -> { + assertThat(collector.getReason(), equalTo("search_terminate_after_count")); + assertThat(collector.getTime(), greaterThan(0L)); + assertThat(collector.getProfiledChildren(), hasSize(1)); + assertThat(collector.getProfiledChildren().get(0).getReason(), equalTo("search_top_hits")); + assertThat(collector.getProfiledChildren().get(0).getTime(), greaterThan(0L)); + }); + context.setSize(0); + context.parsedQuery(new ParsedQuery(bq)); + QueryPhase.executeInternal(context.withCleanQueryResult().withProfilers(), queryPhaseSearcher); + assertTrue(context.queryResult().terminatedEarly()); + assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo(1L)); + assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(0)); + + assertProfileData(context, "BooleanQuery", query -> { + assertThat(query.getTimeBreakdown().keySet(), not(empty())); + assertThat(query.getTimeBreakdown().get("score"), equalTo(0L)); + assertThat(query.getTimeBreakdown().get("score_count"), equalTo(0L)); + assertThat(query.getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + + assertThat(query.getProfiledChildren(), hasSize(2)); + assertThat(query.getProfiledChildren().get(0).getQueryName(), equalTo("TermQuery")); + assertThat(query.getProfiledChildren().get(0).getTime(), greaterThan(0L)); + assertThat(query.getProfiledChildren().get(0).getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getProfiledChildren().get(0).getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + assertThat(query.getProfiledChildren().get(0).getTimeBreakdown().get("score"), equalTo(0L)); + assertThat(query.getProfiledChildren().get(0).getTimeBreakdown().get("score_count"), equalTo(0L)); + + assertThat(query.getProfiledChildren().get(1).getQueryName(), equalTo("TermQuery")); + assertThat(query.getProfiledChildren().get(1).getTime(), greaterThan(0L)); + assertThat(query.getProfiledChildren().get(1).getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getProfiledChildren().get(1).getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + assertThat(query.getProfiledChildren().get(1).getTimeBreakdown().get("score"), equalTo(0L)); + assertThat(query.getProfiledChildren().get(1).getTimeBreakdown().get("score_count"), equalTo(0L)); + }, collector -> { + assertThat(collector.getReason(), equalTo("search_terminate_after_count")); + assertThat(collector.getTime(), greaterThan(0L)); + assertThat(collector.getProfiledChildren(), hasSize(1)); + assertThat(collector.getProfiledChildren().get(0).getReason(), equalTo("search_count")); + assertThat(collector.getProfiledChildren().get(0).getTime(), greaterThan(0L)); + }); + } + + context.terminateAfter(7); + context.setSize(10); + for (int trackTotalHits : new int[] { -1, 3, 75, 100 }) { + context.trackTotalHitsUpTo(trackTotalHits); + QueryPhase.executeInternal(context.withCleanQueryResult().withProfilers(), queryPhaseSearcher); + assertTrue(context.queryResult().terminatedEarly()); + if (trackTotalHits == -1) { + assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo(0L)); + } else { + assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo(7L)); + } + assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(7)); + assertProfileData(context, "BooleanQuery", query -> { + assertThat(query.getTimeBreakdown().keySet(), not(empty())); + assertThat(query.getTimeBreakdown().get("score"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("score_count"), greaterThanOrEqualTo(7L)); + assertThat(query.getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + + assertThat(query.getProfiledChildren(), hasSize(2)); + assertThat(query.getProfiledChildren().get(0).getQueryName(), equalTo("TermQuery")); + assertThat(query.getProfiledChildren().get(0).getTime(), greaterThan(0L)); + assertThat(query.getProfiledChildren().get(0).getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getProfiledChildren().get(0).getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + assertThat(query.getProfiledChildren().get(0).getTimeBreakdown().get("score"), greaterThan(0L)); + assertThat(query.getProfiledChildren().get(0).getTimeBreakdown().get("score_count"), greaterThan(0L)); + + assertThat(query.getProfiledChildren().get(1).getQueryName(), equalTo("TermQuery")); + assertThat(query.getProfiledChildren().get(1).getTime(), greaterThan(0L)); + assertThat(query.getProfiledChildren().get(1).getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getProfiledChildren().get(1).getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + assertThat(query.getProfiledChildren().get(1).getTimeBreakdown().get("score"), greaterThan(0L)); + assertThat(query.getProfiledChildren().get(1).getTimeBreakdown().get("score_count"), greaterThan(0L)); + }, collector -> { + assertThat(collector.getReason(), equalTo("search_terminate_after_count")); + assertThat(collector.getTime(), greaterThan(0L)); + assertThat(collector.getProfiledChildren(), hasSize(1)); + assertThat(collector.getProfiledChildren().get(0).getReason(), equalTo("search_top_hits")); + assertThat(collector.getProfiledChildren().get(0).getTime(), greaterThan(0L)); + }); + } + + reader.close(); + dir.close(); + } + + public void testIndexSortingEarlyTermination() throws Exception { + Directory dir = newDirectory(); + final Sort sort = new Sort(new SortField("rank", SortField.Type.INT)); + IndexWriterConfig iwc = newIndexWriterConfig().setIndexSort(sort); + RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc); + final int numDocs = scaledRandomIntBetween(600, 900); + for (int i = 0; i < numDocs; ++i) { + Document doc = new Document(); + if (randomBoolean()) { + doc.add(new StringField("foo", "bar", Store.NO)); + } + if (randomBoolean()) { + doc.add(new StringField("foo", "baz", Store.NO)); + } + doc.add(new NumericDocValuesField("rank", numDocs - i)); + w.addDocument(doc); + } + w.close(); + + final IndexReader reader = DirectoryReader.open(dir); + TestSearchContext context = new TestSearchContext(null, indexShard, newContextSearcher(reader, executor)); + context.parsedQuery(new ParsedQuery(new MatchAllDocsQuery())); + context.setSize(1); + context.setTask(new SearchShardTask(123L, "", "", "", null, Collections.emptyMap())); + context.sort(new SortAndFormats(sort, new DocValueFormat[] { DocValueFormat.RAW })); + + QueryPhase.executeInternal(context.withCleanQueryResult().withProfilers(), queryPhaseSearcher); + assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo((long) numDocs)); + assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(1)); + assertThat(context.queryResult().topDocs().topDocs.scoreDocs[0], instanceOf(FieldDoc.class)); + FieldDoc fieldDoc = (FieldDoc) context.queryResult().topDocs().topDocs.scoreDocs[0]; + assertThat(fieldDoc.fields[0], equalTo(1)); + assertProfileData(context, "MatchAllDocsQuery", query -> { + assertThat(query.getTimeBreakdown().keySet(), not(empty())); + assertThat(query.getTimeBreakdown().get("score"), equalTo(0L)); + assertThat(query.getTimeBreakdown().get("score_count"), equalTo(0L)); + assertThat(query.getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + }, collector -> { + assertThat(collector.getReason(), equalTo("search_top_hits")); + assertThat(collector.getTime(), greaterThan(0L)); + assertThat(collector.getProfiledChildren(), empty()); + }); + + { + context.parsedPostFilter(new ParsedQuery(new MinDocQuery(1))); + QueryPhase.executeInternal(context.withCleanQueryResult().withProfilers(), queryPhaseSearcher); + assertNull(context.queryResult().terminatedEarly()); + assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo(numDocs - 1L)); + assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(1)); + assertThat(context.queryResult().topDocs().topDocs.scoreDocs[0], instanceOf(FieldDoc.class)); + assertThat(fieldDoc.fields[0], anyOf(equalTo(1), equalTo(2))); + assertProfileData(context, collector -> { + assertThat(collector.getReason(), equalTo("search_post_filter")); + assertThat(collector.getTime(), greaterThan(0L)); + assertThat(collector.getProfiledChildren(), hasSize(1)); + assertThat(collector.getProfiledChildren().get(0).getReason(), equalTo("search_top_hits")); + assertThat(collector.getProfiledChildren().get(0).getTime(), greaterThan(0L)); + }, (query) -> { + assertThat(query.getQueryName(), equalTo("MinDocQuery")); + assertThat(query.getTimeBreakdown().keySet(), not(empty())); + assertThat(query.getTimeBreakdown().get("score"), equalTo(0L)); + assertThat(query.getTimeBreakdown().get("score_count"), equalTo(0L)); + assertThat(query.getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + }, (query) -> { + assertThat(query.getQueryName(), equalTo("MatchAllDocsQuery")); + assertThat(query.getTimeBreakdown().keySet(), not(empty())); + assertThat(query.getTimeBreakdown().get("score"), equalTo(0L)); + assertThat(query.getTimeBreakdown().get("score_count"), equalTo(0L)); + assertThat(query.getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + }); + context.parsedPostFilter(null); + } + + { + context.setSearcher(newEarlyTerminationContextSearcher(reader, 1, executor)); + context.trackTotalHitsUpTo(SearchContext.TRACK_TOTAL_HITS_DISABLED); + QueryPhase.executeInternal(context.withCleanQueryResult().withProfilers(), queryPhaseSearcher); + assertNull(context.queryResult().terminatedEarly()); + assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(1)); + assertThat(context.queryResult().topDocs().topDocs.scoreDocs[0], instanceOf(FieldDoc.class)); + assertThat(fieldDoc.fields[0], anyOf(equalTo(1), equalTo(2))); + assertProfileData(context, "MatchAllDocsQuery", query -> { + assertThat(query.getTimeBreakdown().keySet(), not(empty())); + assertThat(query.getTimeBreakdown().get("score"), equalTo(0L)); + assertThat(query.getTimeBreakdown().get("score_count"), equalTo(0L)); + assertThat(query.getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + }, collector -> { + assertThat(collector.getReason(), equalTo("search_top_hits")); + assertThat(collector.getTime(), greaterThan(0L)); + assertThat(collector.getProfiledChildren(), empty()); + }); + + QueryPhase.executeInternal(context.withCleanQueryResult().withProfilers(), queryPhaseSearcher); + assertNull(context.queryResult().terminatedEarly()); + assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(1)); + assertThat(context.queryResult().topDocs().topDocs.scoreDocs[0], instanceOf(FieldDoc.class)); + assertThat(fieldDoc.fields[0], anyOf(equalTo(1), equalTo(2))); + assertProfileData(context, "MatchAllDocsQuery", query -> { + assertThat(query.getTimeBreakdown().keySet(), not(empty())); + assertThat(query.getTimeBreakdown().get("score"), equalTo(0L)); + assertThat(query.getTimeBreakdown().get("score_count"), equalTo(0L)); + assertThat(query.getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + }, collector -> { + assertThat(collector.getReason(), equalTo("search_top_hits")); + assertThat(collector.getTime(), greaterThan(0L)); + assertThat(collector.getProfiledChildren(), empty()); + }); + } + + reader.close(); + dir.close(); + } + + public void testIndexSortScrollOptimization() throws Exception { + Directory dir = newDirectory(); + final Sort indexSort = new Sort(new SortField("rank", SortField.Type.INT), new SortField("tiebreaker", SortField.Type.INT)); + IndexWriterConfig iwc = newIndexWriterConfig().setIndexSort(indexSort); + RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc); + final int numDocs = scaledRandomIntBetween(600, 900); + for (int i = 0; i < numDocs; ++i) { + Document doc = new Document(); + doc.add(new NumericDocValuesField("rank", random().nextInt())); + doc.add(new NumericDocValuesField("tiebreaker", i)); + w.addDocument(doc); + } + if (randomBoolean()) { + w.forceMerge(randomIntBetween(1, 10)); + } + w.close(); + + final IndexReader reader = DirectoryReader.open(dir); + List searchSortAndFormats = new ArrayList<>(); + searchSortAndFormats.add(new SortAndFormats(indexSort, new DocValueFormat[] { DocValueFormat.RAW, DocValueFormat.RAW })); + // search sort is a prefix of the index sort + searchSortAndFormats.add(new SortAndFormats(new Sort(indexSort.getSort()[0]), new DocValueFormat[] { DocValueFormat.RAW })); + for (SortAndFormats searchSortAndFormat : searchSortAndFormats) { + ScrollContext scrollContext = new ScrollContext(); + TestSearchContext context = new TestSearchContext(null, indexShard, newContextSearcher(reader, executor), scrollContext); + context.parsedQuery(new ParsedQuery(new MatchAllDocsQuery())); + scrollContext.lastEmittedDoc = null; + scrollContext.maxScore = Float.NaN; + scrollContext.totalHits = null; + context.setTask(new SearchShardTask(123L, "", "", "", null, Collections.emptyMap())); + context.setSize(10); + context.sort(searchSortAndFormat); + + QueryPhase.executeInternal(context.withCleanQueryResult().withProfilers(), queryPhaseSearcher); + assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo((long) numDocs)); + assertNull(context.queryResult().terminatedEarly()); + assertThat(context.terminateAfter(), equalTo(0)); + assertThat(context.queryResult().getTotalHits().value, equalTo((long) numDocs)); + assertProfileData(context, "MatchAllDocsQuery", query -> { + assertThat(query.getTimeBreakdown().keySet(), not(empty())); + assertThat(query.getTimeBreakdown().get("score"), equalTo(0L)); + assertThat(query.getTimeBreakdown().get("score_count"), equalTo(0L)); + assertThat(query.getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + }, collector -> { + assertThat(collector.getReason(), equalTo("search_top_hits")); + assertThat(collector.getTime(), greaterThan(0L)); + assertThat(collector.getProfiledChildren(), empty()); + }); + + int sizeMinus1 = context.queryResult().topDocs().topDocs.scoreDocs.length - 1; + FieldDoc lastDoc = (FieldDoc) context.queryResult().topDocs().topDocs.scoreDocs[sizeMinus1]; + + context.setSearcher(newEarlyTerminationContextSearcher(reader, 10, executor)); + QueryPhase.executeInternal(context.withCleanQueryResult().withProfilers(), queryPhaseSearcher); + assertNull(context.queryResult().terminatedEarly()); + assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo((long) numDocs)); + assertThat(context.terminateAfter(), equalTo(0)); + assertThat(context.queryResult().getTotalHits().value, equalTo((long) numDocs)); + assertProfileData(context, "ConstantScoreQuery", query -> { + assertThat(query.getTimeBreakdown().keySet(), not(empty())); + assertThat(query.getTimeBreakdown().get("score"), equalTo(0L)); + assertThat(query.getTimeBreakdown().get("score_count"), equalTo(0L)); + assertThat(query.getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + + assertThat(query.getProfiledChildren(), hasSize(1)); + assertThat(query.getProfiledChildren().get(0).getQueryName(), equalTo("SearchAfterSortedDocQuery")); + assertThat(query.getProfiledChildren().get(0).getTime(), greaterThan(0L)); + assertThat(query.getProfiledChildren().get(0).getTimeBreakdown().get("score"), equalTo(0L)); + assertThat(query.getProfiledChildren().get(0).getTimeBreakdown().get("score_count"), equalTo(0L)); + assertThat(query.getProfiledChildren().get(0).getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getProfiledChildren().get(0).getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + }, collector -> { + assertThat(collector.getReason(), equalTo("search_top_hits")); + assertThat(collector.getTime(), greaterThan(0L)); + assertThat(collector.getProfiledChildren(), empty()); + }); + FieldDoc firstDoc = (FieldDoc) context.queryResult().topDocs().topDocs.scoreDocs[0]; + for (int i = 0; i < searchSortAndFormat.sort.getSort().length; i++) { + @SuppressWarnings("unchecked") + FieldComparator comparator = (FieldComparator) searchSortAndFormat.sort.getSort()[i].getComparator(i, true); + int cmp = comparator.compareValues(firstDoc.fields[i], lastDoc.fields[i]); + if (cmp == 0) { + continue; + } + assertThat(cmp, equalTo(1)); + break; + } + } + reader.close(); + dir.close(); + } + + public void testDisableTopScoreCollection() throws Exception { + Directory dir = newDirectory(); + IndexWriterConfig iwc = newIndexWriterConfig(new StandardAnalyzer()); + RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc); + Document doc = new Document(); + final int numDocs = 2 * scaledRandomIntBetween(50, 450); + for (int i = 0; i < numDocs; i++) { + doc.clear(); + if (i % 2 == 0) { + doc.add(new TextField("title", "foo bar", Store.NO)); + } else { + doc.add(new TextField("title", "foo", Store.NO)); + } + w.addDocument(doc); + } + w.close(); + + IndexReader reader = DirectoryReader.open(dir); + TestSearchContext context = new TestSearchContext(null, indexShard, newContextSearcher(reader, executor)); + context.setTask(new SearchShardTask(123L, "", "", "", null, Collections.emptyMap())); + Query q = new SpanNearQuery.Builder("title", true).addClause(new SpanTermQuery(new Term("title", "foo"))) + .addClause(new SpanTermQuery(new Term("title", "bar"))) + .build(); + + context.parsedQuery(new ParsedQuery(q)); + context.setSize(3); + context.trackTotalHitsUpTo(3); + TopDocsCollectorContext topDocsContext = TopDocsCollectorContext.createTopDocsCollectorContext(context, false); + assertEquals(topDocsContext.create(null).scoreMode(), org.apache.lucene.search.ScoreMode.COMPLETE); + QueryPhase.executeInternal(context.withCleanQueryResult().withProfilers(), queryPhaseSearcher); + assertEquals(numDocs / 2, context.queryResult().topDocs().topDocs.totalHits.value); + assertEquals(context.queryResult().topDocs().topDocs.totalHits.relation, TotalHits.Relation.EQUAL_TO); + assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(3)); + assertProfileData(context, "SpanNearQuery", query -> { + assertThat(query.getTimeBreakdown().keySet(), not(empty())); + assertThat(query.getTimeBreakdown().get("score"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("score_count"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + }, collector -> { + assertThat(collector.getReason(), equalTo("search_top_hits")); + assertThat(collector.getTime(), greaterThan(0L)); + assertThat(collector.getProfiledChildren(), empty()); + }); + + context.sort(new SortAndFormats(new Sort(new SortField("other", SortField.Type.INT)), new DocValueFormat[] { DocValueFormat.RAW })); + topDocsContext = TopDocsCollectorContext.createTopDocsCollectorContext(context, false); + assertEquals(topDocsContext.create(null).scoreMode(), org.apache.lucene.search.ScoreMode.TOP_DOCS); + QueryPhase.executeInternal(context.withCleanQueryResult().withProfilers(), queryPhaseSearcher); + assertEquals(numDocs / 2, context.queryResult().topDocs().topDocs.totalHits.value); + assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(3)); + assertEquals(context.queryResult().topDocs().topDocs.totalHits.relation, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO); + assertProfileData(context, "SpanNearQuery", query -> { + assertThat(query.getTimeBreakdown().keySet(), not(empty())); + assertThat(query.getTimeBreakdown().get("score"), equalTo(0L)); + assertThat(query.getTimeBreakdown().get("score_count"), equalTo(0L)); + assertThat(query.getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + }, collector -> { + assertThat(collector.getReason(), equalTo("search_top_hits")); + assertThat(collector.getTime(), greaterThan(0L)); + assertThat(collector.getProfiledChildren(), empty()); + }); + + reader.close(); + dir.close(); + } + + public void testMinScore() throws Exception { + Directory dir = newDirectory(); + IndexWriterConfig iwc = newIndexWriterConfig(); + RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc); + for (int i = 0; i < 10; i++) { + Document doc = new Document(); + doc.add(new StringField("foo", "bar", Store.NO)); + doc.add(new StringField("filter", "f1", Store.NO)); + w.addDocument(doc); + } + w.close(); + + IndexReader reader = DirectoryReader.open(dir); + TestSearchContext context = new TestSearchContext(null, indexShard, newContextSearcher(reader, executor)); + context.parsedQuery( + new ParsedQuery( + new BooleanQuery.Builder().add(new TermQuery(new Term("foo", "bar")), Occur.MUST) + .add(new TermQuery(new Term("filter", "f1")), Occur.SHOULD) + .build() + ) + ); + context.minimumScore(0.01f); + context.setTask(new SearchShardTask(123L, "", "", "", null, Collections.emptyMap())); + context.setSize(1); + context.trackTotalHitsUpTo(5); + + QueryPhase.executeInternal(context.withCleanQueryResult().withProfilers(), queryPhaseSearcher); + assertEquals(10, context.queryResult().topDocs().topDocs.totalHits.value); + assertProfileData(context, "BooleanQuery", query -> { + assertThat(query.getTimeBreakdown().keySet(), not(empty())); + assertThat(query.getTimeBreakdown().get("score"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("score_count"), equalTo(10L)); + assertThat(query.getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + + assertThat(query.getProfiledChildren(), hasSize(2)); + assertThat(query.getProfiledChildren().get(0).getQueryName(), equalTo("TermQuery")); + assertThat(query.getProfiledChildren().get(0).getTime(), greaterThan(0L)); + assertThat(query.getProfiledChildren().get(0).getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getProfiledChildren().get(0).getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + + assertThat(query.getProfiledChildren().get(1).getQueryName(), equalTo("TermQuery")); + assertThat(query.getProfiledChildren().get(1).getTime(), greaterThan(0L)); + assertThat(query.getProfiledChildren().get(1).getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getProfiledChildren().get(1).getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + }, collector -> { + assertThat(collector.getReason(), equalTo("search_min_score")); + assertThat(collector.getTime(), greaterThan(0L)); + assertThat(collector.getProfiledChildren(), hasSize(1)); + assertThat(collector.getProfiledChildren().get(0).getReason(), equalTo("search_top_hits")); + assertThat(collector.getProfiledChildren().get(0).getTime(), greaterThan(0L)); + }); + + reader.close(); + dir.close(); + } + + public void testMaxScore() throws Exception { + Directory dir = newDirectory(); + final Sort sort = new Sort(new SortField("filter", SortField.Type.STRING)); + IndexWriterConfig iwc = newIndexWriterConfig().setIndexSort(sort); + RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc); + + final int numDocs = scaledRandomIntBetween(600, 900); + for (int i = 0; i < numDocs; i++) { + Document doc = new Document(); + doc.add(new StringField("foo", "bar", Store.NO)); + doc.add(new StringField("filter", "f1" + ((i > 0) ? " " + Integer.toString(i) : ""), Store.NO)); + doc.add(new SortedDocValuesField("filter", newBytesRef("f1" + ((i > 0) ? " " + Integer.toString(i) : "")))); + w.addDocument(doc); + } + w.close(); + + IndexReader reader = DirectoryReader.open(dir); + TestSearchContext context = new TestSearchContext(null, indexShard, newContextSearcher(reader, executor)); + context.trackScores(true); + context.parsedQuery( + new ParsedQuery( + new BooleanQuery.Builder().add(new TermQuery(new Term("foo", "bar")), Occur.MUST) + .add(new TermQuery(new Term("filter", "f1")), Occur.SHOULD) + .build() + ) + ); + context.setTask(new SearchShardTask(123L, "", "", "", null, Collections.emptyMap())); + context.setSize(1); + context.trackTotalHitsUpTo(5); + + QueryPhase.executeInternal(context.withCleanQueryResult().withProfilers(), queryPhaseSearcher); + assertFalse(Float.isNaN(context.queryResult().getMaxScore())); + assertEquals(1, context.queryResult().topDocs().topDocs.scoreDocs.length); + assertThat(context.queryResult().topDocs().topDocs.totalHits.value, greaterThanOrEqualTo(6L)); + assertProfileData(context, "BooleanQuery", query -> { + assertThat(query.getTimeBreakdown().keySet(), not(empty())); + assertThat(query.getTimeBreakdown().get("score"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("score_count"), greaterThanOrEqualTo(6L)); + assertThat(query.getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + + assertThat(query.getProfiledChildren(), hasSize(2)); + assertThat(query.getProfiledChildren().get(0).getQueryName(), equalTo("TermQuery")); + assertThat(query.getProfiledChildren().get(0).getTime(), greaterThan(0L)); + assertThat(query.getProfiledChildren().get(0).getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getProfiledChildren().get(0).getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + + assertThat(query.getProfiledChildren().get(1).getQueryName(), equalTo("TermQuery")); + assertThat(query.getProfiledChildren().get(1).getTime(), greaterThan(0L)); + assertThat(query.getProfiledChildren().get(1).getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getProfiledChildren().get(1).getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + }, collector -> { + assertThat(collector.getReason(), equalTo("search_top_hits")); + assertThat(collector.getTime(), greaterThan(0L)); + assertThat(collector.getProfiledChildren(), empty()); + }); + + context.sort(new SortAndFormats(sort, new DocValueFormat[] { DocValueFormat.RAW })); + QueryPhase.executeInternal(context.withCleanQueryResult().withProfilers(), queryPhaseSearcher); + assertFalse(Float.isNaN(context.queryResult().getMaxScore())); + assertEquals(1, context.queryResult().topDocs().topDocs.scoreDocs.length); + assertThat(context.queryResult().topDocs().topDocs.totalHits.value, greaterThanOrEqualTo(6L)); + assertProfileData(context, "BooleanQuery", query -> { + assertThat(query.getTimeBreakdown().keySet(), not(empty())); + assertThat(query.getTimeBreakdown().get("score"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("score_count"), greaterThanOrEqualTo(6L)); + assertThat(query.getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + + assertThat(query.getProfiledChildren(), hasSize(2)); + assertThat(query.getProfiledChildren().get(0).getQueryName(), equalTo("TermQuery")); + assertThat(query.getProfiledChildren().get(0).getTime(), greaterThan(0L)); + assertThat(query.getProfiledChildren().get(0).getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getProfiledChildren().get(0).getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + + assertThat(query.getProfiledChildren().get(1).getQueryName(), equalTo("TermQuery")); + assertThat(query.getProfiledChildren().get(1).getTime(), greaterThan(0L)); + assertThat(query.getProfiledChildren().get(1).getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getProfiledChildren().get(1).getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + }, collector -> { + assertThat(collector.getReason(), equalTo("search_top_hits")); + assertThat(collector.getTime(), greaterThan(0L)); + assertThat(collector.getProfiledChildren(), empty()); + }); + + reader.close(); + dir.close(); + } + + public void testCollapseQuerySearchResults() throws Exception { + Directory dir = newDirectory(); + final Sort sort = new Sort(new SortField("user", SortField.Type.INT)); + IndexWriterConfig iwc = newIndexWriterConfig().setIndexSort(sort); + RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc); + + // Always end up with uneven buckets so collapsing is predictable + final int numDocs = 2 * scaledRandomIntBetween(600, 900) - 1; + for (int i = 0; i < numDocs; i++) { + Document doc = new Document(); + doc.add(new StringField("foo", "bar", Store.NO)); + doc.add(new NumericDocValuesField("user", i & 1)); + w.addDocument(doc); + } + w.close(); + + IndexReader reader = DirectoryReader.open(dir); + QueryShardContext queryShardContext = mock(QueryShardContext.class); + when(queryShardContext.fieldMapper("user")).thenReturn( + new NumberFieldType("user", NumberType.INTEGER, true, false, true, false, null, Collections.emptyMap()) + ); + + TestSearchContext context = new TestSearchContext(queryShardContext, indexShard, newContextSearcher(reader, executor)); + context.collapse(new CollapseBuilder("user").build(context.getQueryShardContext())); + context.trackScores(true); + context.parsedQuery(new ParsedQuery(new TermQuery(new Term("foo", "bar")))); + context.setTask(new SearchShardTask(123L, "", "", "", null, Collections.emptyMap())); + context.setSize(2); + context.trackTotalHitsUpTo(5); + + QueryPhase.executeInternal(context.withCleanQueryResult().withProfilers(), queryPhaseSearcher); + assertFalse(Float.isNaN(context.queryResult().getMaxScore())); + assertEquals(2, context.queryResult().topDocs().topDocs.scoreDocs.length); + assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo((long) numDocs)); + assertThat(context.queryResult().topDocs().topDocs, instanceOf(CollapseTopFieldDocs.class)); + + assertProfileData(context, "TermQuery", query -> { + assertThat(query.getTimeBreakdown().keySet(), not(empty())); + assertThat(query.getTimeBreakdown().get("score"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("score_count"), greaterThanOrEqualTo(6L)); + assertThat(query.getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + assertThat(query.getProfiledChildren(), empty()); + }, collector -> { + assertThat(collector.getReason(), equalTo("search_top_hits")); + assertThat(collector.getTime(), greaterThan(0L)); + assertThat(collector.getProfiledChildren(), empty()); + }); + + context.sort(new SortAndFormats(sort, new DocValueFormat[] { DocValueFormat.RAW })); + QueryPhase.executeInternal(context.withCleanQueryResult().withProfilers(), queryPhaseSearcher); + assertFalse(Float.isNaN(context.queryResult().getMaxScore())); + assertEquals(2, context.queryResult().topDocs().topDocs.scoreDocs.length); + assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo((long) numDocs)); + assertThat(context.queryResult().topDocs().topDocs, instanceOf(CollapseTopFieldDocs.class)); + + assertProfileData(context, "TermQuery", query -> { + assertThat(query.getTimeBreakdown().keySet(), not(empty())); + assertThat(query.getTimeBreakdown().get("score"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("score_count"), greaterThanOrEqualTo(6L)); + assertThat(query.getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + assertThat(query.getProfiledChildren(), empty()); + }, collector -> { + assertThat(collector.getReason(), equalTo("search_top_hits")); + assertThat(collector.getTime(), greaterThan(0L)); + assertThat(collector.getProfiledChildren(), empty()); + }); + + reader.close(); + dir.close(); + } + + private void assertProfileData(SearchContext context, String type, Consumer query, Consumer collector) + throws IOException { + assertProfileData(context, collector, (profileResult) -> { + assertThat(profileResult.getQueryName(), equalTo(type)); + assertThat(profileResult.getTime(), greaterThan(0L)); + query.accept(profileResult); + }); + } + + private void assertProfileData(SearchContext context, Consumer collector, Consumer query1) + throws IOException { + assertProfileData(context, Arrays.asList(query1), collector, false); + } + + private void assertProfileData( + SearchContext context, + Consumer collector, + Consumer query1, + Consumer query2 + ) throws IOException { + assertProfileData(context, Arrays.asList(query1, query2), collector, false); + } + + private final void assertProfileData( + SearchContext context, + List> queries, + Consumer collector, + boolean debug + ) throws IOException { + assertThat(context.getProfilers(), not(nullValue())); + + final ProfileShardResult result = SearchProfileShardResults.buildShardResults(context.getProfilers(), null); + if (debug) { + final SearchProfileShardResults results = new SearchProfileShardResults( + Collections.singletonMap(indexShard.shardId().toString(), result) + ); + + try (final XContentBuilder builder = JsonXContent.contentBuilder().prettyPrint()) { + builder.startObject(); + results.toXContent(builder, ToXContent.EMPTY_PARAMS); + builder.endObject(); + builder.flush(); + + final OutputStream out = builder.getOutputStream(); + assertThat(out, instanceOf(ByteArrayOutputStream.class)); + + logger.info(new String(((ByteArrayOutputStream) out).toByteArray(), StandardCharsets.UTF_8)); + } + } + + assertThat(result.getQueryProfileResults(), hasSize(1)); + + final QueryProfileShardResult queryProfileShardResult = result.getQueryProfileResults().get(0); + assertThat(queryProfileShardResult.getQueryResults(), hasSize(queries.size())); + + for (int i = 0; i < queries.size(); ++i) { + queries.get(i).accept(queryProfileShardResult.getQueryResults().get(i)); + } + + collector.accept(queryProfileShardResult.getCollectorResult()); + } + + private static ContextIndexSearcher newContextSearcher(IndexReader reader, ExecutorService executor) throws IOException { + return new ContextIndexSearcher( + reader, + IndexSearcher.getDefaultSimilarity(), + IndexSearcher.getDefaultQueryCache(), + IndexSearcher.getDefaultQueryCachingPolicy(), + true, + executor + ); + } + + private static ContextIndexSearcher newEarlyTerminationContextSearcher(IndexReader reader, int size, ExecutorService executor) + throws IOException { + return new ContextIndexSearcher( + reader, + IndexSearcher.getDefaultSimilarity(), + IndexSearcher.getDefaultQueryCache(), + IndexSearcher.getDefaultQueryCachingPolicy(), + true, + executor + ) { + + @Override + public void search(List leaves, Weight weight, Collector collector) throws IOException { + final Collector in = new AssertingEarlyTerminationFilterCollector(collector, size); + super.search(leaves, weight, in); + } + }; + } + + private static class AssertingEarlyTerminationFilterCollector extends FilterCollector { + private final int size; + + AssertingEarlyTerminationFilterCollector(Collector in, int size) { + super(in); + this.size = size; + } + + @Override + public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException { + final LeafCollector in = super.getLeafCollector(context); + return new FilterLeafCollector(in) { + int collected; + + @Override + public void collect(int doc) throws IOException { + assert collected <= size : "should not collect more than " + size + " doc per segment, got " + collected; + ++collected; + super.collect(doc); + } + }; + } + } +} diff --git a/server/src/main/java/org/opensearch/common/lucene/MinimumScoreCollector.java b/server/src/main/java/org/opensearch/common/lucene/MinimumScoreCollector.java index 81c98c862d2b2..a883e111f7c95 100644 --- a/server/src/main/java/org/opensearch/common/lucene/MinimumScoreCollector.java +++ b/server/src/main/java/org/opensearch/common/lucene/MinimumScoreCollector.java @@ -55,6 +55,10 @@ public MinimumScoreCollector(Collector collector, float minimumScore) { this.minimumScore = minimumScore; } + public Collector getCollector() { + return collector; + } + @Override public void setScorer(Scorable scorer) throws IOException { if (!(scorer instanceof ScoreCachingWrappingScorer)) { diff --git a/server/src/main/java/org/opensearch/common/lucene/search/FilteredCollector.java b/server/src/main/java/org/opensearch/common/lucene/search/FilteredCollector.java index 331b67a40878f..2dcb0578fd23d 100644 --- a/server/src/main/java/org/opensearch/common/lucene/search/FilteredCollector.java +++ b/server/src/main/java/org/opensearch/common/lucene/search/FilteredCollector.java @@ -53,6 +53,10 @@ public FilteredCollector(Collector collector, Weight filter) { this.filter = filter; } + public Collector getCollector() { + return collector; + } + @Override public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException { final ScorerSupplier filterScorerSupplier = filter.scorerSupplier(context); diff --git a/server/src/main/java/org/opensearch/search/DefaultSearchContext.java b/server/src/main/java/org/opensearch/search/DefaultSearchContext.java index bfe8eed05ea9b..6fd78b834344d 100644 --- a/server/src/main/java/org/opensearch/search/DefaultSearchContext.java +++ b/server/src/main/java/org/opensearch/search/DefaultSearchContext.java @@ -36,6 +36,7 @@ import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.BoostQuery; import org.apache.lucene.search.Collector; +import org.apache.lucene.search.CollectorManager; import org.apache.lucene.search.FieldDoc; import org.apache.lucene.search.MatchNoDocsQuery; import org.apache.lucene.search.Query; @@ -82,6 +83,7 @@ import org.opensearch.search.profile.Profilers; import org.opensearch.search.query.QueryPhaseExecutionException; import org.opensearch.search.query.QuerySearchResult; +import org.opensearch.search.query.ReduceableSearchResult; import org.opensearch.search.rescore.RescoreContext; import org.opensearch.search.slice.SliceBuilder; import org.opensearch.search.sort.SortAndFormats; @@ -163,7 +165,7 @@ final class DefaultSearchContext extends SearchContext { private Profilers profilers; private final Map searchExtBuilders = new HashMap<>(); - private final Map, Collector> queryCollectors = new HashMap<>(); + private final Map, CollectorManager> queryCollectorManagers = new HashMap<>(); private final QueryShardContext queryShardContext; private final FetchPhase fetchPhase; @@ -823,8 +825,8 @@ public long getRelativeTimeInMillis() { } @Override - public Map, Collector> queryCollectors() { - return queryCollectors; + public Map, CollectorManager> queryCollectorManagers() { + return queryCollectorManagers; } @Override diff --git a/server/src/main/java/org/opensearch/search/aggregations/AggregationPhase.java b/server/src/main/java/org/opensearch/search/aggregations/AggregationPhase.java index be62b33adb356..5a837a6e14c5a 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/AggregationPhase.java +++ b/server/src/main/java/org/opensearch/search/aggregations/AggregationPhase.java @@ -32,6 +32,7 @@ package org.opensearch.search.aggregations; import org.apache.lucene.search.Collector; +import org.apache.lucene.search.CollectorManager; import org.apache.lucene.search.Query; import org.opensearch.common.inject.Inject; import org.opensearch.common.lucene.search.Queries; @@ -40,9 +41,11 @@ import org.opensearch.search.profile.query.CollectorResult; import org.opensearch.search.profile.query.InternalProfileCollector; import org.opensearch.search.query.QueryPhaseExecutionException; +import org.opensearch.search.query.ReduceableSearchResult; import java.io.IOException; import java.util.ArrayList; +import java.util.Collection; import java.util.Collections; import java.util.List; @@ -68,17 +71,18 @@ public void preProcess(SearchContext context) { } context.aggregations().aggregators(aggregators); if (!collectors.isEmpty()) { - Collector collector = MultiBucketCollector.wrap(collectors); - ((BucketCollector) collector).preCollection(); - if (context.getProfilers() != null) { - collector = new InternalProfileCollector( - collector, - CollectorResult.REASON_AGGREGATION, - // TODO: report on child aggs as well - Collections.emptyList() - ); - } - context.queryCollectors().put(AggregationPhase.class, collector); + final Collector collector = createCollector(context, collectors); + context.queryCollectorManagers().put(AggregationPhase.class, new CollectorManager() { + @Override + public Collector newCollector() throws IOException { + return collector; + } + + @Override + public ReduceableSearchResult reduce(Collection collectors) throws IOException { + throw new UnsupportedOperationException("The concurrent aggregation over index segments is not supported"); + } + }); } } catch (IOException e) { throw new AggregationInitializationException("Could not initialize aggregators", e); @@ -147,6 +151,20 @@ public void execute(SearchContext context) { // disable aggregations so that they don't run on next pages in case of scrolling context.aggregations(null); - context.queryCollectors().remove(AggregationPhase.class); + context.queryCollectorManagers().remove(AggregationPhase.class); + } + + private Collector createCollector(SearchContext context, List collectors) throws IOException { + Collector collector = MultiBucketCollector.wrap(collectors); + ((BucketCollector) collector).preCollection(); + if (context.getProfilers() != null) { + collector = new InternalProfileCollector( + collector, + CollectorResult.REASON_AGGREGATION, + // TODO: report on child aggs as well + Collections.emptyList() + ); + } + return collector; } } diff --git a/server/src/main/java/org/opensearch/search/internal/ContextIndexSearcher.java b/server/src/main/java/org/opensearch/search/internal/ContextIndexSearcher.java index 2cc15d4c65b96..2fb5ababe19ad 100644 --- a/server/src/main/java/org/opensearch/search/internal/ContextIndexSearcher.java +++ b/server/src/main/java/org/opensearch/search/internal/ContextIndexSearcher.java @@ -96,16 +96,6 @@ public class ContextIndexSearcher extends IndexSearcher implements Releasable { private QueryProfiler profiler; private MutableQueryTimeout cancellable; - public ContextIndexSearcher( - IndexReader reader, - Similarity similarity, - QueryCache queryCache, - QueryCachingPolicy queryCachingPolicy, - boolean wrapWithExitableDirectoryReader - ) throws IOException { - this(reader, similarity, queryCache, queryCachingPolicy, new MutableQueryTimeout(), wrapWithExitableDirectoryReader, null); - } - public ContextIndexSearcher( IndexReader reader, Similarity similarity, @@ -233,6 +223,25 @@ public void search( result.topDocs(new TopDocsAndMaxScore(mergedTopDocs, Float.NaN), formats); } + public void search( + Query query, + CollectorManager manager, + QuerySearchResult result, + DocValueFormat[] formats, + TotalHits totalHits + ) throws IOException { + TopFieldDocs mergedTopDocs = search(query, manager); + // Lucene sets shards indexes during merging of topDocs from different collectors + // We need to reset shard index; OpenSearch will set shard index later during reduce stage + for (ScoreDoc scoreDoc : mergedTopDocs.scoreDocs) { + scoreDoc.shardIndex = -1; + } + if (totalHits != null) { // we have already precalculated totalHits for the whole index + mergedTopDocs = new TopFieldDocs(totalHits, mergedTopDocs.scoreDocs, mergedTopDocs.fields); + } + result.topDocs(new TopDocsAndMaxScore(mergedTopDocs, Float.NaN), formats); + } + @Override protected void search(List leaves, Weight weight, Collector collector) throws IOException { for (LeafReaderContext ctx : leaves) { // search each subreader @@ -420,8 +429,4 @@ public void clear() { runnables.clear(); } } - - public boolean allowConcurrentSegmentSearch() { - return (getExecutor() != null); - } } diff --git a/server/src/main/java/org/opensearch/search/internal/FilteredSearchContext.java b/server/src/main/java/org/opensearch/search/internal/FilteredSearchContext.java index 6d77558ec3bd0..961d45b0011ef 100644 --- a/server/src/main/java/org/opensearch/search/internal/FilteredSearchContext.java +++ b/server/src/main/java/org/opensearch/search/internal/FilteredSearchContext.java @@ -33,6 +33,7 @@ package org.opensearch.search.internal; import org.apache.lucene.search.Collector; +import org.apache.lucene.search.CollectorManager; import org.apache.lucene.search.FieldDoc; import org.apache.lucene.search.Query; import org.opensearch.action.search.SearchShardTask; @@ -61,6 +62,7 @@ import org.opensearch.search.fetch.subphase.highlight.SearchHighlightContext; import org.opensearch.search.profile.Profilers; import org.opensearch.search.query.QuerySearchResult; +import org.opensearch.search.query.ReduceableSearchResult; import org.opensearch.search.rescore.RescoreContext; import org.opensearch.search.sort.SortAndFormats; import org.opensearch.search.suggest.SuggestionSearchContext; @@ -492,8 +494,8 @@ public Profilers getProfilers() { } @Override - public Map, Collector> queryCollectors() { - return in.queryCollectors(); + public Map, CollectorManager> queryCollectorManagers() { + return in.queryCollectorManagers(); } @Override diff --git a/server/src/main/java/org/opensearch/search/internal/SearchContext.java b/server/src/main/java/org/opensearch/search/internal/SearchContext.java index 7ff0eaed4be63..0c24fbee76335 100644 --- a/server/src/main/java/org/opensearch/search/internal/SearchContext.java +++ b/server/src/main/java/org/opensearch/search/internal/SearchContext.java @@ -32,6 +32,7 @@ package org.opensearch.search.internal; import org.apache.lucene.search.Collector; +import org.apache.lucene.search.CollectorManager; import org.apache.lucene.search.FieldDoc; import org.apache.lucene.search.Query; import org.opensearch.action.search.SearchShardTask; @@ -66,6 +67,7 @@ import org.opensearch.search.fetch.subphase.highlight.SearchHighlightContext; import org.opensearch.search.profile.Profilers; import org.opensearch.search.query.QuerySearchResult; +import org.opensearch.search.query.ReduceableSearchResult; import org.opensearch.search.rescore.RescoreContext; import org.opensearch.search.sort.SortAndFormats; import org.opensearch.search.suggest.SuggestionSearchContext; @@ -388,8 +390,8 @@ public final boolean hasOnlySuggest() { */ public abstract long getRelativeTimeInMillis(); - /** Return a view of the additional query collectors that should be run for this context. */ - public abstract Map, Collector> queryCollectors(); + /** Return a view of the additional query collector managers that should be run for this context. */ + public abstract Map, CollectorManager> queryCollectorManagers(); public abstract QueryShardContext getQueryShardContext(); diff --git a/server/src/main/java/org/opensearch/search/profile/Profilers.java b/server/src/main/java/org/opensearch/search/profile/Profilers.java index 6b9be0167b50f..3cc9b1710d420 100644 --- a/server/src/main/java/org/opensearch/search/profile/Profilers.java +++ b/server/src/main/java/org/opensearch/search/profile/Profilers.java @@ -57,7 +57,7 @@ public Profilers(ContextIndexSearcher searcher) { /** Switch to a new profile. */ public QueryProfiler addQueryProfiler() { - QueryProfiler profiler = new QueryProfiler(searcher.allowConcurrentSegmentSearch()); + QueryProfiler profiler = new QueryProfiler(searcher.getExecutor() != null); searcher.setProfiler(profiler); queryProfilers.add(profiler); return profiler; diff --git a/server/src/main/java/org/opensearch/search/profile/query/InternalProfileCollectorManager.java b/server/src/main/java/org/opensearch/search/profile/query/InternalProfileCollectorManager.java new file mode 100644 index 0000000000000..723dc6ae73904 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/profile/query/InternalProfileCollectorManager.java @@ -0,0 +1,89 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.profile.query; + +import org.apache.lucene.search.Collector; +import org.apache.lucene.search.CollectorManager; +import org.opensearch.search.query.EarlyTerminatingListener; +import org.opensearch.search.query.ReduceableSearchResult; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; + +public class InternalProfileCollectorManager + implements + ProfileCollectorManager, + EarlyTerminatingListener { + private final CollectorManager manager; + private final String reason; + private final List children; + private long time = 0; + + public InternalProfileCollectorManager( + CollectorManager manager, + String reason, + List children + ) { + this.manager = manager; + this.reason = reason; + this.children = children; + } + + @Override + public InternalProfileCollector newCollector() throws IOException { + return new InternalProfileCollector(manager.newCollector(), reason, children); + } + + @SuppressWarnings("unchecked") + @Override + public ReduceableSearchResult reduce(Collection collectors) throws IOException { + final Collection subs = new ArrayList<>(); + + for (final InternalProfileCollector collector : collectors) { + subs.add(collector.getCollector()); + time += collector.getTime(); + } + + return ((CollectorManager) manager).reduce(subs); + } + + @Override + public String getReason() { + return reason; + } + + @Override + public long getTime() { + return time; + } + + @Override + public Collection children() { + return children; + } + + @Override + public String getName() { + return manager.getClass().getSimpleName(); + } + + @Override + public CollectorResult getCollectorTree() { + return InternalProfileCollector.doGetCollectorTree(this); + } + + @Override + public void onEarlyTemination(int maxCountHits, boolean forcedTermination) { + if (manager instanceof EarlyTerminatingListener) { + ((EarlyTerminatingListener) manager).onEarlyTemination(maxCountHits, forcedTermination); + } + } +} diff --git a/server/src/main/java/org/opensearch/search/profile/query/ProfileCollectorManager.java b/server/src/main/java/org/opensearch/search/profile/query/ProfileCollectorManager.java new file mode 100644 index 0000000000000..7037988401fce --- /dev/null +++ b/server/src/main/java/org/opensearch/search/profile/query/ProfileCollectorManager.java @@ -0,0 +1,17 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.profile.query; + +import org.apache.lucene.search.Collector; +import org.apache.lucene.search.CollectorManager; + +/** + * Collector manager which supports profiling + */ +public interface ProfileCollectorManager extends CollectorManager, InternalProfileComponent {} diff --git a/server/src/main/java/org/opensearch/search/query/EarlyTerminatingCollector.java b/server/src/main/java/org/opensearch/search/query/EarlyTerminatingCollector.java index 3ee8430522891..56cb49835dcc4 100644 --- a/server/src/main/java/org/opensearch/search/query/EarlyTerminatingCollector.java +++ b/server/src/main/java/org/opensearch/search/query/EarlyTerminatingCollector.java @@ -95,6 +95,10 @@ public void collect(int doc) throws IOException { }; } + Collector getCollector() { + return in; + } + /** * Returns true if this collector has early terminated. */ diff --git a/server/src/main/java/org/opensearch/search/query/EarlyTerminatingCollectorManager.java b/server/src/main/java/org/opensearch/search/query/EarlyTerminatingCollectorManager.java new file mode 100644 index 0000000000000..02a2aa9519b70 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/query/EarlyTerminatingCollectorManager.java @@ -0,0 +1,74 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.query; + +import org.apache.lucene.search.Collector; +import org.apache.lucene.search.CollectorManager; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; + +public class EarlyTerminatingCollectorManager + implements + CollectorManager, + EarlyTerminatingListener { + + private final CollectorManager manager; + private final int maxCountHits; + private boolean forceTermination; + + EarlyTerminatingCollectorManager(CollectorManager manager, int maxCountHits, boolean forceTermination) { + this.manager = manager; + this.maxCountHits = maxCountHits; + this.forceTermination = forceTermination; + } + + @Override + public EarlyTerminatingCollector newCollector() throws IOException { + return new EarlyTerminatingCollector(manager.newCollector(), maxCountHits, false /* forced termination is not supported */); + } + + @SuppressWarnings("unchecked") + @Override + public ReduceableSearchResult reduce(Collection collectors) throws IOException { + final List innerCollectors = new ArrayList<>(collectors.size()); + + boolean didTerminateEarly = false; + for (EarlyTerminatingCollector collector : collectors) { + innerCollectors.add((C) collector.getCollector()); + if (collector.hasEarlyTerminated()) { + didTerminateEarly = true; + } + } + + if (didTerminateEarly) { + onEarlyTemination(maxCountHits, forceTermination); + + final ReduceableSearchResult result = manager.reduce(innerCollectors); + return new ReduceableSearchResult() { + @Override + public void reduce(QuerySearchResult r) throws IOException { + result.reduce(r); + r.terminatedEarly(true); + } + }; + } + + return manager.reduce(innerCollectors); + } + + @Override + public void onEarlyTemination(int maxCountHits, boolean forcedTermination) { + if (manager instanceof EarlyTerminatingListener) { + ((EarlyTerminatingListener) manager).onEarlyTemination(maxCountHits, forcedTermination); + } + } +} diff --git a/server/src/main/java/org/opensearch/search/query/EarlyTerminatingListener.java b/server/src/main/java/org/opensearch/search/query/EarlyTerminatingListener.java new file mode 100644 index 0000000000000..1f7bd04b30832 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/query/EarlyTerminatingListener.java @@ -0,0 +1,22 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.query; + +/** + * Early termination event listener. It is used during concurrent segment search + * to propagate the early termination intent. + */ +public interface EarlyTerminatingListener { + /** + * Early termination event notification + * @param maxCountHits desired maximum number of hits + * @param forcedTermination :true" if forced termination has been requested, "false" otherwise + */ + void onEarlyTemination(int maxCountHits, boolean forcedTermination); +} diff --git a/server/src/main/java/org/opensearch/search/query/FilteredCollectorManager.java b/server/src/main/java/org/opensearch/search/query/FilteredCollectorManager.java new file mode 100644 index 0000000000000..ef47cf2a388f3 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/query/FilteredCollectorManager.java @@ -0,0 +1,45 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.query; + +import org.apache.lucene.search.Collector; +import org.apache.lucene.search.CollectorManager; +import org.apache.lucene.search.Weight; +import org.opensearch.common.lucene.search.FilteredCollector; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; + +class FilteredCollectorManager implements CollectorManager { + private final CollectorManager manager; + private final Weight filter; + + FilteredCollectorManager(CollectorManager manager, Weight filter) { + this.manager = manager; + this.filter = filter; + } + + @Override + public FilteredCollector newCollector() throws IOException { + return new FilteredCollector(manager.newCollector(), filter); + } + + @Override + @SuppressWarnings("unchecked") + public ReduceableSearchResult reduce(Collection collectors) throws IOException { + final Collection subCollectors = new ArrayList<>(); + + for (final FilteredCollector collector : collectors) { + subCollectors.add(collector.getCollector()); + } + + return ((CollectorManager) manager).reduce(subCollectors); + } +} diff --git a/server/src/main/java/org/opensearch/search/query/MinimumCollectorManager.java b/server/src/main/java/org/opensearch/search/query/MinimumCollectorManager.java new file mode 100644 index 0000000000000..22b25222b639d --- /dev/null +++ b/server/src/main/java/org/opensearch/search/query/MinimumCollectorManager.java @@ -0,0 +1,44 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.query; + +import org.apache.lucene.search.Collector; +import org.apache.lucene.search.CollectorManager; +import org.opensearch.common.lucene.MinimumScoreCollector; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; + +class MinimumCollectorManager implements CollectorManager { + private final CollectorManager manager; + private final float minimumScore; + + MinimumCollectorManager(CollectorManager manager, float minimumScore) { + this.manager = manager; + this.minimumScore = minimumScore; + } + + @Override + public MinimumScoreCollector newCollector() throws IOException { + return new MinimumScoreCollector(manager.newCollector(), minimumScore); + } + + @Override + @SuppressWarnings("unchecked") + public ReduceableSearchResult reduce(Collection collectors) throws IOException { + final Collection subCollectors = new ArrayList<>(); + + for (final MinimumScoreCollector collector : collectors) { + subCollectors.add(collector.getCollector()); + } + + return ((CollectorManager) manager).reduce(subCollectors); + } +} diff --git a/server/src/main/java/org/opensearch/search/query/MultiCollectorWrapper.java b/server/src/main/java/org/opensearch/search/query/MultiCollectorWrapper.java new file mode 100644 index 0000000000000..0ee423b48caeb --- /dev/null +++ b/server/src/main/java/org/opensearch/search/query/MultiCollectorWrapper.java @@ -0,0 +1,58 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.query; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.Collector; +import org.apache.lucene.search.LeafCollector; +import org.apache.lucene.search.MultiCollector; +import org.apache.lucene.search.ScoreMode; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; + +/** + * Wraps MultiCollector and provide access to underlying collectors. + * Please check out https://github.com/apache/lucene/pull/455. + */ +public class MultiCollectorWrapper implements Collector { + private final MultiCollector delegate; + private final Collection collectors; + + MultiCollectorWrapper(MultiCollector delegate, Collection collectors) { + this.delegate = delegate; + this.collectors = collectors; + } + + @Override + public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException { + return delegate.getLeafCollector(context); + } + + @Override + public ScoreMode scoreMode() { + return delegate.scoreMode(); + } + + public Collection getCollectors() { + return collectors; + } + + public static Collector wrap(Collector... collectors) { + final List collectorsList = Arrays.asList(collectors); + final Collector collector = MultiCollector.wrap(collectorsList); + if (collector instanceof MultiCollector) { + return new MultiCollectorWrapper((MultiCollector) collector, collectorsList); + } else { + return collector; + } + } +} diff --git a/server/src/main/java/org/opensearch/search/query/QueryCollectorContext.java b/server/src/main/java/org/opensearch/search/query/QueryCollectorContext.java index d1ff855888f0b..95ad514adf97d 100644 --- a/server/src/main/java/org/opensearch/search/query/QueryCollectorContext.java +++ b/server/src/main/java/org/opensearch/search/query/QueryCollectorContext.java @@ -33,6 +33,7 @@ package org.opensearch.search.query; import org.apache.lucene.search.Collector; +import org.apache.lucene.search.CollectorManager; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.MultiCollector; import org.apache.lucene.search.Query; @@ -42,6 +43,7 @@ import org.opensearch.common.lucene.MinimumScoreCollector; import org.opensearch.common.lucene.search.FilteredCollector; import org.opensearch.search.profile.query.InternalProfileCollector; +import org.opensearch.search.profile.query.InternalProfileCollectorManager; import java.io.IOException; import java.util.ArrayList; @@ -54,7 +56,7 @@ import static org.opensearch.search.profile.query.CollectorResult.REASON_SEARCH_POST_FILTER; import static org.opensearch.search.profile.query.CollectorResult.REASON_SEARCH_TERMINATE_AFTER_COUNT; -abstract class QueryCollectorContext { +public abstract class QueryCollectorContext { private static final Collector EMPTY_COLLECTOR = new SimpleCollector() { @Override public void collect(int doc) {} @@ -77,6 +79,8 @@ public ScoreMode scoreMode() { */ abstract Collector create(Collector in) throws IOException; + abstract CollectorManager createManager(CollectorManager in) throws IOException; + /** * Wraps this collector with a profiler */ @@ -85,6 +89,18 @@ protected InternalProfileCollector createWithProfiler(InternalProfileCollector i return new InternalProfileCollector(collector, profilerName, in != null ? Collections.singletonList(in) : Collections.emptyList()); } + /** + * Wraps this collector manager with a profiler + */ + protected InternalProfileCollectorManager createWithProfiler(InternalProfileCollectorManager in) throws IOException { + final CollectorManager manager = createManager(in); + return new InternalProfileCollectorManager( + manager, + profilerName, + in != null ? Collections.singletonList(in) : Collections.emptyList() + ); + } + /** * Post-process result after search execution. * @@ -126,6 +142,11 @@ static QueryCollectorContext createMinScoreCollectorContext(float minScore) { Collector create(Collector in) { return new MinimumScoreCollector(in, minScore); } + + @Override + CollectorManager createManager(CollectorManager in) throws IOException { + return new MinimumCollectorManager(in, minScore); + } }; } @@ -139,35 +160,58 @@ Collector create(Collector in) throws IOException { final Weight filterWeight = searcher.createWeight(searcher.rewrite(query), ScoreMode.COMPLETE_NO_SCORES, 1f); return new FilteredCollector(in, filterWeight); } + + @Override + CollectorManager createManager(CollectorManager in) throws IOException { + final Weight filterWeight = searcher.createWeight(searcher.rewrite(query), ScoreMode.COMPLETE_NO_SCORES, 1f); + return new FilteredCollectorManager(in, filterWeight); + } }; } /** - * Creates a multi collector from the provided subs + * Creates a multi collector manager from the provided subs */ - static QueryCollectorContext createMultiCollectorContext(Collection subs) { + static QueryCollectorContext createMultiCollectorContext( + Collection> subs + ) { return new QueryCollectorContext(REASON_SEARCH_MULTI) { @Override - Collector create(Collector in) { + Collector create(Collector in) throws IOException { List subCollectors = new ArrayList<>(); subCollectors.add(in); - subCollectors.addAll(subs); + for (CollectorManager manager : subs) { + subCollectors.add(manager.newCollector()); + } return MultiCollector.wrap(subCollectors); } @Override - protected InternalProfileCollector createWithProfiler(InternalProfileCollector in) { + protected InternalProfileCollector createWithProfiler(InternalProfileCollector in) throws IOException { final List subCollectors = new ArrayList<>(); subCollectors.add(in); - if (subs.stream().anyMatch((col) -> col instanceof InternalProfileCollector == false)) { - throw new IllegalArgumentException("non-profiling collector"); - } - for (Collector collector : subs) { + + for (CollectorManager manager : subs) { + final Collector collector = manager.newCollector(); + if (!(collector instanceof InternalProfileCollector)) { + throw new IllegalArgumentException("non-profiling collector"); + } subCollectors.add((InternalProfileCollector) collector); } + final Collector collector = MultiCollector.wrap(subCollectors); return new InternalProfileCollector(collector, REASON_SEARCH_MULTI, subCollectors); } + + @Override + CollectorManager createManager( + CollectorManager in + ) throws IOException { + final List> managers = new ArrayList<>(); + managers.add(in); + managers.addAll(subs); + return QueryCollectorManagerContext.createOpaqueCollectorManager(managers); + } }; } @@ -192,6 +236,13 @@ Collector create(Collector in) { this.collector = MultiCollector.wrap(subCollectors); return collector; } + + @Override + CollectorManager createManager( + CollectorManager in + ) throws IOException { + return new EarlyTerminatingCollectorManager<>(in, numHits, true); + } }; } } diff --git a/server/src/main/java/org/opensearch/search/query/QueryCollectorManagerContext.java b/server/src/main/java/org/opensearch/search/query/QueryCollectorManagerContext.java new file mode 100644 index 0000000000000..c98f4884bb030 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/query/QueryCollectorManagerContext.java @@ -0,0 +1,99 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.query; + +import org.apache.lucene.search.Collector; +import org.apache.lucene.search.CollectorManager; +import org.apache.lucene.search.MultiCollectorManager; +import org.opensearch.search.profile.query.InternalProfileCollectorManager; +import org.opensearch.search.profile.query.ProfileCollectorManager; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; + +public abstract class QueryCollectorManagerContext { + private static class QueryCollectorManager implements CollectorManager { + private final MultiCollectorManager manager; + + private QueryCollectorManager(Collection> managers) { + this.manager = new MultiCollectorManager(managers.toArray(new CollectorManager[0])); + } + + @Override + public Collector newCollector() throws IOException { + return manager.newCollector(); + } + + @Override + public ReduceableSearchResult reduce(Collection collectors) throws IOException { + final Object[] results = manager.reduce(collectors); + + final ReduceableSearchResult[] transformed = new ReduceableSearchResult[results.length]; + for (int i = 0; i < results.length; ++i) { + assert results[i] instanceof ReduceableSearchResult; + transformed[i] = (ReduceableSearchResult) results[i]; + } + + return reduceWith(transformed); + } + + protected ReduceableSearchResult reduceWith(final ReduceableSearchResult[] results) { + return (QuerySearchResult result) -> { + for (final ReduceableSearchResult r : results) { + r.reduce(result); + } + }; + } + } + + private static class OpaqueQueryCollectorManager extends QueryCollectorManager { + private OpaqueQueryCollectorManager(Collection> managers) { + super(managers); + } + + @Override + protected ReduceableSearchResult reduceWith(final ReduceableSearchResult[] results) { + return (QuerySearchResult result) -> {}; + } + } + + public static CollectorManager createOpaqueCollectorManager( + List> managers + ) throws IOException { + return new OpaqueQueryCollectorManager(managers); + } + + public static CollectorManager createMultiCollectorManager( + List collectors + ) throws IOException { + final Collection> managers = new ArrayList<>(); + + CollectorManager manager = null; + for (QueryCollectorContext ctx : collectors) { + manager = ctx.createManager(manager); + managers.add(manager); + } + + return new QueryCollectorManager(managers); + } + + public static ProfileCollectorManager createQueryCollectorManagerWithProfiler( + List collectors + ) throws IOException { + InternalProfileCollectorManager manager = null; + + for (QueryCollectorContext ctx : collectors) { + manager = ctx.createWithProfiler(manager); + } + + return manager; + } +} diff --git a/server/src/main/java/org/opensearch/search/query/QueryPhase.java b/server/src/main/java/org/opensearch/search/query/QueryPhase.java index 3edbc16cd613f..1501067ec7983 100644 --- a/server/src/main/java/org/opensearch/search/query/QueryPhase.java +++ b/server/src/main/java/org/opensearch/search/query/QueryPhase.java @@ -238,9 +238,9 @@ static boolean executeInternal(SearchContext searchContext, QueryPhaseSearcher q // this collector can filter documents during the collection hasFilterCollector = true; } - if (searchContext.queryCollectors().isEmpty() == false) { + if (searchContext.queryCollectorManagers().isEmpty() == false) { // plug in additional collectors, like aggregations - collectors.add(createMultiCollectorContext(searchContext.queryCollectors().values())); + collectors.add(createMultiCollectorContext(searchContext.queryCollectorManagers().values())); } if (searchContext.minimumScore() != null) { // apply the minimum score after multi collector so we filter aggs as well diff --git a/server/src/main/java/org/opensearch/search/query/ReduceableSearchResult.java b/server/src/main/java/org/opensearch/search/query/ReduceableSearchResult.java new file mode 100644 index 0000000000000..48e8d7198ea3b --- /dev/null +++ b/server/src/main/java/org/opensearch/search/query/ReduceableSearchResult.java @@ -0,0 +1,23 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.query; + +import java.io.IOException; + +/** + * The search result callback returned by reduce phase of the collector manager. + */ +public interface ReduceableSearchResult { + /** + * Apply the reduce operation to the query search results + * @param result query search results + * @throws IOException exception if reduce operation failed + */ + void reduce(QuerySearchResult result) throws IOException; +} diff --git a/server/src/main/java/org/opensearch/search/query/TopDocsCollectorContext.java b/server/src/main/java/org/opensearch/search/query/TopDocsCollectorContext.java index 9cf7dca3c4caf..7319357f11831 100644 --- a/server/src/main/java/org/opensearch/search/query/TopDocsCollectorContext.java +++ b/server/src/main/java/org/opensearch/search/query/TopDocsCollectorContext.java @@ -44,6 +44,7 @@ import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BoostQuery; import org.apache.lucene.search.Collector; +import org.apache.lucene.search.CollectorManager; import org.apache.lucene.search.ConstantScoreQuery; import org.apache.lucene.search.DocValuesFieldExistsQuery; import org.apache.lucene.search.FieldDoc; @@ -80,6 +81,9 @@ import org.opensearch.search.sort.SortAndFormats; import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; import java.util.Objects; import java.util.function.Supplier; @@ -89,7 +93,7 @@ /** * A {@link QueryCollectorContext} that creates top docs collector */ -abstract class TopDocsCollectorContext extends QueryCollectorContext { +public abstract class TopDocsCollectorContext extends QueryCollectorContext { protected final int numHits; TopDocsCollectorContext(String profilerName, int numHits) { @@ -107,7 +111,7 @@ final int numHits() { /** * Returns true if the top docs should be re-scored after initial search */ - boolean shouldRescore() { + public boolean shouldRescore() { return false; } @@ -115,6 +119,8 @@ static class EmptyTopDocsCollectorContext extends TopDocsCollectorContext { private final Sort sort; private final Collector collector; private final Supplier hitCountSupplier; + private final int trackTotalHitsUpTo; + private final int hitCount; /** * Ctr @@ -132,16 +138,18 @@ private EmptyTopDocsCollectorContext( ) throws IOException { super(REASON_SEARCH_COUNT, 0); this.sort = sortAndFormats == null ? null : sortAndFormats.sort; - if (trackTotalHitsUpTo == SearchContext.TRACK_TOTAL_HITS_DISABLED) { + this.trackTotalHitsUpTo = trackTotalHitsUpTo; + if (this.trackTotalHitsUpTo == SearchContext.TRACK_TOTAL_HITS_DISABLED) { this.collector = new EarlyTerminatingCollector(new TotalHitCountCollector(), 0, false); // for bwc hit count is set to 0, it will be converted to -1 by the coordinating node this.hitCountSupplier = () -> new TotalHits(0, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO); + this.hitCount = Integer.MIN_VALUE; } else { TotalHitCountCollector hitCountCollector = new TotalHitCountCollector(); // implicit total hit counts are valid only when there is no filter collector in the chain - int hitCount = hasFilterCollector ? -1 : shortcutTotalHitCount(reader, query); - if (hitCount == -1) { - if (trackTotalHitsUpTo == SearchContext.TRACK_TOTAL_HITS_ACCURATE) { + this.hitCount = hasFilterCollector ? -1 : shortcutTotalHitCount(reader, query); + if (this.hitCount == -1) { + if (this.trackTotalHitsUpTo == SearchContext.TRACK_TOTAL_HITS_ACCURATE) { this.collector = hitCountCollector; this.hitCountSupplier = () -> new TotalHits(hitCountCollector.getTotalHits(), TotalHits.Relation.EQUAL_TO); } else { @@ -159,6 +167,39 @@ private EmptyTopDocsCollectorContext( } } + @Override + CollectorManager createManager(CollectorManager in) throws IOException { + assert in == null; + + CollectorManager manager = null; + + if (trackTotalHitsUpTo == SearchContext.TRACK_TOTAL_HITS_DISABLED) { + manager = new EarlyTerminatingCollectorManager<>( + new TotalHitCountCollectorManager.Empty(new TotalHits(0, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), sort), + 0, + false + ); + } else { + if (hitCount == -1) { + if (trackTotalHitsUpTo != SearchContext.TRACK_TOTAL_HITS_ACCURATE) { + manager = new EarlyTerminatingCollectorManager<>( + new TotalHitCountCollectorManager(sort), + trackTotalHitsUpTo, + false + ); + } + } else { + manager = new EarlyTerminatingCollectorManager<>( + new TotalHitCountCollectorManager.Empty(new TotalHits(hitCount, TotalHits.Relation.EQUAL_TO), sort), + 0, + false + ); + } + } + + return manager; + } + @Override Collector create(Collector in) { assert in == null; @@ -181,7 +222,11 @@ void postProcess(QuerySearchResult result) { static class CollapsingTopDocsCollectorContext extends TopDocsCollectorContext { private final DocValueFormat[] sortFmt; private final CollapsingTopDocsCollector topDocsCollector; + private final Collector collector; private final Supplier maxScoreSupplier; + private final CollapseContext collapseContext; + private final boolean trackMaxScore; + private final Sort sort; /** * Ctr @@ -199,30 +244,94 @@ private CollapsingTopDocsCollectorContext( super(REASON_SEARCH_TOP_HITS, numHits); assert numHits > 0; assert collapseContext != null; - Sort sort = sortAndFormats == null ? Sort.RELEVANCE : sortAndFormats.sort; + this.sort = sortAndFormats == null ? Sort.RELEVANCE : sortAndFormats.sort; this.sortFmt = sortAndFormats == null ? new DocValueFormat[] { DocValueFormat.RAW } : sortAndFormats.formats; + this.collapseContext = collapseContext; this.topDocsCollector = collapseContext.createTopDocs(sort, numHits); + this.trackMaxScore = trackMaxScore; - MaxScoreCollector maxScoreCollector; + MaxScoreCollector maxScoreCollector = null; if (trackMaxScore) { maxScoreCollector = new MaxScoreCollector(); maxScoreSupplier = maxScoreCollector::getMaxScore; } else { + maxScoreCollector = null; maxScoreSupplier = () -> Float.NaN; } + + this.collector = MultiCollector.wrap(topDocsCollector, maxScoreCollector); } @Override Collector create(Collector in) throws IOException { assert in == null; - return topDocsCollector; + return collector; } @Override void postProcess(QuerySearchResult result) throws IOException { - CollapseTopFieldDocs topDocs = topDocsCollector.getTopDocs(); + final CollapseTopFieldDocs topDocs = topDocsCollector.getTopDocs(); result.topDocs(new TopDocsAndMaxScore(topDocs, maxScoreSupplier.get()), sortFmt); } + + @Override + CollectorManager createManager(CollectorManager in) throws IOException { + return new CollectorManager() { + @Override + public Collector newCollector() throws IOException { + MaxScoreCollector maxScoreCollector = null; + + if (trackMaxScore) { + maxScoreCollector = new MaxScoreCollector(); + } + + return MultiCollectorWrapper.wrap(collapseContext.createTopDocs(sort, numHits), maxScoreCollector); + } + + @Override + public ReduceableSearchResult reduce(Collection collectors) throws IOException { + final Collection subs = new ArrayList<>(); + for (final Collector collector : collectors) { + if (collector instanceof MultiCollectorWrapper) { + subs.addAll(((MultiCollectorWrapper) collector).getCollectors()); + } else { + subs.add(collector); + } + } + + final Collection topFieldDocs = new ArrayList(); + float maxScore = Float.NaN; + + for (final Collector collector : subs) { + if (collector instanceof CollapsingTopDocsCollector) { + topFieldDocs.add(((CollapsingTopDocsCollector) collector).getTopDocs()); + } else if (collector instanceof MaxScoreCollector) { + float score = ((MaxScoreCollector) collector).getMaxScore(); + if (Float.isNaN(maxScore)) { + maxScore = score; + } else { + maxScore = Math.max(maxScore, score); + } + } + } + + return reduceWith(topFieldDocs, maxScore); + } + }; + } + + protected ReduceableSearchResult reduceWith(final Collection topFieldDocs, float maxScore) { + return (QuerySearchResult result) -> { + final CollapseTopFieldDocs topDocs = CollapseTopFieldDocs.merge( + sort, + 0, + numHits, + topFieldDocs.toArray(new CollapseTopFieldDocs[0]), + true + ); + result.topDocs(new TopDocsAndMaxScore(topDocs, maxScore), sortFmt); + }; + } } abstract static class SimpleTopDocsCollectorContext extends TopDocsCollectorContext { @@ -240,11 +349,38 @@ private static TopDocsCollector createCollector( } } + private static CollectorManager, ? extends TopDocs> createCollectorManager( + @Nullable SortAndFormats sortAndFormats, + int numHits, + @Nullable ScoreDoc searchAfter, + int hitCountThreshold + ) { + if (sortAndFormats == null) { + // See please https://github.com/apache/lucene/pull/450, should be fixed in 9.x + if (searchAfter != null) { + return TopScoreDocCollector.createSharedManager( + numHits, + new FieldDoc(searchAfter.doc, searchAfter.score), + hitCountThreshold + ); + } else { + return TopScoreDocCollector.createSharedManager(numHits, null, hitCountThreshold); + } + } else { + return TopFieldCollector.createSharedManager(sortAndFormats.sort, numHits, (FieldDoc) searchAfter, hitCountThreshold); + } + } + protected final @Nullable SortAndFormats sortAndFormats; private final Collector collector; private final Supplier totalHitsSupplier; private final Supplier topDocsSupplier; private final Supplier maxScoreSupplier; + private final ScoreDoc searchAfter; + private final int trackTotalHitsUpTo; + private final boolean trackMaxScore; + private final boolean hasInfMaxScore; + private final int hitCount; /** * Ctr @@ -269,24 +405,30 @@ private SimpleTopDocsCollectorContext( ) throws IOException { super(REASON_SEARCH_TOP_HITS, numHits); this.sortAndFormats = sortAndFormats; + this.searchAfter = searchAfter; + this.trackTotalHitsUpTo = trackTotalHitsUpTo; + this.trackMaxScore = trackMaxScore; + this.hasInfMaxScore = hasInfMaxScore(query); final TopDocsCollector topDocsCollector; - if ((sortAndFormats == null || SortField.FIELD_SCORE.equals(sortAndFormats.sort.getSort()[0])) && hasInfMaxScore(query)) { + if ((sortAndFormats == null || SortField.FIELD_SCORE.equals(sortAndFormats.sort.getSort()[0])) && hasInfMaxScore) { // disable max score optimization since we have a mandatory clause // that doesn't track the maximum score topDocsCollector = createCollector(sortAndFormats, numHits, searchAfter, Integer.MAX_VALUE); topDocsSupplier = new CachedSupplier<>(topDocsCollector::topDocs); totalHitsSupplier = () -> topDocsSupplier.get().totalHits; + hitCount = Integer.MIN_VALUE; } else if (trackTotalHitsUpTo == SearchContext.TRACK_TOTAL_HITS_DISABLED) { // don't compute hit counts via the collector topDocsCollector = createCollector(sortAndFormats, numHits, searchAfter, 1); topDocsSupplier = new CachedSupplier<>(topDocsCollector::topDocs); totalHitsSupplier = () -> new TotalHits(0, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO); + hitCount = -1; } else { // implicit total hit counts are valid only when there is no filter collector in the chain - final int hitCount = hasFilterCollector ? -1 : shortcutTotalHitCount(reader, query); - if (hitCount == -1) { + this.hitCount = hasFilterCollector ? -1 : shortcutTotalHitCount(reader, query); + if (this.hitCount == -1) { topDocsCollector = createCollector(sortAndFormats, numHits, searchAfter, trackTotalHitsUpTo); topDocsSupplier = new CachedSupplier<>(topDocsCollector::topDocs); totalHitsSupplier = () -> topDocsSupplier.get().totalHits; @@ -294,7 +436,7 @@ private SimpleTopDocsCollectorContext( // don't compute hit counts via the collector topDocsCollector = createCollector(sortAndFormats, numHits, searchAfter, 1); topDocsSupplier = new CachedSupplier<>(topDocsCollector::topDocs); - totalHitsSupplier = () -> new TotalHits(hitCount, TotalHits.Relation.EQUAL_TO); + totalHitsSupplier = () -> new TotalHits(this.hitCount, TotalHits.Relation.EQUAL_TO); } } MaxScoreCollector maxScoreCollector = null; @@ -315,7 +457,98 @@ private SimpleTopDocsCollectorContext( } this.collector = MultiCollector.wrap(topDocsCollector, maxScoreCollector); + } + + private class SimpleTopDocsCollectorManager + implements + CollectorManager, + EarlyTerminatingListener { + private Integer terminatedAfter; + private final CollectorManager, ? extends TopDocs> manager; + + private SimpleTopDocsCollectorManager() { + if ((sortAndFormats == null || SortField.FIELD_SCORE.equals(sortAndFormats.sort.getSort()[0])) && hasInfMaxScore) { + // disable max score optimization since we have a mandatory clause + // that doesn't track the maximum score + manager = createCollectorManager(sortAndFormats, numHits, searchAfter, Integer.MAX_VALUE); + } else if (trackTotalHitsUpTo == SearchContext.TRACK_TOTAL_HITS_DISABLED) { + // don't compute hit counts via the collector + manager = createCollectorManager(sortAndFormats, numHits, searchAfter, 1); + } else { + // implicit total hit counts are valid only when there is no filter collector in the chain + if (hitCount == -1) { + manager = createCollectorManager(sortAndFormats, numHits, searchAfter, trackTotalHitsUpTo); + } else { + // don't compute hit counts via the collector + manager = createCollectorManager(sortAndFormats, numHits, searchAfter, 1); + } + } + } + + @Override + public void onEarlyTemination(int maxCountHits, boolean forcedTermination) { + terminatedAfter = maxCountHits; + } + + @Override + public Collector newCollector() throws IOException { + MaxScoreCollector maxScoreCollector = null; + + if (sortAndFormats != null && trackMaxScore) { + maxScoreCollector = new MaxScoreCollector(); + } + + return MultiCollectorWrapper.wrap(manager.newCollector(), maxScoreCollector); + } + + @SuppressWarnings("unchecked") + @Override + public ReduceableSearchResult reduce(Collection collectors) throws IOException { + final Collection> topDocsCollectors = new ArrayList<>(); + final Collection maxScoreCollectors = new ArrayList<>(); + + for (final Collector collector : collectors) { + if (collector instanceof MultiCollectorWrapper) { + for (final Collector sub : (((MultiCollectorWrapper) collector).getCollectors())) { + if (sub instanceof TopDocsCollector) { + topDocsCollectors.add((TopDocsCollector) sub); + } else if (sub instanceof MaxScoreCollector) { + maxScoreCollectors.add((MaxScoreCollector) sub); + } + } + } else if (collector instanceof TopDocsCollector) { + topDocsCollectors.add((TopDocsCollector) collector); + } else if (collector instanceof MaxScoreCollector) { + maxScoreCollectors.add((MaxScoreCollector) collector); + } + } + + float maxScore = Float.NaN; + for (final MaxScoreCollector collector : maxScoreCollectors) { + float score = collector.getMaxScore(); + if (Float.isNaN(maxScore)) { + maxScore = score; + } else { + maxScore = Math.max(maxScore, score); + } + } + final TopDocs topDocs = ((CollectorManager, ? extends TopDocs>) manager).reduce(topDocsCollectors); + return reduceWith(topDocs, maxScore, terminatedAfter); + } + } + + @Override + CollectorManager createManager(CollectorManager in) throws IOException { + assert in == null; + return new SimpleTopDocsCollectorManager(); + } + + protected ReduceableSearchResult reduceWith(final TopDocs topDocs, final float maxScore, final Integer terminatedAfter) { + return (QuerySearchResult result) -> { + final TopDocsAndMaxScore topDocsAndMaxScore = newTopDocs(topDocs, maxScore, terminatedAfter); + result.topDocs(topDocsAndMaxScore, sortAndFormats == null ? null : sortAndFormats.formats); + }; } @Override @@ -324,6 +557,50 @@ Collector create(Collector in) { return collector; } + TopDocsAndMaxScore newTopDocs(final TopDocs topDocs, final float maxScore, final Integer terminatedAfter) { + TotalHits totalHits = null; + + if ((sortAndFormats == null || SortField.FIELD_SCORE.equals(sortAndFormats.sort.getSort()[0])) && hasInfMaxScore) { + totalHits = topDocs.totalHits; + } else if (trackTotalHitsUpTo == SearchContext.TRACK_TOTAL_HITS_DISABLED) { + // don't compute hit counts via the collector + totalHits = new TotalHits(0, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO); + } else { + if (hitCount == -1) { + totalHits = topDocs.totalHits; + } else { + totalHits = new TotalHits(hitCount, TotalHits.Relation.EQUAL_TO); + } + } + + // Since we cannot support early forced termination, we have to simulate it by + // artificially reducing the number of total hits and doc scores. + ScoreDoc[] scoreDocs = topDocs.scoreDocs; + if (terminatedAfter != null) { + if (totalHits.value > terminatedAfter) { + totalHits = new TotalHits(terminatedAfter, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO); + } + + if (scoreDocs != null && scoreDocs.length > terminatedAfter) { + scoreDocs = Arrays.copyOf(scoreDocs, terminatedAfter); + } + } + + final TopDocs newTopDocs; + if (topDocs instanceof TopFieldDocs) { + TopFieldDocs fieldDocs = (TopFieldDocs) topDocs; + newTopDocs = new TopFieldDocs(totalHits, scoreDocs, fieldDocs.fields); + } else { + newTopDocs = new TopDocs(totalHits, scoreDocs); + } + + if (Float.isNaN(maxScore) && newTopDocs.scoreDocs.length > 0 && sortAndFormats == null) { + return new TopDocsAndMaxScore(newTopDocs, newTopDocs.scoreDocs[0].score); + } else { + return new TopDocsAndMaxScore(newTopDocs, maxScore); + } + } + TopDocsAndMaxScore newTopDocs() { TopDocs in = topDocsSupplier.get(); float maxScore = maxScoreSupplier.get(); @@ -373,6 +650,35 @@ private ScrollingTopDocsCollectorContext( this.numberOfShards = numberOfShards; } + @Override + protected ReduceableSearchResult reduceWith(final TopDocs topDocs, final float maxScore, final Integer terminatedAfter) { + return (QuerySearchResult result) -> { + final TopDocsAndMaxScore topDocsAndMaxScore = newTopDocs(topDocs, maxScore, terminatedAfter); + + if (scrollContext.totalHits == null) { + // first round + scrollContext.totalHits = topDocsAndMaxScore.topDocs.totalHits; + scrollContext.maxScore = topDocsAndMaxScore.maxScore; + } else { + // subsequent round: the total number of hits and + // the maximum score were computed on the first round + topDocsAndMaxScore.topDocs.totalHits = scrollContext.totalHits; + topDocsAndMaxScore.maxScore = scrollContext.maxScore; + } + + if (numberOfShards == 1) { + // if we fetch the document in the same roundtrip, we already know the last emitted doc + if (topDocsAndMaxScore.topDocs.scoreDocs.length > 0) { + // set the last emitted doc + scrollContext.lastEmittedDoc = topDocsAndMaxScore.topDocs.scoreDocs[topDocsAndMaxScore.topDocs.scoreDocs.length + - 1]; + } + } + + result.topDocs(topDocsAndMaxScore, sortAndFormats == null ? null : sortAndFormats.formats); + }; + } + @Override void postProcess(QuerySearchResult result) throws IOException { final TopDocsAndMaxScore topDocs = newTopDocs(); @@ -457,7 +763,7 @@ static int shortcutTotalHitCount(IndexReader reader, Query query) throws IOExcep * Creates a {@link TopDocsCollectorContext} from the provided searchContext. * @param hasFilterCollector True if the collector chain contains at least one collector that can filters document. */ - static TopDocsCollectorContext createTopDocsCollectorContext(SearchContext searchContext, boolean hasFilterCollector) + public static TopDocsCollectorContext createTopDocsCollectorContext(SearchContext searchContext, boolean hasFilterCollector) throws IOException { final IndexReader reader = searchContext.searcher().getIndexReader(); final Query query = searchContext.query(); @@ -515,7 +821,7 @@ static TopDocsCollectorContext createTopDocsCollectorContext(SearchContext searc hasFilterCollector ) { @Override - boolean shouldRescore() { + public boolean shouldRescore() { return rescore; } }; diff --git a/server/src/main/java/org/opensearch/search/query/TotalHitCountCollectorManager.java b/server/src/main/java/org/opensearch/search/query/TotalHitCountCollectorManager.java new file mode 100644 index 0000000000000..8f50f7b693df1 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/query/TotalHitCountCollectorManager.java @@ -0,0 +1,106 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.query; + +import org.apache.lucene.search.CollectorManager; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.Sort; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TopFieldDocs; +import org.apache.lucene.search.TotalHitCountCollector; +import org.apache.lucene.search.TotalHits; +import org.opensearch.common.lucene.Lucene; +import org.opensearch.common.lucene.search.TopDocsAndMaxScore; + +import java.io.IOException; +import java.util.Collection; + +public class TotalHitCountCollectorManager + implements + CollectorManager, + EarlyTerminatingListener { + + private static final TotalHitCountCollector EMPTY_COLLECTOR = new TotalHitCountCollector() { + @Override + public void collect(int doc) {} + + @Override + public ScoreMode scoreMode() { + return ScoreMode.COMPLETE_NO_SCORES; + } + }; + + private final Sort sort; + private Integer terminatedAfter; + + public TotalHitCountCollectorManager(final Sort sort) { + this.sort = sort; + } + + @Override + public void onEarlyTemination(int maxCountHits, boolean forcedTermination) { + terminatedAfter = maxCountHits; + } + + @Override + public TotalHitCountCollector newCollector() throws IOException { + return new TotalHitCountCollector(); + } + + @Override + public ReduceableSearchResult reduce(Collection collectors) throws IOException { + return (QuerySearchResult result) -> { + final TotalHits.Relation relation = (terminatedAfter != null) + ? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO + : TotalHits.Relation.EQUAL_TO; + + int totalHits = collectors.stream().mapToInt(TotalHitCountCollector::getTotalHits).sum(); + if (terminatedAfter != null && totalHits > terminatedAfter) { + totalHits = terminatedAfter; + } + + final TotalHits totalHitCount = new TotalHits(totalHits, relation); + final TopDocs topDocs = (sort != null) + ? new TopFieldDocs(totalHitCount, Lucene.EMPTY_SCORE_DOCS, sort.getSort()) + : new TopDocs(totalHitCount, Lucene.EMPTY_SCORE_DOCS); + + result.topDocs(new TopDocsAndMaxScore(topDocs, Float.NaN), null); + }; + } + + static class Empty implements CollectorManager { + private final TotalHits totalHits; + private final Sort sort; + + Empty(final TotalHits totalHits, final Sort sort) { + this.totalHits = totalHits; + this.sort = sort; + } + + @Override + public TotalHitCountCollector newCollector() throws IOException { + return EMPTY_COLLECTOR; + } + + @Override + public ReduceableSearchResult reduce(Collection collectors) throws IOException { + return (QuerySearchResult result) -> { + final TopDocs topDocs; + + if (sort != null) { + topDocs = new TopFieldDocs(totalHits, Lucene.EMPTY_SCORE_DOCS, sort.getSort()); + } else { + topDocs = new TopDocs(totalHits, Lucene.EMPTY_SCORE_DOCS); + } + + result.topDocs(new TopDocsAndMaxScore(topDocs, Float.NaN), null); + }; + } + } +} diff --git a/server/src/test/java/org/opensearch/search/DefaultSearchContextTests.java b/server/src/test/java/org/opensearch/search/DefaultSearchContextTests.java index e1cf74bdd6aeb..f6ca12f1c514c 100644 --- a/server/src/test/java/org/opensearch/search/DefaultSearchContextTests.java +++ b/server/src/test/java/org/opensearch/search/DefaultSearchContextTests.java @@ -32,6 +32,8 @@ package org.opensearch.search; +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + import org.apache.lucene.index.IndexReader; import org.apache.lucene.tests.index.RandomIndexWriter; import org.apache.lucene.search.IndexSearcher; @@ -76,7 +78,12 @@ import org.opensearch.threadpool.ThreadPool; import java.io.IOException; +import java.util.Arrays; +import java.util.Collection; import java.util.UUID; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; import java.util.function.Function; import java.util.function.Supplier; @@ -91,6 +98,25 @@ import static org.mockito.Mockito.when; public class DefaultSearchContextTests extends OpenSearchTestCase { + private final ExecutorService executor; + + @ParametersFactory + public static Collection concurrency() { + return Arrays.asList(new Integer[] { 0 }, new Integer[] { 5 }); + } + + public DefaultSearchContextTests(int concurrency) { + this.executor = (concurrency > 0) ? Executors.newFixedThreadPool(concurrency) : null; + } + + @Override + public void tearDown() throws Exception { + super.tearDown(); + + if (executor != null) { + ThreadPool.terminate(executor, 10, TimeUnit.SECONDS); + } + } public void testPreProcess() throws Exception { TimeValue timeout = new TimeValue(randomIntBetween(1, 100)); @@ -183,7 +209,7 @@ protected Engine.Searcher acquireSearcherInternal(String source) { false, Version.CURRENT, false, - null + executor ); contextWithoutScroll.from(300); contextWithoutScroll.close(); @@ -225,7 +251,7 @@ protected Engine.Searcher acquireSearcherInternal(String source) { false, Version.CURRENT, false, - null + executor ); context1.from(300); exception = expectThrows(IllegalArgumentException.class, () -> context1.preProcess(false)); @@ -295,7 +321,7 @@ protected Engine.Searcher acquireSearcherInternal(String source) { false, Version.CURRENT, false, - null + executor ); SliceBuilder sliceBuilder = mock(SliceBuilder.class); @@ -334,7 +360,7 @@ protected Engine.Searcher acquireSearcherInternal(String source) { false, Version.CURRENT, false, - null + executor ); ParsedQuery parsedQuery = ParsedQuery.parsedMatchAllQuery(); context3.sliceBuilder(null).parsedQuery(parsedQuery).preProcess(false); @@ -365,7 +391,7 @@ protected Engine.Searcher acquireSearcherInternal(String source) { false, Version.CURRENT, false, - null + executor ); context4.sliceBuilder(new SliceBuilder(1, 2)).parsedQuery(parsedQuery).preProcess(false); Query query1 = context4.query(); @@ -446,7 +472,7 @@ protected Engine.Searcher acquireSearcherInternal(String source) { false, Version.CURRENT, false, - null + executor ); assertThat(context.searcher().hasCancellations(), is(false)); context.searcher().addQueryCancellation(() -> {}); diff --git a/server/src/test/java/org/opensearch/search/SearchCancellationTests.java b/server/src/test/java/org/opensearch/search/SearchCancellationTests.java index 1927558f94094..f479f3a1b99f1 100644 --- a/server/src/test/java/org/opensearch/search/SearchCancellationTests.java +++ b/server/src/test/java/org/opensearch/search/SearchCancellationTests.java @@ -108,7 +108,8 @@ public void testAddingCancellationActions() throws IOException { IndexSearcher.getDefaultSimilarity(), IndexSearcher.getDefaultQueryCache(), IndexSearcher.getDefaultQueryCachingPolicy(), - true + true, + null ); NullPointerException npe = expectThrows(NullPointerException.class, () -> searcher.addQueryCancellation(null)); assertEquals("cancellation runnable should not be null", npe.getMessage()); @@ -127,7 +128,8 @@ public void testCancellableCollector() throws IOException { IndexSearcher.getDefaultSimilarity(), IndexSearcher.getDefaultQueryCache(), IndexSearcher.getDefaultQueryCachingPolicy(), - true + true, + null ); searcher.search(new MatchAllDocsQuery(), collector1); @@ -154,7 +156,8 @@ public void testExitableDirectoryReader() throws IOException { IndexSearcher.getDefaultSimilarity(), IndexSearcher.getDefaultQueryCache(), IndexSearcher.getDefaultQueryCachingPolicy(), - true + true, + null ); searcher.addQueryCancellation(cancellation); CompiledAutomaton automaton = new CompiledAutomaton(new RegExp("a.*").toAutomaton()); diff --git a/server/src/test/java/org/opensearch/search/internal/ContextIndexSearcherTests.java b/server/src/test/java/org/opensearch/search/internal/ContextIndexSearcherTests.java index de0a31b9dc04b..eb7dde4b0b2ce 100644 --- a/server/src/test/java/org/opensearch/search/internal/ContextIndexSearcherTests.java +++ b/server/src/test/java/org/opensearch/search/internal/ContextIndexSearcherTests.java @@ -258,7 +258,8 @@ public void onRemoval(ShardId shardId, Accountable accountable) { IndexSearcher.getDefaultSimilarity(), IndexSearcher.getDefaultQueryCache(), IndexSearcher.getDefaultQueryCachingPolicy(), - true + true, + null ); for (LeafReaderContext context : searcher.getIndexReader().leaves()) { diff --git a/server/src/test/java/org/opensearch/search/profile/query/QueryProfilerTests.java b/server/src/test/java/org/opensearch/search/profile/query/QueryProfilerTests.java index afaab15e1431e..7f4dcdaed2aa1 100644 --- a/server/src/test/java/org/opensearch/search/profile/query/QueryProfilerTests.java +++ b/server/src/test/java/org/opensearch/search/profile/query/QueryProfilerTests.java @@ -32,8 +32,6 @@ package org.opensearch.search.profile.query; -import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; - import org.apache.lucene.document.Document; import org.apache.lucene.document.Field.Store; import org.apache.lucene.document.StringField; @@ -64,18 +62,12 @@ import org.opensearch.search.internal.ContextIndexSearcher; import org.opensearch.search.profile.ProfileResult; import org.opensearch.test.OpenSearchTestCase; -import org.opensearch.threadpool.ThreadPool; import org.junit.After; import org.junit.Before; import java.io.IOException; -import java.util.Arrays; -import java.util.Collection; import java.util.List; import java.util.Map; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; -import java.util.concurrent.TimeUnit; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; @@ -85,16 +77,6 @@ public class QueryProfilerTests extends OpenSearchTestCase { private Directory dir; private IndexReader reader; private ContextIndexSearcher searcher; - private ExecutorService executor; - - @ParametersFactory - public static Collection concurrency() { - return Arrays.asList(new Integer[] { 0 }, new Integer[] { 5 }); - } - - public QueryProfilerTests(int concurrency) { - this.executor = (concurrency > 0) ? Executors.newFixedThreadPool(concurrency) : null; - } @Before public void setUp() throws Exception { @@ -120,7 +102,7 @@ public void setUp() throws Exception { IndexSearcher.getDefaultQueryCache(), ALWAYS_CACHE_POLICY, true, - executor + null ); } @@ -134,10 +116,6 @@ public void tearDown() throws Exception { assertThat(cache.getTotalCount(), equalTo(cache.getMissCount())); assertThat(cache.getCacheSize(), equalTo(0L)); - if (executor != null) { - ThreadPool.terminate(executor, 10, TimeUnit.SECONDS); - } - IOUtils.close(reader, dir); dir = null; reader = null; @@ -145,7 +123,7 @@ public void tearDown() throws Exception { } public void testBasic() throws IOException { - QueryProfiler profiler = new QueryProfiler(searcher.allowConcurrentSegmentSearch()); + QueryProfiler profiler = new QueryProfiler(false); searcher.setProfiler(profiler); Query query = new TermQuery(new Term("foo", "bar")); searcher.search(query, 1); @@ -171,7 +149,7 @@ public void testBasic() throws IOException { } public void testNoScoring() throws IOException { - QueryProfiler profiler = new QueryProfiler(searcher.allowConcurrentSegmentSearch()); + QueryProfiler profiler = new QueryProfiler(false); searcher.setProfiler(profiler); Query query = new TermQuery(new Term("foo", "bar")); searcher.search(query, 1, Sort.INDEXORDER); // scores are not needed @@ -197,7 +175,7 @@ public void testNoScoring() throws IOException { } public void testUseIndexStats() throws IOException { - QueryProfiler profiler = new QueryProfiler(searcher.allowConcurrentSegmentSearch()); + QueryProfiler profiler = new QueryProfiler(false); searcher.setProfiler(profiler); Query query = new TermQuery(new Term("foo", "bar")); searcher.count(query); // will use index stats @@ -211,7 +189,7 @@ public void testUseIndexStats() throws IOException { } public void testApproximations() throws IOException { - QueryProfiler profiler = new QueryProfiler(searcher.allowConcurrentSegmentSearch()); + QueryProfiler profiler = new QueryProfiler(false); searcher.setProfiler(profiler); Query query = new RandomApproximationQuery(new TermQuery(new Term("foo", "bar")), random()); searcher.count(query); diff --git a/server/src/test/java/org/opensearch/search/query/QueryPhaseTests.java b/server/src/test/java/org/opensearch/search/query/QueryPhaseTests.java index b87c11dce5be2..1232347edea64 100644 --- a/server/src/test/java/org/opensearch/search/query/QueryPhaseTests.java +++ b/server/src/test/java/org/opensearch/search/query/QueryPhaseTests.java @@ -39,6 +39,7 @@ import org.apache.lucene.document.LatLonPoint; import org.apache.lucene.document.LongPoint; import org.apache.lucene.document.NumericDocValuesField; +import org.apache.lucene.document.SortedDocValuesField; import org.apache.lucene.document.SortedSetDocValuesField; import org.apache.lucene.document.StringField; import org.apache.lucene.document.TextField; @@ -77,6 +78,7 @@ import org.apache.lucene.search.TotalHitCountCollector; import org.apache.lucene.search.TotalHits; import org.apache.lucene.search.Weight; +import org.apache.lucene.search.grouping.CollapseTopFieldDocs; import org.apache.lucene.search.join.BitSetProducer; import org.apache.lucene.search.join.ScoreMode; import org.apache.lucene.store.Directory; @@ -88,12 +90,15 @@ import org.opensearch.index.mapper.MappedFieldType; import org.opensearch.index.mapper.MapperService; import org.opensearch.index.mapper.NumberFieldMapper; +import org.opensearch.index.mapper.NumberFieldMapper.NumberFieldType; +import org.opensearch.index.mapper.NumberFieldMapper.NumberType; import org.opensearch.index.query.ParsedQuery; import org.opensearch.index.query.QueryShardContext; import org.opensearch.index.search.OpenSearchToParentBlockJoinQuery; import org.opensearch.index.shard.IndexShard; import org.opensearch.index.shard.IndexShardTestCase; import org.opensearch.search.DocValueFormat; +import org.opensearch.search.collapse.CollapseBuilder; import org.opensearch.search.internal.ContextIndexSearcher; import org.opensearch.search.internal.ScrollContext; import org.opensearch.search.internal.SearchContext; @@ -144,7 +149,7 @@ private void countTestCase(Query query, IndexReader reader, boolean shouldCollec context.parsedQuery(new ParsedQuery(query)); context.setSize(0); context.setTask(new SearchShardTask(123L, "", "", "", null, Collections.emptyMap())); - final boolean rescore = QueryPhase.executeInternal(context); + final boolean rescore = QueryPhase.executeInternal(context.withCleanQueryResult()); assertFalse(rescore); ContextIndexSearcher countSearcher = shouldCollectCount @@ -157,7 +162,7 @@ private void countTestCase(boolean withDeletions) throws Exception { Directory dir = newDirectory(); IndexWriterConfig iwc = newIndexWriterConfig().setMergePolicy(NoMergePolicy.INSTANCE); RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc); - final int numDocs = scaledRandomIntBetween(100, 200); + final int numDocs = scaledRandomIntBetween(600, 900); for (int i = 0; i < numDocs; ++i) { Document doc = new Document(); if (randomBoolean()) { @@ -228,12 +233,12 @@ public void testPostFilterDisablesCountOptimization() throws Exception { context.setTask(new SearchShardTask(123L, "", "", "", null, Collections.emptyMap())); context.parsedQuery(new ParsedQuery(new MatchAllDocsQuery())); - QueryPhase.executeInternal(context); + QueryPhase.executeInternal(context.withCleanQueryResult()); assertEquals(1, context.queryResult().topDocs().topDocs.totalHits.value); context.setSearcher(newContextSearcher(reader)); context.parsedPostFilter(new ParsedQuery(new MatchNoDocsQuery())); - QueryPhase.executeInternal(context); + QueryPhase.executeInternal(context.withCleanQueryResult()); assertEquals(0, context.queryResult().topDocs().topDocs.totalHits.value); reader.close(); dir.close(); @@ -261,7 +266,7 @@ public void testTerminateAfterWithFilter() throws Exception { context.setSize(10); for (int i = 0; i < 10; i++) { context.parsedPostFilter(new ParsedQuery(new TermQuery(new Term("foo", Integer.toString(i))))); - QueryPhase.executeInternal(context); + QueryPhase.executeInternal(context.withCleanQueryResult()); assertEquals(1, context.queryResult().topDocs().topDocs.totalHits.value); assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(1)); } @@ -283,12 +288,13 @@ public void testMinScoreDisablesCountOptimization() throws Exception { context.parsedQuery(new ParsedQuery(new MatchAllDocsQuery())); context.setSize(0); context.setTask(new SearchShardTask(123L, "", "", "", null, Collections.emptyMap())); - QueryPhase.executeInternal(context); + QueryPhase.executeInternal(context.withCleanQueryResult()); assertEquals(1, context.queryResult().topDocs().topDocs.totalHits.value); context.minimumScore(100); - QueryPhase.executeInternal(context); + QueryPhase.executeInternal(context.withCleanQueryResult()); assertEquals(0, context.queryResult().topDocs().topDocs.totalHits.value); + assertEquals(TotalHits.Relation.EQUAL_TO, context.queryResult().topDocs().topDocs.totalHits.relation); reader.close(); dir.close(); } @@ -297,7 +303,7 @@ public void testQueryCapturesThreadPoolStats() throws Exception { Directory dir = newDirectory(); IndexWriterConfig iwc = newIndexWriterConfig(); RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc); - final int numDocs = scaledRandomIntBetween(100, 200); + final int numDocs = scaledRandomIntBetween(600, 900); for (int i = 0; i < numDocs; ++i) { w.addDocument(new Document()); } @@ -307,7 +313,7 @@ public void testQueryCapturesThreadPoolStats() throws Exception { context.setTask(new SearchShardTask(123L, "", "", "", null, Collections.emptyMap())); context.parsedQuery(new ParsedQuery(new MatchAllDocsQuery())); - QueryPhase.executeInternal(context); + QueryPhase.executeInternal(context.withCleanQueryResult()); QuerySearchResult results = context.queryResult(); assertThat(results.serviceTimeEWMA(), greaterThanOrEqualTo(0L)); assertThat(results.nodeQueueSize(), greaterThanOrEqualTo(0)); @@ -320,7 +326,7 @@ public void testInOrderScrollOptimization() throws Exception { final Sort sort = new Sort(new SortField("rank", SortField.Type.INT)); IndexWriterConfig iwc = newIndexWriterConfig().setIndexSort(sort); RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc); - final int numDocs = scaledRandomIntBetween(100, 200); + final int numDocs = scaledRandomIntBetween(600, 900); for (int i = 0; i < numDocs; ++i) { w.addDocument(new Document()); } @@ -336,14 +342,14 @@ public void testInOrderScrollOptimization() throws Exception { int size = randomIntBetween(2, 5); context.setSize(size); - QueryPhase.executeInternal(context); + QueryPhase.executeInternal(context.withCleanQueryResult()); assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo((long) numDocs)); assertNull(context.queryResult().terminatedEarly()); assertThat(context.terminateAfter(), equalTo(0)); assertThat(context.queryResult().getTotalHits().value, equalTo((long) numDocs)); context.setSearcher(newEarlyTerminationContextSearcher(reader, size)); - QueryPhase.executeInternal(context); + QueryPhase.executeInternal(context.withCleanQueryResult()); assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo((long) numDocs)); assertThat(context.terminateAfter(), equalTo(size)); assertThat(context.queryResult().getTotalHits().value, equalTo((long) numDocs)); @@ -356,7 +362,7 @@ public void testTerminateAfterEarlyTermination() throws Exception { Directory dir = newDirectory(); IndexWriterConfig iwc = newIndexWriterConfig(); RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc); - final int numDocs = scaledRandomIntBetween(100, 200); + final int numDocs = scaledRandomIntBetween(600, 900); for (int i = 0; i < numDocs; ++i) { Document doc = new Document(); if (randomBoolean()) { @@ -377,25 +383,25 @@ public void testTerminateAfterEarlyTermination() throws Exception { context.terminateAfter(numDocs); { context.setSize(10); - TotalHitCountCollector collector = new TotalHitCountCollector(); - context.queryCollectors().put(TotalHitCountCollector.class, collector); - QueryPhase.executeInternal(context); + final TestTotalHitCountCollectorManager manager = TestTotalHitCountCollectorManager.create(); + context.queryCollectorManagers().put(TotalHitCountCollector.class, manager); + QueryPhase.executeInternal(context.withCleanQueryResult()); assertFalse(context.queryResult().terminatedEarly()); assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo((long) numDocs)); assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(10)); - assertThat(collector.getTotalHits(), equalTo(numDocs)); + assertThat(manager.getTotalHits(), equalTo(numDocs)); } context.terminateAfter(1); { context.setSize(1); - QueryPhase.executeInternal(context); + QueryPhase.executeInternal(context.withCleanQueryResult()); assertTrue(context.queryResult().terminatedEarly()); assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo(1L)); assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(1)); context.setSize(0); - QueryPhase.executeInternal(context); + QueryPhase.executeInternal(context.withCleanQueryResult()); assertTrue(context.queryResult().terminatedEarly()); assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo(1L)); assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(0)); @@ -403,7 +409,7 @@ public void testTerminateAfterEarlyTermination() throws Exception { { context.setSize(1); - QueryPhase.executeInternal(context); + QueryPhase.executeInternal(context.withCleanQueryResult()); assertTrue(context.queryResult().terminatedEarly()); assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo(1L)); assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(1)); @@ -414,38 +420,38 @@ public void testTerminateAfterEarlyTermination() throws Exception { .add(new TermQuery(new Term("foo", "baz")), Occur.SHOULD) .build(); context.parsedQuery(new ParsedQuery(bq)); - QueryPhase.executeInternal(context); + QueryPhase.executeInternal(context.withCleanQueryResult()); assertTrue(context.queryResult().terminatedEarly()); assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo(1L)); assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(1)); context.setSize(0); context.parsedQuery(new ParsedQuery(bq)); - QueryPhase.executeInternal(context); + QueryPhase.executeInternal(context.withCleanQueryResult()); assertTrue(context.queryResult().terminatedEarly()); assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo(1L)); assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(0)); } { context.setSize(1); - TotalHitCountCollector collector = new TotalHitCountCollector(); - context.queryCollectors().put(TotalHitCountCollector.class, collector); - QueryPhase.executeInternal(context); + final TestTotalHitCountCollectorManager manager = TestTotalHitCountCollectorManager.create(); + context.queryCollectorManagers().put(TotalHitCountCollector.class, manager); + QueryPhase.executeInternal(context.withCleanQueryResult()); assertTrue(context.queryResult().terminatedEarly()); assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo(1L)); assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(1)); - assertThat(collector.getTotalHits(), equalTo(1)); - context.queryCollectors().clear(); + assertThat(manager.getTotalHits(), equalTo(1)); + context.queryCollectorManagers().clear(); } { context.setSize(0); - TotalHitCountCollector collector = new TotalHitCountCollector(); - context.queryCollectors().put(TotalHitCountCollector.class, collector); - QueryPhase.executeInternal(context); + final TestTotalHitCountCollectorManager manager = TestTotalHitCountCollectorManager.create(); + context.queryCollectorManagers().put(TotalHitCountCollector.class, manager); + QueryPhase.executeInternal(context.withCleanQueryResult()); assertTrue(context.queryResult().terminatedEarly()); assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo(1L)); assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(0)); - assertThat(collector.getTotalHits(), equalTo(1)); + assertThat(manager.getTotalHits(), equalTo(1)); } // tests with trackTotalHits and terminateAfter @@ -453,9 +459,9 @@ public void testTerminateAfterEarlyTermination() throws Exception { context.setSize(0); for (int trackTotalHits : new int[] { -1, 3, 76, 100 }) { context.trackTotalHitsUpTo(trackTotalHits); - TotalHitCountCollector collector = new TotalHitCountCollector(); - context.queryCollectors().put(TotalHitCountCollector.class, collector); - QueryPhase.executeInternal(context); + final TestTotalHitCountCollectorManager manager = TestTotalHitCountCollectorManager.create(); + context.queryCollectorManagers().put(TotalHitCountCollector.class, manager); + QueryPhase.executeInternal(context.withCleanQueryResult()); assertTrue(context.queryResult().terminatedEarly()); if (trackTotalHits == -1) { assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo(0L)); @@ -463,16 +469,14 @@ public void testTerminateAfterEarlyTermination() throws Exception { assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo((long) Math.min(trackTotalHits, 10))); } assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(0)); - assertThat(collector.getTotalHits(), equalTo(10)); + assertThat(manager.getTotalHits(), equalTo(10)); } context.terminateAfter(7); context.setSize(10); for (int trackTotalHits : new int[] { -1, 3, 75, 100 }) { context.trackTotalHitsUpTo(trackTotalHits); - EarlyTerminatingCollector collector = new EarlyTerminatingCollector(new TotalHitCountCollector(), 1, false); - context.queryCollectors().put(EarlyTerminatingCollector.class, collector); - QueryPhase.executeInternal(context); + QueryPhase.executeInternal(context.withCleanQueryResult()); assertTrue(context.queryResult().terminatedEarly()); if (trackTotalHits == -1) { assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo(0L)); @@ -490,7 +494,7 @@ public void testIndexSortingEarlyTermination() throws Exception { final Sort sort = new Sort(new SortField("rank", SortField.Type.INT)); IndexWriterConfig iwc = newIndexWriterConfig().setIndexSort(sort); RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc); - final int numDocs = scaledRandomIntBetween(100, 200); + final int numDocs = scaledRandomIntBetween(600, 900); for (int i = 0; i < numDocs; ++i) { Document doc = new Document(); if (randomBoolean()) { @@ -511,7 +515,7 @@ public void testIndexSortingEarlyTermination() throws Exception { context.setTask(new SearchShardTask(123L, "", "", "", null, Collections.emptyMap())); context.sort(new SortAndFormats(sort, new DocValueFormat[] { DocValueFormat.RAW })); - QueryPhase.executeInternal(context); + QueryPhase.executeInternal(context.withCleanQueryResult()); assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo((long) numDocs)); assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(1)); assertThat(context.queryResult().topDocs().topDocs.scoreDocs[0], instanceOf(FieldDoc.class)); @@ -520,7 +524,7 @@ public void testIndexSortingEarlyTermination() throws Exception { { context.parsedPostFilter(new ParsedQuery(new MinDocQuery(1))); - QueryPhase.executeInternal(context); + QueryPhase.executeInternal(context.withCleanQueryResult()); assertNull(context.queryResult().terminatedEarly()); assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo(numDocs - 1L)); assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(1)); @@ -528,28 +532,28 @@ public void testIndexSortingEarlyTermination() throws Exception { assertThat(fieldDoc.fields[0], anyOf(equalTo(1), equalTo(2))); context.parsedPostFilter(null); - final TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector(); - context.queryCollectors().put(TotalHitCountCollector.class, totalHitCountCollector); - QueryPhase.executeInternal(context); + final TestTotalHitCountCollectorManager manager = TestTotalHitCountCollectorManager.create(sort); + context.queryCollectorManagers().put(TotalHitCountCollector.class, manager); + QueryPhase.executeInternal(context.withCleanQueryResult()); assertNull(context.queryResult().terminatedEarly()); assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo((long) numDocs)); assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(1)); assertThat(context.queryResult().topDocs().topDocs.scoreDocs[0], instanceOf(FieldDoc.class)); assertThat(fieldDoc.fields[0], anyOf(equalTo(1), equalTo(2))); - assertThat(totalHitCountCollector.getTotalHits(), equalTo(numDocs)); - context.queryCollectors().clear(); + assertThat(manager.getTotalHits(), equalTo(numDocs)); + context.queryCollectorManagers().clear(); } { context.setSearcher(newEarlyTerminationContextSearcher(reader, 1)); context.trackTotalHitsUpTo(SearchContext.TRACK_TOTAL_HITS_DISABLED); - QueryPhase.executeInternal(context); + QueryPhase.executeInternal(context.withCleanQueryResult()); assertNull(context.queryResult().terminatedEarly()); assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(1)); assertThat(context.queryResult().topDocs().topDocs.scoreDocs[0], instanceOf(FieldDoc.class)); assertThat(fieldDoc.fields[0], anyOf(equalTo(1), equalTo(2))); - QueryPhase.executeInternal(context); + QueryPhase.executeInternal(context.withCleanQueryResult()); assertNull(context.queryResult().terminatedEarly()); assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(1)); assertThat(context.queryResult().topDocs().topDocs.scoreDocs[0], instanceOf(FieldDoc.class)); @@ -564,7 +568,7 @@ public void testIndexSortScrollOptimization() throws Exception { final Sort indexSort = new Sort(new SortField("rank", SortField.Type.INT), new SortField("tiebreaker", SortField.Type.INT)); IndexWriterConfig iwc = newIndexWriterConfig().setIndexSort(indexSort); RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc); - final int numDocs = scaledRandomIntBetween(100, 200); + final int numDocs = scaledRandomIntBetween(600, 900); for (int i = 0; i < numDocs; ++i) { Document doc = new Document(); doc.add(new NumericDocValuesField("rank", random().nextInt())); @@ -592,7 +596,7 @@ public void testIndexSortScrollOptimization() throws Exception { context.setSize(10); context.sort(searchSortAndFormat); - QueryPhase.executeInternal(context); + QueryPhase.executeInternal(context.withCleanQueryResult()); assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo((long) numDocs)); assertNull(context.queryResult().terminatedEarly()); assertThat(context.terminateAfter(), equalTo(0)); @@ -601,7 +605,7 @@ public void testIndexSortScrollOptimization() throws Exception { FieldDoc lastDoc = (FieldDoc) context.queryResult().topDocs().topDocs.scoreDocs[sizeMinus1]; context.setSearcher(newEarlyTerminationContextSearcher(reader, 10)); - QueryPhase.executeInternal(context); + QueryPhase.executeInternal(context.withCleanQueryResult()); assertNull(context.queryResult().terminatedEarly()); assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo((long) numDocs)); assertThat(context.terminateAfter(), equalTo(0)); @@ -630,7 +634,8 @@ public void testDisableTopScoreCollection() throws Exception { IndexWriterConfig iwc = newIndexWriterConfig(new StandardAnalyzer()); RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc); Document doc = new Document(); - for (int i = 0; i < 10; i++) { + final int numDocs = 2 * scaledRandomIntBetween(50, 450); + for (int i = 0; i < numDocs; i++) { doc.clear(); if (i % 2 == 0) { doc.add(new TextField("title", "foo bar", Store.NO)); @@ -653,16 +658,16 @@ public void testDisableTopScoreCollection() throws Exception { context.trackTotalHitsUpTo(3); TopDocsCollectorContext topDocsContext = TopDocsCollectorContext.createTopDocsCollectorContext(context, false); assertEquals(topDocsContext.create(null).scoreMode(), org.apache.lucene.search.ScoreMode.COMPLETE); - QueryPhase.executeInternal(context); - assertEquals(5, context.queryResult().topDocs().topDocs.totalHits.value); + QueryPhase.executeInternal(context.withCleanQueryResult()); + assertEquals(numDocs / 2, context.queryResult().topDocs().topDocs.totalHits.value); assertEquals(context.queryResult().topDocs().topDocs.totalHits.relation, TotalHits.Relation.EQUAL_TO); assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(3)); context.sort(new SortAndFormats(new Sort(new SortField("other", SortField.Type.INT)), new DocValueFormat[] { DocValueFormat.RAW })); topDocsContext = TopDocsCollectorContext.createTopDocsCollectorContext(context, false); assertEquals(topDocsContext.create(null).scoreMode(), org.apache.lucene.search.ScoreMode.TOP_DOCS); - QueryPhase.executeInternal(context); - assertEquals(5, context.queryResult().topDocs().topDocs.totalHits.value); + QueryPhase.executeInternal(context.withCleanQueryResult()); + assertEquals(numDocs / 2, context.queryResult().topDocs().topDocs.totalHits.value); assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(3)); assertEquals(context.queryResult().topDocs().topDocs.totalHits.relation, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO); @@ -724,7 +729,7 @@ public void testEnhanceSortOnNumeric() throws Exception { searchContext.parsedQuery(query); searchContext.setTask(task); searchContext.setSize(10); - QueryPhase.executeInternal(searchContext); + QueryPhase.executeInternal(searchContext.withCleanQueryResult()); assertSortResults(searchContext.queryResult().topDocs().topDocs, (long) numDocs, false); } @@ -736,7 +741,7 @@ public void testEnhanceSortOnNumeric() throws Exception { searchContext.parsedQuery(query); searchContext.setTask(task); searchContext.setSize(10); - QueryPhase.executeInternal(searchContext); + QueryPhase.executeInternal(searchContext.withCleanQueryResult()); assertSortResults(searchContext.queryResult().topDocs().topDocs, (long) numDocs, true); } @@ -748,7 +753,7 @@ public void testEnhanceSortOnNumeric() throws Exception { searchContext.parsedQuery(query); searchContext.setTask(task); searchContext.setSize(10); - QueryPhase.executeInternal(searchContext); + QueryPhase.executeInternal(searchContext.withCleanQueryResult()); assertSortResults(searchContext.queryResult().topDocs().topDocs, (long) numDocs, false); } @@ -773,7 +778,7 @@ public void testEnhanceSortOnNumeric() throws Exception { searchContext.setTask(task); searchContext.from(5); searchContext.setSize(0); - QueryPhase.executeInternal(searchContext); + QueryPhase.executeInternal(searchContext.withCleanQueryResult()); assertSortResults(searchContext.queryResult().topDocs().topDocs, (long) numDocs, false); } @@ -800,11 +805,15 @@ public void testEnhanceSortOnNumeric() throws Exception { searchContext.parsedQuery(query); searchContext.setTask(task); searchContext.setSize(10); - QueryPhase.executeInternal(searchContext); + QueryPhase.executeInternal(searchContext.withCleanQueryResult()); final TopDocs topDocs = searchContext.queryResult().topDocs().topDocs; long topValue = (long) ((FieldDoc) topDocs.scoreDocs[0]).fields[0]; assertThat(topValue, greaterThan(afterValue)); assertSortResults(topDocs, (long) numDocs, false); + + final TotalHits totalHits = topDocs.totalHits; + assertEquals(TotalHits.Relation.EQUAL_TO, totalHits.relation); + assertEquals(numDocs, totalHits.value); } reader.close(); @@ -916,13 +925,133 @@ public void testMinScore() throws Exception { context.setSize(1); context.trackTotalHitsUpTo(5); - QueryPhase.executeInternal(context); + QueryPhase.executeInternal(context.withCleanQueryResult()); assertEquals(10, context.queryResult().topDocs().topDocs.totalHits.value); reader.close(); dir.close(); } + public void testMaxScore() throws Exception { + Directory dir = newDirectory(); + final Sort sort = new Sort(new SortField("filter", SortField.Type.STRING)); + IndexWriterConfig iwc = newIndexWriterConfig().setIndexSort(sort); + RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc); + + final int numDocs = scaledRandomIntBetween(600, 900); + for (int i = 0; i < numDocs; i++) { + Document doc = new Document(); + doc.add(new StringField("foo", "bar", Store.NO)); + doc.add(new StringField("filter", "f1" + ((i > 0) ? " " + Integer.toString(i) : ""), Store.NO)); + doc.add(new SortedDocValuesField("filter", newBytesRef("f1" + ((i > 0) ? " " + Integer.toString(i) : "")))); + w.addDocument(doc); + } + w.close(); + + IndexReader reader = DirectoryReader.open(dir); + TestSearchContext context = new TestSearchContext(null, indexShard, newContextSearcher(reader)); + context.trackScores(true); + context.parsedQuery( + new ParsedQuery( + new BooleanQuery.Builder().add(new TermQuery(new Term("foo", "bar")), Occur.MUST) + .add(new TermQuery(new Term("filter", "f1")), Occur.SHOULD) + .build() + ) + ); + context.setTask(new SearchShardTask(123L, "", "", "", null, Collections.emptyMap())); + context.setSize(1); + context.trackTotalHitsUpTo(5); + + QueryPhase.executeInternal(context.withCleanQueryResult()); + assertFalse(Float.isNaN(context.queryResult().getMaxScore())); + assertEquals(1, context.queryResult().topDocs().topDocs.scoreDocs.length); + assertThat(context.queryResult().topDocs().topDocs.totalHits.value, greaterThanOrEqualTo(6L)); + + context.sort(new SortAndFormats(sort, new DocValueFormat[] { DocValueFormat.RAW })); + QueryPhase.executeInternal(context.withCleanQueryResult()); + assertFalse(Float.isNaN(context.queryResult().getMaxScore())); + assertEquals(1, context.queryResult().topDocs().topDocs.scoreDocs.length); + assertThat(context.queryResult().topDocs().topDocs.totalHits.value, greaterThanOrEqualTo(6L)); + + context.trackScores(false); + QueryPhase.executeInternal(context.withCleanQueryResult()); + assertTrue(Float.isNaN(context.queryResult().getMaxScore())); + assertEquals(1, context.queryResult().topDocs().topDocs.scoreDocs.length); + assertThat(context.queryResult().topDocs().topDocs.totalHits.value, greaterThanOrEqualTo(6L)); + + reader.close(); + dir.close(); + } + + public void testCollapseQuerySearchResults() throws Exception { + Directory dir = newDirectory(); + final Sort sort = new Sort(new SortField("user", SortField.Type.INT)); + IndexWriterConfig iwc = newIndexWriterConfig().setIndexSort(sort); + RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc); + + // Always end up with uneven buckets so collapsing is predictable + final int numDocs = 2 * scaledRandomIntBetween(600, 900) - 1; + for (int i = 0; i < numDocs; i++) { + Document doc = new Document(); + doc.add(new StringField("foo", "bar", Store.NO)); + doc.add(new NumericDocValuesField("user", i & 1)); + w.addDocument(doc); + } + w.close(); + + IndexReader reader = DirectoryReader.open(dir); + QueryShardContext queryShardContext = mock(QueryShardContext.class); + when(queryShardContext.fieldMapper("user")).thenReturn( + new NumberFieldType("user", NumberType.INTEGER, true, false, true, false, null, Collections.emptyMap()) + ); + + TestSearchContext context = new TestSearchContext(queryShardContext, indexShard, newContextSearcher(reader)); + context.collapse(new CollapseBuilder("user").build(context.getQueryShardContext())); + context.trackScores(true); + context.parsedQuery(new ParsedQuery(new TermQuery(new Term("foo", "bar")))); + context.setTask(new SearchShardTask(123L, "", "", "", null, Collections.emptyMap())); + context.setSize(2); + context.trackTotalHitsUpTo(5); + + QueryPhase.executeInternal(context.withCleanQueryResult()); + assertFalse(Float.isNaN(context.queryResult().getMaxScore())); + assertEquals(2, context.queryResult().topDocs().topDocs.scoreDocs.length); + assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo((long) numDocs)); + assertThat(context.queryResult().topDocs().topDocs, instanceOf(CollapseTopFieldDocs.class)); + + CollapseTopFieldDocs topDocs = (CollapseTopFieldDocs) context.queryResult().topDocs().topDocs; + assertThat(topDocs.collapseValues.length, equalTo(2)); + assertThat(topDocs.collapseValues[0], equalTo(0L)); // user == 0 + assertThat(topDocs.collapseValues[1], equalTo(1L)); // user == 1 + + context.sort(new SortAndFormats(sort, new DocValueFormat[] { DocValueFormat.RAW })); + QueryPhase.executeInternal(context.withCleanQueryResult()); + assertFalse(Float.isNaN(context.queryResult().getMaxScore())); + assertEquals(2, context.queryResult().topDocs().topDocs.scoreDocs.length); + assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo((long) numDocs)); + assertThat(context.queryResult().topDocs().topDocs, instanceOf(CollapseTopFieldDocs.class)); + + topDocs = (CollapseTopFieldDocs) context.queryResult().topDocs().topDocs; + assertThat(topDocs.collapseValues.length, equalTo(2)); + assertThat(topDocs.collapseValues[0], equalTo(0L)); // user == 0 + assertThat(topDocs.collapseValues[1], equalTo(1L)); // user == 1 + + context.trackScores(false); + QueryPhase.executeInternal(context.withCleanQueryResult()); + assertTrue(Float.isNaN(context.queryResult().getMaxScore())); + assertEquals(2, context.queryResult().topDocs().topDocs.scoreDocs.length); + assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo((long) numDocs)); + assertThat(context.queryResult().topDocs().topDocs, instanceOf(CollapseTopFieldDocs.class)); + + topDocs = (CollapseTopFieldDocs) context.queryResult().topDocs().topDocs; + assertThat(topDocs.collapseValues.length, equalTo(2)); + assertThat(topDocs.collapseValues[0], equalTo(0L)); // user == 0 + assertThat(topDocs.collapseValues[1], equalTo(1L)); // user == 1 + + reader.close(); + dir.close(); + } + public void testCancellationDuringPreprocess() throws IOException { try (Directory dir = newDirectory(); RandomIndexWriter w = new RandomIndexWriter(random(), dir, newIndexWriterConfig())) { @@ -982,7 +1111,8 @@ private static ContextIndexSearcher newContextSearcher(IndexReader reader) throw IndexSearcher.getDefaultSimilarity(), IndexSearcher.getDefaultQueryCache(), IndexSearcher.getDefaultQueryCachingPolicy(), - true + true, + null ); } @@ -992,7 +1122,8 @@ private static ContextIndexSearcher newEarlyTerminationContextSearcher(IndexRead IndexSearcher.getDefaultSimilarity(), IndexSearcher.getDefaultQueryCache(), IndexSearcher.getDefaultQueryCachingPolicy(), - true + true, + null ) { @Override @@ -1003,6 +1134,32 @@ public void search(List leaves, Weight weight, Collector coll }; } + private static class TestTotalHitCountCollectorManager extends TotalHitCountCollectorManager { + private final TotalHitCountCollector collector; + + static TestTotalHitCountCollectorManager create() { + return create(null); + } + + static TestTotalHitCountCollectorManager create(final Sort sort) { + return new TestTotalHitCountCollectorManager(new TotalHitCountCollector(), sort); + } + + private TestTotalHitCountCollectorManager(final TotalHitCountCollector collector, final Sort sort) { + super(sort); + this.collector = collector; + } + + @Override + public TotalHitCountCollector newCollector() throws IOException { + return collector; + } + + public int getTotalHits() { + return collector.getTotalHits(); + } + } + private static class AssertingEarlyTerminationFilterCollector extends FilterCollector { private final int size; diff --git a/server/src/test/java/org/opensearch/search/query/QueryProfilePhaseTests.java b/server/src/test/java/org/opensearch/search/query/QueryProfilePhaseTests.java new file mode 100644 index 0000000000000..dfa41edb5cff2 --- /dev/null +++ b/server/src/test/java/org/opensearch/search/query/QueryProfilePhaseTests.java @@ -0,0 +1,1158 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.query; + +import org.apache.lucene.analysis.standard.StandardAnalyzer; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field.Store; +import org.apache.lucene.document.NumericDocValuesField; +import org.apache.lucene.document.SortedDocValuesField; +import org.apache.lucene.document.StringField; +import org.apache.lucene.document.TextField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.Term; +import org.apache.lucene.queries.spans.SpanNearQuery; +import org.apache.lucene.queries.spans.SpanTermQuery; +import org.apache.lucene.search.BooleanClause.Occur; +import org.apache.lucene.search.grouping.CollapseTopFieldDocs; +import org.apache.lucene.search.BooleanQuery; +import org.apache.lucene.search.Collector; +import org.apache.lucene.search.FieldComparator; +import org.apache.lucene.search.FieldDoc; +import org.apache.lucene.search.FilterCollector; +import org.apache.lucene.search.FilterLeafCollector; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.LeafCollector; +import org.apache.lucene.search.MatchAllDocsQuery; +import org.apache.lucene.search.MatchNoDocsQuery; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.Sort; +import org.apache.lucene.search.SortField; +import org.apache.lucene.search.TermQuery; +import org.apache.lucene.search.TotalHits; +import org.apache.lucene.search.Weight; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.index.RandomIndexWriter; +import org.opensearch.action.search.SearchShardTask; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.ToXContent; +import org.opensearch.common.xcontent.XContentBuilder; +import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.index.mapper.NumberFieldMapper.NumberFieldType; +import org.opensearch.index.mapper.NumberFieldMapper.NumberType; +import org.opensearch.index.query.ParsedQuery; +import org.opensearch.index.query.QueryShardContext; +import org.opensearch.index.shard.IndexShard; +import org.opensearch.index.shard.IndexShardTestCase; +import org.opensearch.lucene.queries.MinDocQuery; +import org.opensearch.search.DocValueFormat; +import org.opensearch.search.collapse.CollapseBuilder; +import org.opensearch.search.internal.ContextIndexSearcher; +import org.opensearch.search.internal.ScrollContext; +import org.opensearch.search.internal.SearchContext; +import org.opensearch.search.profile.ProfileResult; +import org.opensearch.search.profile.ProfileShardResult; +import org.opensearch.search.profile.SearchProfileShardResults; +import org.opensearch.search.profile.query.CollectorResult; +import org.opensearch.search.profile.query.QueryProfileShardResult; +import org.opensearch.search.sort.SortAndFormats; +import org.opensearch.test.TestSearchContext; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.OutputStream; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.function.Consumer; + +import static org.hamcrest.CoreMatchers.not; +import static org.hamcrest.CoreMatchers.nullValue; +import static org.hamcrest.Matchers.anyOf; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.hamcrest.Matchers.instanceOf; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.hamcrest.Matchers.hasSize; + +public class QueryProfilePhaseTests extends IndexShardTestCase { + + private IndexShard indexShard; + + @Override + public Settings threadPoolSettings() { + return Settings.builder().put(super.threadPoolSettings()).put("thread_pool.search.min_queue_size", 10).build(); + } + + @Override + public void setUp() throws Exception { + super.setUp(); + indexShard = newShard(true); + } + + @Override + public void tearDown() throws Exception { + super.tearDown(); + closeShards(indexShard); + } + + public void testPostFilterDisablesCountOptimization() throws Exception { + Directory dir = newDirectory(); + final Sort sort = new Sort(new SortField("rank", SortField.Type.INT)); + IndexWriterConfig iwc = newIndexWriterConfig().setIndexSort(sort); + RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc); + Document doc = new Document(); + w.addDocument(doc); + w.close(); + + IndexReader reader = DirectoryReader.open(dir); + + TestSearchContext context = new TestSearchContext(null, indexShard, newEarlyTerminationContextSearcher(reader, 0)); + context.setTask(new SearchShardTask(123L, "", "", "", null, Collections.emptyMap())); + context.parsedQuery(new ParsedQuery(new MatchAllDocsQuery())); + + QueryPhase.executeInternal(context.withCleanQueryResult().withProfilers()); + assertEquals(1, context.queryResult().topDocs().topDocs.totalHits.value); + assertProfileData(context, "MatchAllDocsQuery", query -> { + assertThat(query.getTimeBreakdown().keySet(), not(empty())); + assertThat(query.getTimeBreakdown().get("score"), equalTo(0L)); + assertThat(query.getTimeBreakdown().get("score_count"), equalTo(0L)); + assertThat(query.getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + }, collector -> { + assertThat(collector.getReason(), equalTo("search_count")); + assertThat(collector.getTime(), greaterThan(0L)); + assertThat(collector.getProfiledChildren(), empty()); + }); + + context.setSearcher(newContextSearcher(reader)); + context.parsedPostFilter(new ParsedQuery(new MatchNoDocsQuery())); + QueryPhase.executeInternal(context.withCleanQueryResult().withProfilers()); + assertEquals(0, context.queryResult().topDocs().topDocs.totalHits.value); + assertProfileData(context, collector -> { + assertThat(collector.getReason(), equalTo("search_post_filter")); + assertThat(collector.getTime(), greaterThan(0L)); + assertThat(collector.getProfiledChildren(), hasSize(1)); + assertThat(collector.getProfiledChildren().get(0).getReason(), equalTo("search_count")); + assertThat(collector.getProfiledChildren().get(0).getTime(), greaterThan(0L)); + }, (query) -> { + assertThat(query.getQueryName(), equalTo("MatchNoDocsQuery")); + assertThat(query.getTimeBreakdown().keySet(), not(empty())); + assertThat(query.getTimeBreakdown().get("score"), equalTo(0L)); + assertThat(query.getTimeBreakdown().get("score_count"), equalTo(0L)); + assertThat(query.getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + }, (query) -> { + assertThat(query.getQueryName(), equalTo("MatchAllDocsQuery")); + assertThat(query.getTimeBreakdown().keySet(), not(empty())); + assertThat(query.getTimeBreakdown().get("score"), equalTo(0L)); + assertThat(query.getTimeBreakdown().get("score_count"), equalTo(0L)); + assertThat(query.getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + }); + + reader.close(); + dir.close(); + } + + public void testTerminateAfterWithFilter() throws Exception { + Directory dir = newDirectory(); + final Sort sort = new Sort(new SortField("rank", SortField.Type.INT)); + IndexWriterConfig iwc = newIndexWriterConfig().setIndexSort(sort); + RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc); + Document doc = new Document(); + for (int i = 0; i < 10; i++) { + doc.add(new StringField("foo", Integer.toString(i), Store.NO)); + } + w.addDocument(doc); + w.close(); + + IndexReader reader = DirectoryReader.open(dir); + + TestSearchContext context = new TestSearchContext(null, indexShard, newContextSearcher(reader)); + context.setTask(new SearchShardTask(123L, "", "", "", null, Collections.emptyMap())); + + context.parsedQuery(new ParsedQuery(new MatchAllDocsQuery())); + context.terminateAfter(1); + context.setSize(10); + for (int i = 0; i < 10; i++) { + context.parsedPostFilter(new ParsedQuery(new TermQuery(new Term("foo", Integer.toString(i))))); + QueryPhase.executeInternal(context.withCleanQueryResult().withProfilers()); + assertEquals(1, context.queryResult().topDocs().topDocs.totalHits.value); + assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(1)); + assertProfileData(context, collector -> { + assertThat(collector.getReason(), equalTo("search_post_filter")); + assertThat(collector.getTime(), greaterThan(0L)); + assertThat(collector.getProfiledChildren(), hasSize(1)); + assertThat(collector.getProfiledChildren().get(0).getReason(), equalTo("search_terminate_after_count")); + assertThat(collector.getProfiledChildren().get(0).getTime(), greaterThan(0L)); + assertThat(collector.getProfiledChildren().get(0).getProfiledChildren(), hasSize(1)); + assertThat(collector.getProfiledChildren().get(0).getProfiledChildren().get(0).getReason(), equalTo("search_top_hits")); + assertThat(collector.getProfiledChildren().get(0).getProfiledChildren().get(0).getTime(), greaterThan(0L)); + }, (query) -> { + assertThat(query.getQueryName(), equalTo("TermQuery")); + assertThat(query.getTimeBreakdown().keySet(), not(empty())); + assertThat(query.getTimeBreakdown().get("score"), equalTo(0L)); + assertThat(query.getTimeBreakdown().get("score_count"), equalTo(0L)); + assertThat(query.getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + }, (query) -> { + assertThat(query.getQueryName(), equalTo("MatchAllDocsQuery")); + assertThat(query.getTimeBreakdown().keySet(), not(empty())); + assertThat(query.getTimeBreakdown().get("score"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("score_count"), equalTo(1L)); + assertThat(query.getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + }); + } + reader.close(); + dir.close(); + } + + public void testMinScoreDisablesCountOptimization() throws Exception { + Directory dir = newDirectory(); + final Sort sort = new Sort(new SortField("rank", SortField.Type.INT)); + IndexWriterConfig iwc = newIndexWriterConfig().setIndexSort(sort); + RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc); + Document doc = new Document(); + w.addDocument(doc); + w.close(); + + IndexReader reader = DirectoryReader.open(dir); + TestSearchContext context = new TestSearchContext(null, indexShard, newEarlyTerminationContextSearcher(reader, 0)); + context.parsedQuery(new ParsedQuery(new MatchAllDocsQuery())); + context.setSize(0); + context.setTask(new SearchShardTask(123L, "", "", "", null, Collections.emptyMap())); + QueryPhase.executeInternal(context.withCleanQueryResult().withProfilers()); + assertEquals(1, context.queryResult().topDocs().topDocs.totalHits.value); + assertProfileData(context, "MatchAllDocsQuery", query -> { + assertThat(query.getTimeBreakdown().keySet(), not(empty())); + assertThat(query.getTimeBreakdown().get("score"), equalTo(0L)); + assertThat(query.getTimeBreakdown().get("score_count"), equalTo(0L)); + assertThat(query.getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + }, collector -> { + assertThat(collector.getReason(), equalTo("search_count")); + assertThat(collector.getTime(), greaterThan(0L)); + assertThat(collector.getProfiledChildren(), empty()); + }); + + context.minimumScore(100); + QueryPhase.executeInternal(context.withCleanQueryResult().withProfilers()); + assertEquals(0, context.queryResult().topDocs().topDocs.totalHits.value); + assertEquals(TotalHits.Relation.EQUAL_TO, context.queryResult().topDocs().topDocs.totalHits.relation); + assertProfileData(context, "MatchAllDocsQuery", query -> { + assertThat(query.getTimeBreakdown().keySet(), not(empty())); + assertThat(query.getTimeBreakdown().get("score"), greaterThanOrEqualTo(100L)); + assertThat(query.getTimeBreakdown().get("score_count"), equalTo(1L)); + assertThat(query.getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + }, collector -> { + assertThat(collector.getReason(), equalTo("search_min_score")); + assertThat(collector.getTime(), greaterThan(0L)); + assertThat(collector.getProfiledChildren(), hasSize(1)); + assertThat(collector.getProfiledChildren().get(0).getReason(), equalTo("search_count")); + assertThat(collector.getProfiledChildren().get(0).getTime(), greaterThan(0L)); + }); + + reader.close(); + dir.close(); + } + + public void testInOrderScrollOptimization() throws Exception { + Directory dir = newDirectory(); + final Sort sort = new Sort(new SortField("rank", SortField.Type.INT)); + IndexWriterConfig iwc = newIndexWriterConfig().setIndexSort(sort); + RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc); + final int numDocs = scaledRandomIntBetween(600, 900); + for (int i = 0; i < numDocs; ++i) { + w.addDocument(new Document()); + } + w.close(); + IndexReader reader = DirectoryReader.open(dir); + ScrollContext scrollContext = new ScrollContext(); + TestSearchContext context = new TestSearchContext(null, indexShard, newContextSearcher(reader), scrollContext); + context.parsedQuery(new ParsedQuery(new MatchAllDocsQuery())); + scrollContext.lastEmittedDoc = null; + scrollContext.maxScore = Float.NaN; + scrollContext.totalHits = null; + context.setTask(new SearchShardTask(123L, "", "", "", null, Collections.emptyMap())); + int size = randomIntBetween(2, 5); + context.setSize(size); + + QueryPhase.executeInternal(context.withCleanQueryResult().withProfilers()); + assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo((long) numDocs)); + assertNull(context.queryResult().terminatedEarly()); + assertThat(context.terminateAfter(), equalTo(0)); + assertThat(context.queryResult().getTotalHits().value, equalTo((long) numDocs)); + assertProfileData(context, "MatchAllDocsQuery", query -> { + assertThat(query.getTimeBreakdown().keySet(), not(empty())); + assertThat(query.getTimeBreakdown().get("score"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("score_count"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + }, collector -> { + assertThat(collector.getReason(), equalTo("search_top_hits")); + assertThat(collector.getTime(), greaterThan(0L)); + assertThat(collector.getProfiledChildren(), empty()); + }); + + context.setSearcher(newEarlyTerminationContextSearcher(reader, size)); + QueryPhase.executeInternal(context.withCleanQueryResult().withProfilers()); + assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo((long) numDocs)); + assertThat(context.terminateAfter(), equalTo(size)); + assertThat(context.queryResult().getTotalHits().value, equalTo((long) numDocs)); + assertThat(context.queryResult().topDocs().topDocs.scoreDocs[0].doc, greaterThanOrEqualTo(size)); + assertProfileData(context, "ConstantScoreQuery", query -> { + assertThat(query.getTimeBreakdown().keySet(), not(empty())); + assertThat(query.getTimeBreakdown().get("score"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("score_count"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + assertThat(query.getProfiledChildren().get(0).getTimeBreakdown().get("score"), equalTo(0L)); + assertThat(query.getProfiledChildren().get(0).getTimeBreakdown().get("score_count"), equalTo(0L)); + assertThat(query.getProfiledChildren().get(0).getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getProfiledChildren().get(0).getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + }, collector -> { + assertThat(collector.getReason(), equalTo("search_terminate_after_count")); + assertThat(collector.getTime(), greaterThan(0L)); + assertThat(collector.getProfiledChildren(), hasSize(1)); + assertThat(collector.getProfiledChildren().get(0).getReason(), equalTo("search_top_hits")); + assertThat(collector.getProfiledChildren().get(0).getTime(), greaterThan(0L)); + }); + + reader.close(); + dir.close(); + } + + public void testTerminateAfterEarlyTermination() throws Exception { + Directory dir = newDirectory(); + IndexWriterConfig iwc = newIndexWriterConfig(); + RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc); + final int numDocs = scaledRandomIntBetween(600, 900); + for (int i = 0; i < numDocs; ++i) { + Document doc = new Document(); + if (randomBoolean()) { + doc.add(new StringField("foo", "bar", Store.NO)); + } + if (randomBoolean()) { + doc.add(new StringField("foo", "baz", Store.NO)); + } + doc.add(new NumericDocValuesField("rank", numDocs - i)); + w.addDocument(doc); + } + w.close(); + final IndexReader reader = DirectoryReader.open(dir); + TestSearchContext context = new TestSearchContext(null, indexShard, newContextSearcher(reader)); + context.setTask(new SearchShardTask(123L, "", "", "", null, Collections.emptyMap())); + context.parsedQuery(new ParsedQuery(new MatchAllDocsQuery())); + + context.terminateAfter(1); + { + context.setSize(1); + QueryPhase.executeInternal(context.withCleanQueryResult().withProfilers()); + assertTrue(context.queryResult().terminatedEarly()); + assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo(1L)); + assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(1)); + assertProfileData(context, "MatchAllDocsQuery", query -> { + assertThat(query.getTimeBreakdown().keySet(), not(empty())); + assertThat(query.getTimeBreakdown().get("score"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("score_count"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + }, collector -> { + assertThat(collector.getReason(), equalTo("search_terminate_after_count")); + assertThat(collector.getTime(), greaterThan(0L)); + assertThat(collector.getProfiledChildren(), hasSize(1)); + assertThat(collector.getProfiledChildren().get(0).getReason(), equalTo("search_top_hits")); + assertThat(collector.getProfiledChildren().get(0).getTime(), greaterThan(0L)); + }); + + context.setSize(0); + QueryPhase.executeInternal(context.withCleanQueryResult().withProfilers()); + assertTrue(context.queryResult().terminatedEarly()); + assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo(1L)); + assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(0)); + assertProfileData(context, "MatchAllDocsQuery", query -> { + assertThat(query.getTimeBreakdown().keySet(), not(empty())); + assertThat(query.getTimeBreakdown().get("score"), equalTo(0L)); + assertThat(query.getTimeBreakdown().get("score_count"), equalTo(0L)); + assertThat(query.getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + }, collector -> { + assertThat(collector.getReason(), equalTo("search_terminate_after_count")); + assertThat(collector.getTime(), greaterThan(0L)); + assertThat(collector.getProfiledChildren(), hasSize(1)); + assertThat(collector.getProfiledChildren().get(0).getReason(), equalTo("search_count")); + assertThat(collector.getProfiledChildren().get(0).getTime(), greaterThan(0L)); + }); + } + + { + context.setSize(1); + QueryPhase.executeInternal(context.withCleanQueryResult().withProfilers()); + assertTrue(context.queryResult().terminatedEarly()); + assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo(1L)); + assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(1)); + assertProfileData(context, "MatchAllDocsQuery", query -> { + assertThat(query.getTimeBreakdown().keySet(), not(empty())); + assertThat(query.getTimeBreakdown().get("score"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("score_count"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + }, collector -> { + assertThat(collector.getReason(), equalTo("search_terminate_after_count")); + assertThat(collector.getTime(), greaterThan(0L)); + assertThat(collector.getProfiledChildren(), hasSize(1)); + assertThat(collector.getProfiledChildren().get(0).getReason(), equalTo("search_top_hits")); + assertThat(collector.getProfiledChildren().get(0).getTime(), greaterThan(0L)); + }); + } + { + context.setSize(1); + BooleanQuery bq = new BooleanQuery.Builder().add(new TermQuery(new Term("foo", "bar")), Occur.SHOULD) + .add(new TermQuery(new Term("foo", "baz")), Occur.SHOULD) + .build(); + context.parsedQuery(new ParsedQuery(bq)); + QueryPhase.executeInternal(context.withCleanQueryResult().withProfilers()); + assertTrue(context.queryResult().terminatedEarly()); + assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo(1L)); + assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(1)); + assertProfileData(context, "BooleanQuery", query -> { + assertThat(query.getTimeBreakdown().keySet(), not(empty())); + assertThat(query.getTimeBreakdown().get("score"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("score_count"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + + assertThat(query.getProfiledChildren(), hasSize(2)); + assertThat(query.getProfiledChildren().get(0).getQueryName(), equalTo("TermQuery")); + assertThat(query.getProfiledChildren().get(0).getTime(), greaterThan(0L)); + assertThat(query.getProfiledChildren().get(0).getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getProfiledChildren().get(0).getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + + assertThat(query.getProfiledChildren().get(1).getQueryName(), equalTo("TermQuery")); + assertThat(query.getProfiledChildren().get(1).getTime(), greaterThan(0L)); + assertThat(query.getProfiledChildren().get(1).getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getProfiledChildren().get(1).getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + }, collector -> { + assertThat(collector.getReason(), equalTo("search_terminate_after_count")); + assertThat(collector.getTime(), greaterThan(0L)); + assertThat(collector.getProfiledChildren(), hasSize(1)); + assertThat(collector.getProfiledChildren().get(0).getReason(), equalTo("search_top_hits")); + assertThat(collector.getProfiledChildren().get(0).getTime(), greaterThan(0L)); + }); + context.setSize(0); + context.parsedQuery(new ParsedQuery(bq)); + QueryPhase.executeInternal(context.withCleanQueryResult().withProfilers()); + assertTrue(context.queryResult().terminatedEarly()); + assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo(1L)); + assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(0)); + + assertProfileData(context, "BooleanQuery", query -> { + assertThat(query.getTimeBreakdown().keySet(), not(empty())); + assertThat(query.getTimeBreakdown().get("score"), equalTo(0L)); + assertThat(query.getTimeBreakdown().get("score_count"), equalTo(0L)); + assertThat(query.getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + + assertThat(query.getProfiledChildren(), hasSize(2)); + assertThat(query.getProfiledChildren().get(0).getQueryName(), equalTo("TermQuery")); + assertThat(query.getProfiledChildren().get(0).getTime(), greaterThan(0L)); + assertThat(query.getProfiledChildren().get(0).getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getProfiledChildren().get(0).getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + assertThat(query.getProfiledChildren().get(0).getTimeBreakdown().get("score"), equalTo(0L)); + assertThat(query.getProfiledChildren().get(0).getTimeBreakdown().get("score_count"), equalTo(0L)); + + assertThat(query.getProfiledChildren().get(1).getQueryName(), equalTo("TermQuery")); + assertThat(query.getProfiledChildren().get(1).getTime(), greaterThan(0L)); + assertThat(query.getProfiledChildren().get(1).getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getProfiledChildren().get(1).getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + assertThat(query.getProfiledChildren().get(1).getTimeBreakdown().get("score"), equalTo(0L)); + assertThat(query.getProfiledChildren().get(1).getTimeBreakdown().get("score_count"), equalTo(0L)); + }, collector -> { + assertThat(collector.getReason(), equalTo("search_terminate_after_count")); + assertThat(collector.getTime(), greaterThan(0L)); + assertThat(collector.getProfiledChildren(), hasSize(1)); + assertThat(collector.getProfiledChildren().get(0).getReason(), equalTo("search_count")); + assertThat(collector.getProfiledChildren().get(0).getTime(), greaterThan(0L)); + }); + } + + context.terminateAfter(7); + context.setSize(10); + for (int trackTotalHits : new int[] { -1, 3, 75, 100 }) { + context.trackTotalHitsUpTo(trackTotalHits); + QueryPhase.executeInternal(context.withCleanQueryResult().withProfilers()); + assertTrue(context.queryResult().terminatedEarly()); + if (trackTotalHits == -1) { + assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo(0L)); + } else { + assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo(7L)); + } + assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(7)); + assertProfileData(context, "BooleanQuery", query -> { + assertThat(query.getTimeBreakdown().keySet(), not(empty())); + assertThat(query.getTimeBreakdown().get("score"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("score_count"), greaterThanOrEqualTo(7L)); + assertThat(query.getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + + assertThat(query.getProfiledChildren(), hasSize(2)); + assertThat(query.getProfiledChildren().get(0).getQueryName(), equalTo("TermQuery")); + assertThat(query.getProfiledChildren().get(0).getTime(), greaterThan(0L)); + assertThat(query.getProfiledChildren().get(0).getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getProfiledChildren().get(0).getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + assertThat(query.getProfiledChildren().get(0).getTimeBreakdown().get("score"), greaterThan(0L)); + assertThat(query.getProfiledChildren().get(0).getTimeBreakdown().get("score_count"), greaterThan(0L)); + + assertThat(query.getProfiledChildren().get(1).getQueryName(), equalTo("TermQuery")); + assertThat(query.getProfiledChildren().get(1).getTime(), greaterThan(0L)); + assertThat(query.getProfiledChildren().get(1).getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getProfiledChildren().get(1).getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + assertThat(query.getProfiledChildren().get(1).getTimeBreakdown().get("score"), greaterThan(0L)); + assertThat(query.getProfiledChildren().get(1).getTimeBreakdown().get("score_count"), greaterThan(0L)); + }, collector -> { + assertThat(collector.getReason(), equalTo("search_terminate_after_count")); + assertThat(collector.getTime(), greaterThan(0L)); + assertThat(collector.getProfiledChildren(), hasSize(1)); + assertThat(collector.getProfiledChildren().get(0).getReason(), equalTo("search_top_hits")); + assertThat(collector.getProfiledChildren().get(0).getTime(), greaterThan(0L)); + }); + } + + reader.close(); + dir.close(); + } + + public void testIndexSortingEarlyTermination() throws Exception { + Directory dir = newDirectory(); + final Sort sort = new Sort(new SortField("rank", SortField.Type.INT)); + IndexWriterConfig iwc = newIndexWriterConfig().setIndexSort(sort); + RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc); + final int numDocs = scaledRandomIntBetween(600, 900); + for (int i = 0; i < numDocs; ++i) { + Document doc = new Document(); + if (randomBoolean()) { + doc.add(new StringField("foo", "bar", Store.NO)); + } + if (randomBoolean()) { + doc.add(new StringField("foo", "baz", Store.NO)); + } + doc.add(new NumericDocValuesField("rank", numDocs - i)); + w.addDocument(doc); + } + w.close(); + + final IndexReader reader = DirectoryReader.open(dir); + TestSearchContext context = new TestSearchContext(null, indexShard, newContextSearcher(reader)); + context.parsedQuery(new ParsedQuery(new MatchAllDocsQuery())); + context.setSize(1); + context.setTask(new SearchShardTask(123L, "", "", "", null, Collections.emptyMap())); + context.sort(new SortAndFormats(sort, new DocValueFormat[] { DocValueFormat.RAW })); + + QueryPhase.executeInternal(context.withCleanQueryResult().withProfilers()); + assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo((long) numDocs)); + assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(1)); + assertThat(context.queryResult().topDocs().topDocs.scoreDocs[0], instanceOf(FieldDoc.class)); + FieldDoc fieldDoc = (FieldDoc) context.queryResult().topDocs().topDocs.scoreDocs[0]; + assertThat(fieldDoc.fields[0], equalTo(1)); + assertProfileData(context, "MatchAllDocsQuery", query -> { + assertThat(query.getTimeBreakdown().keySet(), not(empty())); + assertThat(query.getTimeBreakdown().get("score"), equalTo(0L)); + assertThat(query.getTimeBreakdown().get("score_count"), equalTo(0L)); + assertThat(query.getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + }, collector -> { + assertThat(collector.getReason(), equalTo("search_top_hits")); + assertThat(collector.getTime(), greaterThan(0L)); + assertThat(collector.getProfiledChildren(), empty()); + }); + + { + context.parsedPostFilter(new ParsedQuery(new MinDocQuery(1))); + QueryPhase.executeInternal(context.withCleanQueryResult().withProfilers()); + assertNull(context.queryResult().terminatedEarly()); + assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo(numDocs - 1L)); + assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(1)); + assertThat(context.queryResult().topDocs().topDocs.scoreDocs[0], instanceOf(FieldDoc.class)); + assertThat(fieldDoc.fields[0], anyOf(equalTo(1), equalTo(2))); + assertProfileData(context, collector -> { + assertThat(collector.getReason(), equalTo("search_post_filter")); + assertThat(collector.getTime(), greaterThan(0L)); + assertThat(collector.getProfiledChildren(), hasSize(1)); + assertThat(collector.getProfiledChildren().get(0).getReason(), equalTo("search_top_hits")); + assertThat(collector.getProfiledChildren().get(0).getTime(), greaterThan(0L)); + }, (query) -> { + assertThat(query.getQueryName(), equalTo("MinDocQuery")); + assertThat(query.getTimeBreakdown().keySet(), not(empty())); + assertThat(query.getTimeBreakdown().get("score"), equalTo(0L)); + assertThat(query.getTimeBreakdown().get("score_count"), equalTo(0L)); + assertThat(query.getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + }, (query) -> { + assertThat(query.getQueryName(), equalTo("MatchAllDocsQuery")); + assertThat(query.getTimeBreakdown().keySet(), not(empty())); + assertThat(query.getTimeBreakdown().get("score"), equalTo(0L)); + assertThat(query.getTimeBreakdown().get("score_count"), equalTo(0L)); + assertThat(query.getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + }); + context.parsedPostFilter(null); + } + + { + context.setSearcher(newEarlyTerminationContextSearcher(reader, 1)); + context.trackTotalHitsUpTo(SearchContext.TRACK_TOTAL_HITS_DISABLED); + QueryPhase.executeInternal(context.withCleanQueryResult().withProfilers()); + assertNull(context.queryResult().terminatedEarly()); + assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(1)); + assertThat(context.queryResult().topDocs().topDocs.scoreDocs[0], instanceOf(FieldDoc.class)); + assertThat(fieldDoc.fields[0], anyOf(equalTo(1), equalTo(2))); + assertProfileData(context, "MatchAllDocsQuery", query -> { + assertThat(query.getTimeBreakdown().keySet(), not(empty())); + assertThat(query.getTimeBreakdown().get("score"), equalTo(0L)); + assertThat(query.getTimeBreakdown().get("score_count"), equalTo(0L)); + assertThat(query.getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + }, collector -> { + assertThat(collector.getReason(), equalTo("search_top_hits")); + assertThat(collector.getTime(), greaterThan(0L)); + assertThat(collector.getProfiledChildren(), empty()); + }); + + QueryPhase.executeInternal(context.withCleanQueryResult().withProfilers()); + assertNull(context.queryResult().terminatedEarly()); + assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(1)); + assertThat(context.queryResult().topDocs().topDocs.scoreDocs[0], instanceOf(FieldDoc.class)); + assertThat(fieldDoc.fields[0], anyOf(equalTo(1), equalTo(2))); + assertProfileData(context, "MatchAllDocsQuery", query -> { + assertThat(query.getTimeBreakdown().keySet(), not(empty())); + assertThat(query.getTimeBreakdown().get("score"), equalTo(0L)); + assertThat(query.getTimeBreakdown().get("score_count"), equalTo(0L)); + assertThat(query.getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + }, collector -> { + assertThat(collector.getReason(), equalTo("search_top_hits")); + assertThat(collector.getTime(), greaterThan(0L)); + assertThat(collector.getProfiledChildren(), empty()); + }); + } + + reader.close(); + dir.close(); + } + + public void testIndexSortScrollOptimization() throws Exception { + Directory dir = newDirectory(); + final Sort indexSort = new Sort(new SortField("rank", SortField.Type.INT), new SortField("tiebreaker", SortField.Type.INT)); + IndexWriterConfig iwc = newIndexWriterConfig().setIndexSort(indexSort); + RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc); + final int numDocs = scaledRandomIntBetween(600, 900); + for (int i = 0; i < numDocs; ++i) { + Document doc = new Document(); + doc.add(new NumericDocValuesField("rank", random().nextInt())); + doc.add(new NumericDocValuesField("tiebreaker", i)); + w.addDocument(doc); + } + if (randomBoolean()) { + w.forceMerge(randomIntBetween(1, 10)); + } + w.close(); + + final IndexReader reader = DirectoryReader.open(dir); + List searchSortAndFormats = new ArrayList<>(); + searchSortAndFormats.add(new SortAndFormats(indexSort, new DocValueFormat[] { DocValueFormat.RAW, DocValueFormat.RAW })); + // search sort is a prefix of the index sort + searchSortAndFormats.add(new SortAndFormats(new Sort(indexSort.getSort()[0]), new DocValueFormat[] { DocValueFormat.RAW })); + for (SortAndFormats searchSortAndFormat : searchSortAndFormats) { + ScrollContext scrollContext = new ScrollContext(); + TestSearchContext context = new TestSearchContext(null, indexShard, newContextSearcher(reader), scrollContext); + context.parsedQuery(new ParsedQuery(new MatchAllDocsQuery())); + scrollContext.lastEmittedDoc = null; + scrollContext.maxScore = Float.NaN; + scrollContext.totalHits = null; + context.setTask(new SearchShardTask(123L, "", "", "", null, Collections.emptyMap())); + context.setSize(10); + context.sort(searchSortAndFormat); + + QueryPhase.executeInternal(context.withCleanQueryResult().withProfilers()); + assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo((long) numDocs)); + assertNull(context.queryResult().terminatedEarly()); + assertThat(context.terminateAfter(), equalTo(0)); + assertThat(context.queryResult().getTotalHits().value, equalTo((long) numDocs)); + assertProfileData(context, "MatchAllDocsQuery", query -> { + assertThat(query.getTimeBreakdown().keySet(), not(empty())); + assertThat(query.getTimeBreakdown().get("score"), equalTo(0L)); + assertThat(query.getTimeBreakdown().get("score_count"), equalTo(0L)); + assertThat(query.getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + }, collector -> { + assertThat(collector.getReason(), equalTo("search_top_hits")); + assertThat(collector.getTime(), greaterThan(0L)); + assertThat(collector.getProfiledChildren(), empty()); + }); + + int sizeMinus1 = context.queryResult().topDocs().topDocs.scoreDocs.length - 1; + FieldDoc lastDoc = (FieldDoc) context.queryResult().topDocs().topDocs.scoreDocs[sizeMinus1]; + + context.setSearcher(newEarlyTerminationContextSearcher(reader, 10)); + QueryPhase.executeInternal(context.withCleanQueryResult().withProfilers()); + assertNull(context.queryResult().terminatedEarly()); + assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo((long) numDocs)); + assertThat(context.terminateAfter(), equalTo(0)); + assertThat(context.queryResult().getTotalHits().value, equalTo((long) numDocs)); + assertProfileData(context, "ConstantScoreQuery", query -> { + assertThat(query.getTimeBreakdown().keySet(), not(empty())); + assertThat(query.getTimeBreakdown().get("score"), equalTo(0L)); + assertThat(query.getTimeBreakdown().get("score_count"), equalTo(0L)); + assertThat(query.getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + + assertThat(query.getProfiledChildren(), hasSize(1)); + assertThat(query.getProfiledChildren().get(0).getQueryName(), equalTo("SearchAfterSortedDocQuery")); + assertThat(query.getProfiledChildren().get(0).getTime(), greaterThan(0L)); + assertThat(query.getProfiledChildren().get(0).getTimeBreakdown().get("score"), equalTo(0L)); + assertThat(query.getProfiledChildren().get(0).getTimeBreakdown().get("score_count"), equalTo(0L)); + assertThat(query.getProfiledChildren().get(0).getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getProfiledChildren().get(0).getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + }, collector -> { + assertThat(collector.getReason(), equalTo("search_top_hits")); + assertThat(collector.getTime(), greaterThan(0L)); + assertThat(collector.getProfiledChildren(), empty()); + }); + FieldDoc firstDoc = (FieldDoc) context.queryResult().topDocs().topDocs.scoreDocs[0]; + for (int i = 0; i < searchSortAndFormat.sort.getSort().length; i++) { + @SuppressWarnings("unchecked") + FieldComparator comparator = (FieldComparator) searchSortAndFormat.sort.getSort()[i].getComparator( + i, + false + ); + int cmp = comparator.compareValues(firstDoc.fields[i], lastDoc.fields[i]); + if (cmp == 0) { + continue; + } + assertThat(cmp, equalTo(1)); + break; + } + } + reader.close(); + dir.close(); + } + + public void testDisableTopScoreCollection() throws Exception { + Directory dir = newDirectory(); + IndexWriterConfig iwc = newIndexWriterConfig(new StandardAnalyzer()); + RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc); + Document doc = new Document(); + final int numDocs = 2 * scaledRandomIntBetween(50, 450); + for (int i = 0; i < numDocs; i++) { + doc.clear(); + if (i % 2 == 0) { + doc.add(new TextField("title", "foo bar", Store.NO)); + } else { + doc.add(new TextField("title", "foo", Store.NO)); + } + w.addDocument(doc); + } + w.close(); + + IndexReader reader = DirectoryReader.open(dir); + TestSearchContext context = new TestSearchContext(null, indexShard, newContextSearcher(reader)); + context.setTask(new SearchShardTask(123L, "", "", "", null, Collections.emptyMap())); + Query q = new SpanNearQuery.Builder("title", true).addClause(new SpanTermQuery(new Term("title", "foo"))) + .addClause(new SpanTermQuery(new Term("title", "bar"))) + .build(); + + context.parsedQuery(new ParsedQuery(q)); + context.setSize(3); + context.trackTotalHitsUpTo(3); + TopDocsCollectorContext topDocsContext = TopDocsCollectorContext.createTopDocsCollectorContext(context, false); + assertEquals(topDocsContext.create(null).scoreMode(), org.apache.lucene.search.ScoreMode.COMPLETE); + QueryPhase.executeInternal(context.withCleanQueryResult().withProfilers()); + assertEquals(numDocs / 2, context.queryResult().topDocs().topDocs.totalHits.value); + assertEquals(context.queryResult().topDocs().topDocs.totalHits.relation, TotalHits.Relation.EQUAL_TO); + assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(3)); + assertProfileData(context, "SpanNearQuery", query -> { + assertThat(query.getTimeBreakdown().keySet(), not(empty())); + assertThat(query.getTimeBreakdown().get("score"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("score_count"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + }, collector -> { + assertThat(collector.getReason(), equalTo("search_top_hits")); + assertThat(collector.getTime(), greaterThan(0L)); + assertThat(collector.getProfiledChildren(), empty()); + }); + + context.sort(new SortAndFormats(new Sort(new SortField("other", SortField.Type.INT)), new DocValueFormat[] { DocValueFormat.RAW })); + topDocsContext = TopDocsCollectorContext.createTopDocsCollectorContext(context, false); + assertEquals(topDocsContext.create(null).scoreMode(), org.apache.lucene.search.ScoreMode.TOP_DOCS); + QueryPhase.executeInternal(context.withCleanQueryResult().withProfilers()); + assertEquals(numDocs / 2, context.queryResult().topDocs().topDocs.totalHits.value); + assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(3)); + assertEquals(context.queryResult().topDocs().topDocs.totalHits.relation, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO); + assertProfileData(context, "SpanNearQuery", query -> { + assertThat(query.getTimeBreakdown().keySet(), not(empty())); + assertThat(query.getTimeBreakdown().get("score"), equalTo(0L)); + assertThat(query.getTimeBreakdown().get("score_count"), equalTo(0L)); + assertThat(query.getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + }, collector -> { + assertThat(collector.getReason(), equalTo("search_top_hits")); + assertThat(collector.getTime(), greaterThan(0L)); + assertThat(collector.getProfiledChildren(), empty()); + }); + + reader.close(); + dir.close(); + } + + public void testMinScore() throws Exception { + Directory dir = newDirectory(); + IndexWriterConfig iwc = newIndexWriterConfig(); + RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc); + for (int i = 0; i < 10; i++) { + Document doc = new Document(); + doc.add(new StringField("foo", "bar", Store.NO)); + doc.add(new StringField("filter", "f1", Store.NO)); + w.addDocument(doc); + } + w.close(); + + IndexReader reader = DirectoryReader.open(dir); + TestSearchContext context = new TestSearchContext(null, indexShard, newContextSearcher(reader)); + context.parsedQuery( + new ParsedQuery( + new BooleanQuery.Builder().add(new TermQuery(new Term("foo", "bar")), Occur.MUST) + .add(new TermQuery(new Term("filter", "f1")), Occur.SHOULD) + .build() + ) + ); + context.minimumScore(0.01f); + context.setTask(new SearchShardTask(123L, "", "", "", null, Collections.emptyMap())); + context.setSize(1); + context.trackTotalHitsUpTo(5); + + QueryPhase.executeInternal(context.withCleanQueryResult().withProfilers()); + assertEquals(10, context.queryResult().topDocs().topDocs.totalHits.value); + assertProfileData(context, "BooleanQuery", query -> { + assertThat(query.getTimeBreakdown().keySet(), not(empty())); + assertThat(query.getTimeBreakdown().get("score"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("score_count"), equalTo(10L)); + assertThat(query.getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + + assertThat(query.getProfiledChildren(), hasSize(2)); + assertThat(query.getProfiledChildren().get(0).getQueryName(), equalTo("TermQuery")); + assertThat(query.getProfiledChildren().get(0).getTime(), greaterThan(0L)); + assertThat(query.getProfiledChildren().get(0).getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getProfiledChildren().get(0).getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + + assertThat(query.getProfiledChildren().get(1).getQueryName(), equalTo("TermQuery")); + assertThat(query.getProfiledChildren().get(1).getTime(), greaterThan(0L)); + assertThat(query.getProfiledChildren().get(1).getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getProfiledChildren().get(1).getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + }, collector -> { + assertThat(collector.getReason(), equalTo("search_min_score")); + assertThat(collector.getTime(), greaterThan(0L)); + assertThat(collector.getProfiledChildren(), hasSize(1)); + assertThat(collector.getProfiledChildren().get(0).getReason(), equalTo("search_top_hits")); + assertThat(collector.getProfiledChildren().get(0).getTime(), greaterThan(0L)); + }); + + reader.close(); + dir.close(); + } + + public void testMaxScore() throws Exception { + Directory dir = newDirectory(); + final Sort sort = new Sort(new SortField("filter", SortField.Type.STRING)); + IndexWriterConfig iwc = newIndexWriterConfig().setIndexSort(sort); + RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc); + + final int numDocs = scaledRandomIntBetween(600, 900); + for (int i = 0; i < numDocs; i++) { + Document doc = new Document(); + doc.add(new StringField("foo", "bar", Store.NO)); + doc.add(new StringField("filter", "f1" + ((i > 0) ? " " + Integer.toString(i) : ""), Store.NO)); + doc.add(new SortedDocValuesField("filter", newBytesRef("f1" + ((i > 0) ? " " + Integer.toString(i) : "")))); + w.addDocument(doc); + } + w.close(); + + IndexReader reader = DirectoryReader.open(dir); + TestSearchContext context = new TestSearchContext(null, indexShard, newContextSearcher(reader)); + context.trackScores(true); + context.parsedQuery( + new ParsedQuery( + new BooleanQuery.Builder().add(new TermQuery(new Term("foo", "bar")), Occur.MUST) + .add(new TermQuery(new Term("filter", "f1")), Occur.SHOULD) + .build() + ) + ); + context.setTask(new SearchShardTask(123L, "", "", "", null, Collections.emptyMap())); + context.setSize(1); + context.trackTotalHitsUpTo(5); + + QueryPhase.executeInternal(context.withCleanQueryResult().withProfilers()); + assertFalse(Float.isNaN(context.queryResult().getMaxScore())); + assertEquals(1, context.queryResult().topDocs().topDocs.scoreDocs.length); + assertThat(context.queryResult().topDocs().topDocs.totalHits.value, greaterThanOrEqualTo(6L)); + assertProfileData(context, "BooleanQuery", query -> { + assertThat(query.getTimeBreakdown().keySet(), not(empty())); + assertThat(query.getTimeBreakdown().get("score"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("score_count"), greaterThanOrEqualTo(6L)); + assertThat(query.getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + + assertThat(query.getProfiledChildren(), hasSize(2)); + assertThat(query.getProfiledChildren().get(0).getQueryName(), equalTo("TermQuery")); + assertThat(query.getProfiledChildren().get(0).getTime(), greaterThan(0L)); + assertThat(query.getProfiledChildren().get(0).getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getProfiledChildren().get(0).getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + + assertThat(query.getProfiledChildren().get(1).getQueryName(), equalTo("TermQuery")); + assertThat(query.getProfiledChildren().get(1).getTime(), greaterThan(0L)); + assertThat(query.getProfiledChildren().get(1).getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getProfiledChildren().get(1).getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + }, collector -> { + assertThat(collector.getReason(), equalTo("search_top_hits")); + assertThat(collector.getTime(), greaterThan(0L)); + assertThat(collector.getProfiledChildren(), empty()); + }); + + context.sort(new SortAndFormats(sort, new DocValueFormat[] { DocValueFormat.RAW })); + QueryPhase.executeInternal(context.withCleanQueryResult().withProfilers()); + assertFalse(Float.isNaN(context.queryResult().getMaxScore())); + assertEquals(1, context.queryResult().topDocs().topDocs.scoreDocs.length); + assertThat(context.queryResult().topDocs().topDocs.totalHits.value, greaterThanOrEqualTo(6L)); + assertProfileData(context, "BooleanQuery", query -> { + assertThat(query.getTimeBreakdown().keySet(), not(empty())); + assertThat(query.getTimeBreakdown().get("score"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("score_count"), greaterThanOrEqualTo(6L)); + assertThat(query.getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + + assertThat(query.getProfiledChildren(), hasSize(2)); + assertThat(query.getProfiledChildren().get(0).getQueryName(), equalTo("TermQuery")); + assertThat(query.getProfiledChildren().get(0).getTime(), greaterThan(0L)); + assertThat(query.getProfiledChildren().get(0).getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getProfiledChildren().get(0).getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + + assertThat(query.getProfiledChildren().get(1).getQueryName(), equalTo("TermQuery")); + assertThat(query.getProfiledChildren().get(1).getTime(), greaterThan(0L)); + assertThat(query.getProfiledChildren().get(1).getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getProfiledChildren().get(1).getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + }, collector -> { + assertThat(collector.getReason(), equalTo("search_top_hits")); + assertThat(collector.getTime(), greaterThan(0L)); + assertThat(collector.getProfiledChildren(), empty()); + }); + + reader.close(); + dir.close(); + } + + public void testCollapseQuerySearchResults() throws Exception { + Directory dir = newDirectory(); + final Sort sort = new Sort(new SortField("user", SortField.Type.INT)); + IndexWriterConfig iwc = newIndexWriterConfig().setIndexSort(sort); + RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc); + + // Always end up with uneven buckets so collapsing is predictable + final int numDocs = 2 * scaledRandomIntBetween(600, 900) - 1; + for (int i = 0; i < numDocs; i++) { + Document doc = new Document(); + doc.add(new StringField("foo", "bar", Store.NO)); + doc.add(new NumericDocValuesField("user", i & 1)); + w.addDocument(doc); + } + w.close(); + + IndexReader reader = DirectoryReader.open(dir); + QueryShardContext queryShardContext = mock(QueryShardContext.class); + when(queryShardContext.fieldMapper("user")).thenReturn( + new NumberFieldType("user", NumberType.INTEGER, true, false, true, false, null, Collections.emptyMap()) + ); + + TestSearchContext context = new TestSearchContext(queryShardContext, indexShard, newContextSearcher(reader)); + context.collapse(new CollapseBuilder("user").build(context.getQueryShardContext())); + context.trackScores(true); + context.parsedQuery(new ParsedQuery(new TermQuery(new Term("foo", "bar")))); + context.setTask(new SearchShardTask(123L, "", "", "", null, Collections.emptyMap())); + context.setSize(2); + context.trackTotalHitsUpTo(5); + + QueryPhase.executeInternal(context.withCleanQueryResult().withProfilers()); + assertFalse(Float.isNaN(context.queryResult().getMaxScore())); + assertEquals(2, context.queryResult().topDocs().topDocs.scoreDocs.length); + assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo((long) numDocs)); + assertThat(context.queryResult().topDocs().topDocs, instanceOf(CollapseTopFieldDocs.class)); + + assertProfileData(context, "TermQuery", query -> { + assertThat(query.getTimeBreakdown().keySet(), not(empty())); + assertThat(query.getTimeBreakdown().get("score"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("score_count"), greaterThanOrEqualTo(6L)); + assertThat(query.getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + assertThat(query.getProfiledChildren(), empty()); + }, collector -> { + assertThat(collector.getReason(), equalTo("search_top_hits")); + assertThat(collector.getTime(), greaterThan(0L)); + assertThat(collector.getProfiledChildren(), empty()); + }); + + context.sort(new SortAndFormats(sort, new DocValueFormat[] { DocValueFormat.RAW })); + QueryPhase.executeInternal(context.withCleanQueryResult().withProfilers()); + assertFalse(Float.isNaN(context.queryResult().getMaxScore())); + assertEquals(2, context.queryResult().topDocs().topDocs.scoreDocs.length); + assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo((long) numDocs)); + assertThat(context.queryResult().topDocs().topDocs, instanceOf(CollapseTopFieldDocs.class)); + + assertProfileData(context, "TermQuery", query -> { + assertThat(query.getTimeBreakdown().keySet(), not(empty())); + assertThat(query.getTimeBreakdown().get("score"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("score_count"), greaterThanOrEqualTo(6L)); + assertThat(query.getTimeBreakdown().get("create_weight"), greaterThan(0L)); + assertThat(query.getTimeBreakdown().get("create_weight_count"), equalTo(1L)); + assertThat(query.getProfiledChildren(), empty()); + }, collector -> { + assertThat(collector.getReason(), equalTo("search_top_hits")); + assertThat(collector.getTime(), greaterThan(0L)); + assertThat(collector.getProfiledChildren(), empty()); + }); + + reader.close(); + dir.close(); + } + + private void assertProfileData(SearchContext context, String type, Consumer query, Consumer collector) + throws IOException { + assertProfileData(context, collector, (profileResult) -> { + assertThat(profileResult.getQueryName(), equalTo(type)); + assertThat(profileResult.getTime(), greaterThan(0L)); + query.accept(profileResult); + }); + } + + private void assertProfileData(SearchContext context, Consumer collector, Consumer query1) + throws IOException { + assertProfileData(context, Arrays.asList(query1), collector, false); + } + + private void assertProfileData( + SearchContext context, + Consumer collector, + Consumer query1, + Consumer query2 + ) throws IOException { + assertProfileData(context, Arrays.asList(query1, query2), collector, false); + } + + private final void assertProfileData( + SearchContext context, + List> queries, + Consumer collector, + boolean debug + ) throws IOException { + assertThat(context.getProfilers(), not(nullValue())); + + final ProfileShardResult result = SearchProfileShardResults.buildShardResults(context.getProfilers(), null); + if (debug) { + final SearchProfileShardResults results = new SearchProfileShardResults( + Collections.singletonMap(indexShard.shardId().toString(), result) + ); + + try (final XContentBuilder builder = JsonXContent.contentBuilder().prettyPrint()) { + builder.startObject(); + results.toXContent(builder, ToXContent.EMPTY_PARAMS); + builder.endObject(); + builder.flush(); + + final OutputStream out = builder.getOutputStream(); + assertThat(out, instanceOf(ByteArrayOutputStream.class)); + + logger.info(new String(((ByteArrayOutputStream) out).toByteArray(), StandardCharsets.UTF_8)); + } + } + + assertThat(result.getQueryProfileResults(), hasSize(1)); + + final QueryProfileShardResult queryProfileShardResult = result.getQueryProfileResults().get(0); + assertThat(queryProfileShardResult.getQueryResults(), hasSize(queries.size())); + + for (int i = 0; i < queries.size(); ++i) { + queries.get(i).accept(queryProfileShardResult.getQueryResults().get(i)); + } + + collector.accept(queryProfileShardResult.getCollectorResult()); + } + + private static ContextIndexSearcher newContextSearcher(IndexReader reader) throws IOException { + return new ContextIndexSearcher( + reader, + IndexSearcher.getDefaultSimilarity(), + IndexSearcher.getDefaultQueryCache(), + IndexSearcher.getDefaultQueryCachingPolicy(), + true, + null + ); + } + + private static ContextIndexSearcher newEarlyTerminationContextSearcher(IndexReader reader, int size) throws IOException { + return new ContextIndexSearcher( + reader, + IndexSearcher.getDefaultSimilarity(), + IndexSearcher.getDefaultQueryCache(), + IndexSearcher.getDefaultQueryCachingPolicy(), + true, + null + ) { + + @Override + public void search(List leaves, Weight weight, Collector collector) throws IOException { + final Collector in = new AssertingEarlyTerminationFilterCollector(collector, size); + super.search(leaves, weight, in); + } + }; + } + + private static class AssertingEarlyTerminationFilterCollector extends FilterCollector { + private final int size; + + AssertingEarlyTerminationFilterCollector(Collector in, int size) { + super(in); + this.size = size; + } + + @Override + public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException { + final LeafCollector in = super.getLeafCollector(context); + return new FilterLeafCollector(in) { + int collected; + + @Override + public void collect(int doc) throws IOException { + assert collected <= size : "should not collect more than " + size + " doc per segment, got " + collected; + ++collected; + super.collect(doc); + } + }; + } + } +} diff --git a/test/framework/src/main/java/org/opensearch/search/aggregations/AggregatorTestCase.java b/test/framework/src/main/java/org/opensearch/search/aggregations/AggregatorTestCase.java index 38a0253305833..832328cb0242f 100644 --- a/test/framework/src/main/java/org/opensearch/search/aggregations/AggregatorTestCase.java +++ b/test/framework/src/main/java/org/opensearch/search/aggregations/AggregatorTestCase.java @@ -334,7 +334,8 @@ public boolean shouldCache(Query query) { indexSearcher.getSimilarity(), queryCache, queryCachingPolicy, - false + false, + null ); SearchContext searchContext = mock(SearchContext.class); diff --git a/test/framework/src/main/java/org/opensearch/test/TestSearchContext.java b/test/framework/src/main/java/org/opensearch/test/TestSearchContext.java index 0e91332892a55..0b2235a0afedd 100644 --- a/test/framework/src/main/java/org/opensearch/test/TestSearchContext.java +++ b/test/framework/src/main/java/org/opensearch/test/TestSearchContext.java @@ -32,6 +32,7 @@ package org.opensearch.test; import org.apache.lucene.search.Collector; +import org.apache.lucene.search.CollectorManager; import org.apache.lucene.search.FieldDoc; import org.apache.lucene.search.Query; import org.opensearch.action.OriginalIndices; @@ -70,6 +71,7 @@ import org.opensearch.search.internal.ShardSearchRequest; import org.opensearch.search.profile.Profilers; import org.opensearch.search.query.QuerySearchResult; +import org.opensearch.search.query.ReduceableSearchResult; import org.opensearch.search.rescore.RescoreContext; import org.opensearch.search.sort.SortAndFormats; import org.opensearch.search.suggest.SuggestionSearchContext; @@ -90,7 +92,7 @@ public class TestSearchContext extends SearchContext { final BigArrays bigArrays; final IndexService indexService; final BitsetFilterCache fixedBitSetFilterCache; - final Map, Collector> queryCollectors = new HashMap<>(); + final Map, CollectorManager> queryCollectorManagers = new HashMap<>(); final IndexShard indexShard; final QuerySearchResult queryResult = new QuerySearchResult(); final QueryShardContext queryShardContext; @@ -110,7 +112,9 @@ public class TestSearchContext extends SearchContext { private SearchContextAggregations aggregations; private ScrollContext scrollContext; private FieldDoc searchAfter; - private final long originNanoTime = System.nanoTime(); + private Profilers profilers; + private CollapseContext collapse; + private final Map searchExtBuilders = new HashMap<>(); public TestSearchContext(BigArrays bigArrays, IndexService indexService) { @@ -405,12 +409,13 @@ public FieldDoc searchAfter() { @Override public SearchContext collapse(CollapseContext collapse) { - return null; + this.collapse = collapse; + return this; } @Override public CollapseContext collapse() { - return null; + return collapse; } @Override @@ -596,12 +601,12 @@ public long getRelativeTimeInMillis() { @Override public Profilers getProfilers() { - return null; // no profiling + return profilers; } @Override - public Map, Collector> queryCollectors() { - return queryCollectors; + public Map, CollectorManager> queryCollectorManagers() { + return queryCollectorManagers; } @Override @@ -633,4 +638,21 @@ public void addRescore(RescoreContext rescore) { public ReaderContext readerContext() { throw new UnsupportedOperationException(); } + + /** + * Clean the query results by consuming all of it + */ + public TestSearchContext withCleanQueryResult() { + queryResult.consumeAll(); + profilers = null; + return this; + } + + /** + * Add profilers to the query + */ + public TestSearchContext withProfilers() { + this.profilers = new Profilers(searcher); + return this; + } }