diff --git a/CHANGELOG.md b/CHANGELOG.md index 2ddf281f6..8dcdc721b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,7 +22,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Features ### Enhancements ### Bug Fixes -- Adding two phase iterator for hybrid query ([#624](https://github.com/opensearch-project/neural-search/pull/624)) +- Fix runtime exceptions in hybrid query for case when sub-query scorer return TwoPhase iterator that is incompatible with DISI iterator ([#624](https://github.com/opensearch-project/neural-search/pull/624)) ### Infrastructure ### Documentation ### Maintenance diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java index 46c087894..60d9fd639 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java @@ -53,7 +53,7 @@ public final class HybridQueryBuilder extends AbstractQueryBuilder> queryToIndex; private final DocIdSetIterator approximation; - HybridScoreBlockBoundaryPropagator disjunctionBlockPropagator; + private final HybridScoreBlockBoundaryPropagator disjunctionBlockPropagator; private final TwoPhase twoPhase; - public HybridQueryScorer(Weight weight, List subScorers) throws IOException { + public HybridQueryScorer(final Weight weight, final List subScorers) throws IOException { this(weight, subScorers, ScoreMode.TOP_SCORES); } - public HybridQueryScorer(Weight weight, List subScorers, ScoreMode scoreMode) throws IOException { + HybridQueryScorer(final Weight weight, final List subScorers, final ScoreMode scoreMode) throws IOException { super(weight); this.subScorers = Collections.unmodifiableList(subScorers); subScores = new float[subScorers.size()]; diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryWeight.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryWeight.java index 76bdd5f00..facb79694 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryWeight.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryWeight.java @@ -21,6 +21,8 @@ import org.apache.lucene.search.ScorerSupplier; import org.apache.lucene.search.Weight; +import static org.opensearch.neuralsearch.query.HybridQueryBuilder.MAX_NUMBER_OF_SUB_QUERIES; + /** * Calculates query weights and build query scorers for hybrid query. */ @@ -31,8 +33,6 @@ public final class HybridQueryWeight extends Weight { private final ScoreMode scoreMode; - static final int BOOLEAN_REWRITE_TERM_COUNT_THRESHOLD = 16; - /** * Construct the Weight for this Query searched by searcher. Recursively construct subquery weights. */ @@ -108,9 +108,8 @@ public Scorer scorer(LeafReaderContext context) throws IOException { */ @Override public boolean isCacheable(LeafReaderContext ctx) { - if (weights.size() > BOOLEAN_REWRITE_TERM_COUNT_THRESHOLD) { - // Disallow caching large queries to not encourage users - // to build large queries + if (weights.size() > MAX_NUMBER_OF_SUB_QUERIES) { + // this situation should never happen, but in case it do such query will not be cached return false; } return weights.stream().allMatch(w -> w.isCacheable(ctx)); diff --git a/src/main/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollector.java b/src/main/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollector.java index 9190bfeac..79b134b38 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollector.java +++ b/src/main/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollector.java @@ -8,6 +8,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Locale; +import java.util.Objects; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -79,6 +80,9 @@ private HybridQueryScorer getHybridQueryScorer(final Scorable scorer) throws IOE @Override public void collect(int doc) throws IOException { + if (Objects.isNull(compoundQueryScorer)) { + throw new IllegalArgumentException("scorers are null for all sub-queries in hybrid query"); + } float[] subScoresByQuery = compoundQueryScorer.hybridScores(); // iterate over results for each query if (compoundScores == null) { diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessor.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessor.java new file mode 100644 index 000000000..8b584feea --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessor.java @@ -0,0 +1,77 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.search.query; + +import lombok.AllArgsConstructor; +import org.apache.lucene.search.CollectorManager; +import org.opensearch.search.aggregations.AggregationProcessor; +import org.opensearch.search.internal.SearchContext; +import org.opensearch.search.query.QueryPhaseExecutionException; +import org.opensearch.search.query.QuerySearchResult; +import org.opensearch.search.query.ReduceableSearchResult; + +import java.io.IOException; +import java.util.Collection; +import java.util.List; + +import static org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher.isHybridQuery; + +/** + * Defines logic for pre- and post-phases of document scores collection. Responsible for registering custom + * collector manager for hybris query (pre phase) and reducing results (post phase) + */ +@AllArgsConstructor +public class HybridAggregationProcessor implements AggregationProcessor { + + private final AggregationProcessor delegateAggsProcessor; + + @Override + public void preProcess(SearchContext context) { + delegateAggsProcessor.preProcess(context); + + if (isHybridQuery(context.query(), context)) { + // adding collector manager for hybrid query + CollectorManager collectorManager; + try { + collectorManager = HybridCollectorManager.createHybridCollectorManager(context); + } catch (IOException e) { + throw new RuntimeException(e); + } + context.queryCollectorManagers().put(HybridCollectorManager.class, collectorManager); + } + } + + @Override + public void postProcess(SearchContext context) { + if (isHybridQuery(context.query(), context)) { + // for case when concurrent search is not enabled (default as of 2.12 release) reduce for collector + // managers is not called, and we have to call it manually. This is required as we format final + // result of hybrid query in {@link HybridTopScoreCollector#reduce} + if (!context.shouldUseConcurrentSearch()) { + reduceCollectorResults(context); + } + updateQueryResult(context.queryResult(), context); + } + + delegateAggsProcessor.postProcess(context); + } + + private void reduceCollectorResults(SearchContext context) { + CollectorManager collectorManager = context.queryCollectorManagers().get(HybridCollectorManager.class); + try { + final Collection collectors = List.of(collectorManager.newCollector()); + collectorManager.reduce(collectors).reduce(context.queryResult()); + } catch (IOException e) { + throw new QueryPhaseExecutionException(context.shardTarget(), "failed to execute hybrid query aggregation processor", e); + } + } + + private void updateQueryResult(final QuerySearchResult queryResult, final SearchContext searchContext) { + boolean isSingleShard = searchContext.numberOfShards() == 1; + if (isSingleShard) { + searchContext.size(queryResult.queryResult().topDocs().topDocs.scoreDocs.length); + } + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java index 1d715a14c..40b10c5f3 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java @@ -85,10 +85,22 @@ Collector getCollector() { return hybridcollector; } + /** + * Reduce the results from hybrid scores collector into a format specific for hybrid search query: + * - start + * - sub-query-delimiter + * - scores + * - stop + * Ignore other collectors if they are present in the context + * @param collectors collection of collectors after they has been executed and collected documents and scores + * @return search results that can be reduced be the caller + */ @Override public ReduceableSearchResult reduce(Collection collectors) { final List hybridTopScoreDocCollectors = new ArrayList<>(); - + // check if collector for hybrid query scores is part of this search context. It can be wrapped into MultiCollectorWrapper + // in case multiple collector managers are registered. We use hybrid scores collector to format scores into + // format specific for hybrid search query: start, sub-query-delimiter, scores, stop for (final Collector collector : collectors) { if (collector instanceof MultiCollectorWrapper) { for (final Collector sub : (((MultiCollectorWrapper) collector).getCollectors())) { diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java index 5fc6017f2..6461c698e 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java @@ -5,17 +5,12 @@ package org.opensearch.neuralsearch.search.query; import java.io.IOException; -import java.util.Collection; import java.util.LinkedList; import java.util.List; -import java.util.Map; import com.google.common.annotations.VisibleForTesting; -import lombok.AllArgsConstructor; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; -import org.apache.lucene.search.Collector; -import org.apache.lucene.search.CollectorManager; import org.apache.lucene.search.FieldExistsQuery; import org.apache.lucene.search.Query; import org.opensearch.common.settings.Settings; @@ -28,10 +23,7 @@ import org.opensearch.search.internal.SearchContext; import org.opensearch.search.query.QueryCollectorContext; import org.opensearch.search.query.QueryPhase; -import org.opensearch.search.query.QueryPhaseExecutionException; import org.opensearch.search.query.QueryPhaseSearcherWrapper; -import org.opensearch.search.query.QuerySearchResult; -import org.opensearch.search.query.ReduceableSearchResult; import lombok.extern.log4j.Log4j2; @@ -181,58 +173,4 @@ public AggregationProcessor aggregationProcessor(SearchContext searchContext) { AggregationProcessor coreAggProcessor = super.aggregationProcessor(searchContext); return new HybridAggregationProcessor(coreAggProcessor); } - - @AllArgsConstructor - public static class HybridAggregationProcessor implements AggregationProcessor { - - private final AggregationProcessor delegateAggsProcessor; - - @Override - public void preProcess(SearchContext context) { - delegateAggsProcessor.preProcess(context); - - if (isHybridQuery(context.query(), context)) { - // adding collector manager for hybrid query - CollectorManager collectorManager; - try { - collectorManager = HybridCollectorManager.createHybridCollectorManager(context); - } catch (IOException e) { - throw new RuntimeException(e); - } - Map, CollectorManager> collectorManagersByManagerClass = context - .queryCollectorManagers(); - collectorManagersByManagerClass.put(HybridCollectorManager.class, collectorManager); - } - } - - @Override - public void postProcess(SearchContext context) { - if (isHybridQuery(context.query(), context)) { - if (!context.shouldUseConcurrentSearch()) { - reduceCollectorResults(context); - } - updateQueryResult(context.queryResult(), context); - } - - delegateAggsProcessor.postProcess(context); - } - - private void reduceCollectorResults(SearchContext context) { - CollectorManager collectorManager = context.queryCollectorManagers() - .get(HybridCollectorManager.class); - try { - final Collection collectors = List.of(collectorManager.newCollector()); - collectorManager.reduce(collectors).reduce(context.queryResult()); - } catch (IOException e) { - throw new QueryPhaseExecutionException(context.shardTarget(), "failed to execute hybrid query aggregation processor", e); - } - } - - private void updateQueryResult(final QuerySearchResult queryResult, final SearchContext searchContext) { - boolean isSingleShard = searchContext.numberOfShards() == 1; - if (isSingleShard) { - searchContext.size(queryResult.queryResult().topDocs().topDocs.scoreDocs.length); - } - } - } } diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessorTests.java index 1c919b581..f44e762f0 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessorTests.java @@ -49,8 +49,7 @@ public class HybridAggregationProcessorTests extends OpenSearchQueryTestCase { @SneakyThrows public void testAggregationProcessorDelegate_whenPreAndPostAreCalled_thenSuccessful() { AggregationProcessor mockAggsProcessorDelegate = mock(AggregationProcessor.class); - HybridQueryPhaseSearcher.HybridAggregationProcessor hybridAggregationProcessor = - new HybridQueryPhaseSearcher.HybridAggregationProcessor(mockAggsProcessorDelegate); + HybridAggregationProcessor hybridAggregationProcessor = new HybridAggregationProcessor(mockAggsProcessorDelegate); SearchContext searchContext = mock(SearchContext.class); hybridAggregationProcessor.preProcess(searchContext); @@ -63,8 +62,7 @@ public void testAggregationProcessorDelegate_whenPreAndPostAreCalled_thenSuccess @SneakyThrows public void testCollectorManager_whenHybridQueryAndNotConcurrentSearch_thenSuccessful() { AggregationProcessor mockAggsProcessorDelegate = mock(AggregationProcessor.class); - HybridQueryPhaseSearcher.HybridAggregationProcessor hybridAggregationProcessor = - new HybridQueryPhaseSearcher.HybridAggregationProcessor(mockAggsProcessorDelegate); + HybridAggregationProcessor hybridAggregationProcessor = new HybridAggregationProcessor(mockAggsProcessorDelegate); SearchContext searchContext = mock(SearchContext.class); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); @@ -124,8 +122,7 @@ public void testCollectorManager_whenHybridQueryAndNotConcurrentSearch_thenSucce @SneakyThrows public void testCollectorManager_whenHybridQueryAndConcurrentSearch_thenSuccessful() { AggregationProcessor mockAggsProcessorDelegate = mock(AggregationProcessor.class); - HybridQueryPhaseSearcher.HybridAggregationProcessor hybridAggregationProcessor = - new HybridQueryPhaseSearcher.HybridAggregationProcessor(mockAggsProcessorDelegate); + HybridAggregationProcessor hybridAggregationProcessor = new HybridAggregationProcessor(mockAggsProcessorDelegate); SearchContext searchContext = mock(SearchContext.class); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); @@ -185,8 +182,7 @@ public void testCollectorManager_whenHybridQueryAndConcurrentSearch_thenSuccessf @SneakyThrows public void testCollectorManager_whenNotHybridQueryAndNotConcurrentSearch_thenSuccessful() { AggregationProcessor mockAggsProcessorDelegate = mock(AggregationProcessor.class); - HybridQueryPhaseSearcher.HybridAggregationProcessor hybridAggregationProcessor = - new HybridQueryPhaseSearcher.HybridAggregationProcessor(mockAggsProcessorDelegate); + HybridAggregationProcessor hybridAggregationProcessor = new HybridAggregationProcessor(mockAggsProcessorDelegate); SearchContext searchContext = mock(SearchContext.class); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java index 602c87440..2aebbb5d8 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java @@ -807,7 +807,7 @@ public void testAggsProcessor_whenGettingAggsProcessor_thenSuccess() { SearchContext searchContext = mock(SearchContext.class); AggregationProcessor aggregationProcessor = hybridQueryPhaseSearcher.aggregationProcessor(searchContext); assertNotNull(aggregationProcessor); - assertTrue(aggregationProcessor instanceof HybridQueryPhaseSearcher.HybridAggregationProcessor); + assertTrue(aggregationProcessor instanceof HybridAggregationProcessor); } @SneakyThrows