Skip to content

Commit

Permalink
Addressing Jacks and Navneets comments
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <gaievski@amazon.com>
  • Loading branch information
martin-gaievski committed Mar 7, 2024
1 parent 253781c commit 67df195
Show file tree
Hide file tree
Showing 10 changed files with 108 additions and 82 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Features
### Enhancements
### Bug Fixes
- Adding two phase iterator for hybrid query ([#624](https://github.com/opensearch-project/neural-search/pull/624))
- Fix runtime exceptions in hybrid query for case when sub-query scorer return TwoPhase iterator that is incompatible with DISI iterator ([#624](https://github.com/opensearch-project/neural-search/pull/624))
### Infrastructure
### Documentation
### Maintenance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ public final class HybridQueryBuilder extends AbstractQueryBuilder<HybridQueryBu

private String fieldName;

private static final int MAX_NUMBER_OF_SUB_QUERIES = 5;
static final int MAX_NUMBER_OF_SUB_QUERIES = 5;

public HybridQueryBuilder(StreamInput in) throws IOException {
super(in);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,14 @@ public final class HybridQueryScorer extends Scorer {
private final Map<Query, List<Integer>> queryToIndex;

private final DocIdSetIterator approximation;
HybridScoreBlockBoundaryPropagator disjunctionBlockPropagator;
private final HybridScoreBlockBoundaryPropagator disjunctionBlockPropagator;
private final TwoPhase twoPhase;

public HybridQueryScorer(Weight weight, List<Scorer> subScorers) throws IOException {
public HybridQueryScorer(final Weight weight, final List<Scorer> subScorers) throws IOException {
this(weight, subScorers, ScoreMode.TOP_SCORES);
}

public HybridQueryScorer(Weight weight, List<Scorer> subScorers, ScoreMode scoreMode) throws IOException {
HybridQueryScorer(final Weight weight, final List<Scorer> subScorers, final ScoreMode scoreMode) throws IOException {
super(weight);
this.subScorers = Collections.unmodifiableList(subScorers);
subScores = new float[subScorers.size()];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import org.apache.lucene.search.ScorerSupplier;
import org.apache.lucene.search.Weight;

import static org.opensearch.neuralsearch.query.HybridQueryBuilder.MAX_NUMBER_OF_SUB_QUERIES;

/**
* Calculates query weights and build query scorers for hybrid query.
*/
Expand All @@ -31,8 +33,6 @@ public final class HybridQueryWeight extends Weight {

private final ScoreMode scoreMode;

static final int BOOLEAN_REWRITE_TERM_COUNT_THRESHOLD = 16;

/**
* Construct the Weight for this Query searched by searcher. Recursively construct subquery weights.
*/
Expand Down Expand Up @@ -108,9 +108,8 @@ public Scorer scorer(LeafReaderContext context) throws IOException {
*/
@Override
public boolean isCacheable(LeafReaderContext ctx) {
if (weights.size() > BOOLEAN_REWRITE_TERM_COUNT_THRESHOLD) {
// Disallow caching large queries to not encourage users
// to build large queries
if (weights.size() > MAX_NUMBER_OF_SUB_QUERIES) {
// this situation should never happen, but in case it do such query will not be cached
return false;

Check warning on line 113 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryWeight.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryWeight.java#L113

Added line #L113 was not covered by tests
}
return weights.stream().allMatch(w -> w.isCacheable(ctx));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

Expand Down Expand Up @@ -79,6 +80,9 @@ private HybridQueryScorer getHybridQueryScorer(final Scorable scorer) throws IOE

@Override
public void collect(int doc) throws IOException {
if (Objects.isNull(compoundQueryScorer)) {
throw new IllegalArgumentException("scorers are null for all sub-queries in hybrid query");

Check warning on line 84 in src/main/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollector.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollector.java#L84

Added line #L84 was not covered by tests
}
float[] subScoresByQuery = compoundQueryScorer.hybridScores();
// iterate over results for each query
if (compoundScores == null) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.search.query;

import lombok.AllArgsConstructor;
import org.apache.lucene.search.CollectorManager;
import org.opensearch.search.aggregations.AggregationProcessor;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.query.QueryPhaseExecutionException;
import org.opensearch.search.query.QuerySearchResult;
import org.opensearch.search.query.ReduceableSearchResult;

import java.io.IOException;
import java.util.Collection;
import java.util.List;

import static org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher.isHybridQuery;

/**
* Defines logic for pre- and post-phases of document scores collection. Responsible for registering custom
* collector manager for hybris query (pre phase) and reducing results (post phase)
*/
@AllArgsConstructor
public class HybridAggregationProcessor implements AggregationProcessor {

private final AggregationProcessor delegateAggsProcessor;

@Override
public void preProcess(SearchContext context) {
delegateAggsProcessor.preProcess(context);

if (isHybridQuery(context.query(), context)) {
// adding collector manager for hybrid query
CollectorManager collectorManager;
try {
collectorManager = HybridCollectorManager.createHybridCollectorManager(context);
} catch (IOException e) {
throw new RuntimeException(e);

Check warning on line 40 in src/main/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessor.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessor.java#L39-L40

Added lines #L39 - L40 were not covered by tests
}
context.queryCollectorManagers().put(HybridCollectorManager.class, collectorManager);
}
}

@Override
public void postProcess(SearchContext context) {
if (isHybridQuery(context.query(), context)) {
// for case when concurrent search is not enabled (default as of 2.12 release) reduce for collector
// managers is not called, and we have to call it manually. This is required as we format final
// result of hybrid query in {@link HybridTopScoreCollector#reduce}
if (!context.shouldUseConcurrentSearch()) {
reduceCollectorResults(context);
}
updateQueryResult(context.queryResult(), context);
}

delegateAggsProcessor.postProcess(context);
}

private void reduceCollectorResults(SearchContext context) {
CollectorManager<?, ReduceableSearchResult> collectorManager = context.queryCollectorManagers().get(HybridCollectorManager.class);
try {
final Collection collectors = List.of(collectorManager.newCollector());
collectorManager.reduce(collectors).reduce(context.queryResult());
} catch (IOException e) {
throw new QueryPhaseExecutionException(context.shardTarget(), "failed to execute hybrid query aggregation processor", e);

Check warning on line 67 in src/main/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessor.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessor.java#L66-L67

Added lines #L66 - L67 were not covered by tests
}
}

private void updateQueryResult(final QuerySearchResult queryResult, final SearchContext searchContext) {
boolean isSingleShard = searchContext.numberOfShards() == 1;
if (isSingleShard) {
searchContext.size(queryResult.queryResult().topDocs().topDocs.scoreDocs.length);

Check warning on line 74 in src/main/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessor.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessor.java#L74

Added line #L74 was not covered by tests
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,22 @@ Collector getCollector() {
return hybridcollector;
}

/**
* Reduce the results from hybrid scores collector into a format specific for hybrid search query:
* - start
* - sub-query-delimiter
* - scores
* - stop
* Ignore other collectors if they are present in the context
* @param collectors collection of collectors after they has been executed and collected documents and scores
* @return search results that can be reduced be the caller
*/
@Override
public ReduceableSearchResult reduce(Collection<Collector> collectors) {
final List<HybridTopScoreDocCollector> hybridTopScoreDocCollectors = new ArrayList<>();

// check if collector for hybrid query scores is part of this search context. It can be wrapped into MultiCollectorWrapper
// in case multiple collector managers are registered. We use hybrid scores collector to format scores into
// format specific for hybrid search query: start, sub-query-delimiter, scores, stop
for (final Collector collector : collectors) {
if (collector instanceof MultiCollectorWrapper) {
for (final Collector sub : (((MultiCollectorWrapper) collector).getCollectors())) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,12 @@
package org.opensearch.neuralsearch.search.query;

import java.io.IOException;
import java.util.Collection;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;

import com.google.common.annotations.VisibleForTesting;
import lombok.AllArgsConstructor;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.CollectorManager;
import org.apache.lucene.search.FieldExistsQuery;
import org.apache.lucene.search.Query;
import org.opensearch.common.settings.Settings;
Expand All @@ -28,10 +23,7 @@
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.query.QueryCollectorContext;
import org.opensearch.search.query.QueryPhase;
import org.opensearch.search.query.QueryPhaseExecutionException;
import org.opensearch.search.query.QueryPhaseSearcherWrapper;
import org.opensearch.search.query.QuerySearchResult;
import org.opensearch.search.query.ReduceableSearchResult;

import lombok.extern.log4j.Log4j2;

Expand Down Expand Up @@ -181,58 +173,4 @@ public AggregationProcessor aggregationProcessor(SearchContext searchContext) {
AggregationProcessor coreAggProcessor = super.aggregationProcessor(searchContext);
return new HybridAggregationProcessor(coreAggProcessor);
}

@AllArgsConstructor
public static class HybridAggregationProcessor implements AggregationProcessor {

private final AggregationProcessor delegateAggsProcessor;

@Override
public void preProcess(SearchContext context) {
delegateAggsProcessor.preProcess(context);

if (isHybridQuery(context.query(), context)) {
// adding collector manager for hybrid query
CollectorManager collectorManager;
try {
collectorManager = HybridCollectorManager.createHybridCollectorManager(context);
} catch (IOException e) {
throw new RuntimeException(e);
}
Map<Class<?>, CollectorManager<? extends Collector, ReduceableSearchResult>> collectorManagersByManagerClass = context
.queryCollectorManagers();
collectorManagersByManagerClass.put(HybridCollectorManager.class, collectorManager);
}
}

@Override
public void postProcess(SearchContext context) {
if (isHybridQuery(context.query(), context)) {
if (!context.shouldUseConcurrentSearch()) {
reduceCollectorResults(context);
}
updateQueryResult(context.queryResult(), context);
}

delegateAggsProcessor.postProcess(context);
}

private void reduceCollectorResults(SearchContext context) {
CollectorManager<?, ReduceableSearchResult> collectorManager = context.queryCollectorManagers()
.get(HybridCollectorManager.class);
try {
final Collection collectors = List.of(collectorManager.newCollector());
collectorManager.reduce(collectors).reduce(context.queryResult());
} catch (IOException e) {
throw new QueryPhaseExecutionException(context.shardTarget(), "failed to execute hybrid query aggregation processor", e);
}
}

private void updateQueryResult(final QuerySearchResult queryResult, final SearchContext searchContext) {
boolean isSingleShard = searchContext.numberOfShards() == 1;
if (isSingleShard) {
searchContext.size(queryResult.queryResult().topDocs().topDocs.scoreDocs.length);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ public class HybridAggregationProcessorTests extends OpenSearchQueryTestCase {
@SneakyThrows
public void testAggregationProcessorDelegate_whenPreAndPostAreCalled_thenSuccessful() {
AggregationProcessor mockAggsProcessorDelegate = mock(AggregationProcessor.class);
HybridQueryPhaseSearcher.HybridAggregationProcessor hybridAggregationProcessor =
new HybridQueryPhaseSearcher.HybridAggregationProcessor(mockAggsProcessorDelegate);
HybridAggregationProcessor hybridAggregationProcessor = new HybridAggregationProcessor(mockAggsProcessorDelegate);

SearchContext searchContext = mock(SearchContext.class);
hybridAggregationProcessor.preProcess(searchContext);
Expand All @@ -63,8 +62,7 @@ public void testAggregationProcessorDelegate_whenPreAndPostAreCalled_thenSuccess
@SneakyThrows
public void testCollectorManager_whenHybridQueryAndNotConcurrentSearch_thenSuccessful() {
AggregationProcessor mockAggsProcessorDelegate = mock(AggregationProcessor.class);
HybridQueryPhaseSearcher.HybridAggregationProcessor hybridAggregationProcessor =
new HybridQueryPhaseSearcher.HybridAggregationProcessor(mockAggsProcessorDelegate);
HybridAggregationProcessor hybridAggregationProcessor = new HybridAggregationProcessor(mockAggsProcessorDelegate);

SearchContext searchContext = mock(SearchContext.class);
QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
Expand Down Expand Up @@ -124,8 +122,7 @@ public void testCollectorManager_whenHybridQueryAndNotConcurrentSearch_thenSucce
@SneakyThrows
public void testCollectorManager_whenHybridQueryAndConcurrentSearch_thenSuccessful() {
AggregationProcessor mockAggsProcessorDelegate = mock(AggregationProcessor.class);
HybridQueryPhaseSearcher.HybridAggregationProcessor hybridAggregationProcessor =
new HybridQueryPhaseSearcher.HybridAggregationProcessor(mockAggsProcessorDelegate);
HybridAggregationProcessor hybridAggregationProcessor = new HybridAggregationProcessor(mockAggsProcessorDelegate);

SearchContext searchContext = mock(SearchContext.class);
QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
Expand Down Expand Up @@ -185,8 +182,7 @@ public void testCollectorManager_whenHybridQueryAndConcurrentSearch_thenSuccessf
@SneakyThrows
public void testCollectorManager_whenNotHybridQueryAndNotConcurrentSearch_thenSuccessful() {
AggregationProcessor mockAggsProcessorDelegate = mock(AggregationProcessor.class);
HybridQueryPhaseSearcher.HybridAggregationProcessor hybridAggregationProcessor =
new HybridQueryPhaseSearcher.HybridAggregationProcessor(mockAggsProcessorDelegate);
HybridAggregationProcessor hybridAggregationProcessor = new HybridAggregationProcessor(mockAggsProcessorDelegate);

SearchContext searchContext = mock(SearchContext.class);
QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -807,7 +807,7 @@ public void testAggsProcessor_whenGettingAggsProcessor_thenSuccess() {
SearchContext searchContext = mock(SearchContext.class);
AggregationProcessor aggregationProcessor = hybridQueryPhaseSearcher.aggregationProcessor(searchContext);
assertNotNull(aggregationProcessor);
assertTrue(aggregationProcessor instanceof HybridQueryPhaseSearcher.HybridAggregationProcessor);
assertTrue(aggregationProcessor instanceof HybridAggregationProcessor);
}

@SneakyThrows
Expand Down

0 comments on commit 67df195

Please sign in to comment.