Skip to content

Commit

Permalink
Optimize PQ for hybrid scores
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <gaievski@amazon.com>
  • Loading branch information
martin-gaievski committed May 10, 2024
1 parent 7c54c86 commit 8ccee40
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 61 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Features
### Enhancements
- Pass empty doc collector instead of top docs collector to improve hybrid query latencies by 20% ([#731](https://github.com/opensearch-project/neural-search/pull/731))
- Use lazy initialization for priority queue of hits and scores to improve latencies by 20% ([#746](https://github.com/opensearch-project/neural-search/pull/746))
### Bug Fixes
### Infrastructure
### Documentation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,7 @@
*/
package org.opensearch.neuralsearch.query;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Objects;

import lombok.Getter;
import lombok.extern.log4j.Log4j2;
import org.apache.lucene.search.DisiPriorityQueue;
import org.apache.lucene.search.DisiWrapper;
Expand All @@ -21,11 +14,16 @@
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.TwoPhaseIterator;
import org.apache.lucene.search.Weight;

import lombok.Getter;
import org.apache.lucene.util.PriorityQueue;
import org.opensearch.neuralsearch.search.HybridDisiWrapper;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Objects;

/**
* Class abstracts functionality of Scorer for hybrid query. When iterating over documents in increasing
* order of doc id, this class fills up array of scores per sub-query for each doc id. Order in array of scores
Expand All @@ -40,11 +38,10 @@ public final class HybridQueryScorer extends Scorer {

private final DisiPriorityQueue subScorersPQ;

private final float[] subScores;

private final DocIdSetIterator approximation;
private final HybridScoreBlockBoundaryPropagator disjunctionBlockPropagator;
private final TwoPhase twoPhase;
private final int numSubqueries;

public HybridQueryScorer(final Weight weight, final List<Scorer> subScorers) throws IOException {
this(weight, subScorers, ScoreMode.TOP_SCORES);
Expand All @@ -53,7 +50,7 @@ public HybridQueryScorer(final Weight weight, final List<Scorer> subScorers) thr
HybridQueryScorer(final Weight weight, final List<Scorer> subScorers, final ScoreMode scoreMode) throws IOException {
super(weight);
this.subScorers = Collections.unmodifiableList(subScorers);
subScores = new float[subScorers.size()];
this.numSubqueries = subScorers.size();
this.subScorersPQ = initializeSubScorersPQ();
boolean needsScores = scoreMode != ScoreMode.COMPLETE_NO_SCORES;

Expand Down Expand Up @@ -100,13 +97,8 @@ public int advanceShallow(int target) throws IOException {
*/
@Override
public float score() throws IOException {
return score(getSubMatches());
}

private float score(DisiWrapper topList) throws IOException {
float totalScore = 0.0f;
for (DisiWrapper disiWrapper = topList; disiWrapper != null; disiWrapper = disiWrapper.next) {
// check if this doc has match in the subQuery. If not, add score as 0.0 and continue
for (DisiWrapper disiWrapper : subScorersPQ) {
if (disiWrapper.scorer.docID() == DocIdSetIterator.NO_MORE_DOCS) {
continue;
}
Expand Down Expand Up @@ -189,21 +181,8 @@ public int docID() {
* @throws IOException
*/
public float[] hybridScores() throws IOException {
float[] scores = new float[subScores.length];
float[] scores = new float[numSubqueries];
DisiWrapper topList = subScorersPQ.topList();
if (topList instanceof HybridDisiWrapper == false) {
log.error(
String.format(
Locale.ROOT,
"Unexpected type of DISI wrapper, expected [%s] but found [%s]",
HybridDisiWrapper.class.getSimpleName(),
subScorersPQ.topList().getClass().getSimpleName()
)
);
throw new IllegalStateException(
"Unable to collect scores for one of the sub-queries, encountered an unexpected type of score iterator."
);
}
for (HybridDisiWrapper disiWrapper = (HybridDisiWrapper) topList; disiWrapper != null; disiWrapper =
(HybridDisiWrapper) disiWrapper.next) {
// check if this doc has match in the subQuery. If not, add score as 0.0 and continue
Expand All @@ -219,9 +198,8 @@ public float[] hybridScores() throws IOException {
private DisiPriorityQueue initializeSubScorersPQ() {
Objects.requireNonNull(subScorers, "should not be null");
// we need to count this way in order to include all identical sub-queries
int numOfSubQueries = subScorers.size();
DisiPriorityQueue subScorersPQ = new DisiPriorityQueue(numOfSubQueries);
for (int idx = 0; idx < subScorers.size(); idx++) {
DisiPriorityQueue subScorersPQ = new DisiPriorityQueue(numSubqueries);
for (int idx = 0; idx < numSubqueries; idx++) {
Scorer scorer = subScorers.get(idx);
if (scorer == null) {
continue;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,9 @@
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.util.PriorityQueue;
import org.opensearch.neuralsearch.query.HybridQueryScorer;

import lombok.Getter;
import lombok.extern.log4j.Log4j2;
import org.opensearch.neuralsearch.query.HybridQueryScorer;

/**
* Collects the TopDocs after executing hybrid query. Uses HybridQueryTopDocs as DTO to handle each sub query results
Expand All @@ -38,7 +37,6 @@ public class HybridTopScoreDocCollector implements Collector {
private TotalHits.Relation totalHitsRelation = TotalHits.Relation.EQUAL_TO;
private int[] totalHits;
private final int numOfHits;
@Getter
private PriorityQueue<ScoreDoc>[] compoundScores;

public HybridTopScoreDocCollector(int numHits, HitsThresholdChecker hitsThresholdChecker) {
Expand Down Expand Up @@ -96,12 +94,13 @@ public void collect(int doc) throws IOException {
if (Objects.isNull(compoundQueryScorer)) {
throw new IllegalArgumentException("scorers are null for all sub-queries in hybrid query");
}

float[] subScoresByQuery = compoundQueryScorer.hybridScores();
// iterate over results for each query
if (compoundScores == null) {
compoundScores = new PriorityQueue[subScoresByQuery.length];
for (int i = 0; i < subScoresByQuery.length; i++) {
compoundScores[i] = new HitQueue(numOfHits, true);
compoundScores[i] = new HitQueue(numOfHits, false);
}
totalHits = new int[subScoresByQuery.length];
}
Expand All @@ -113,10 +112,10 @@ public void collect(int doc) throws IOException {
}
totalHits[i]++;
PriorityQueue<ScoreDoc> pq = compoundScores[i];
ScoreDoc topDoc = pq.top();
topDoc.doc = doc + docBase;
topDoc.score = score;
pq.updateTop();
ScoreDoc topDoc = new ScoreDoc(doc + docBase, score);
// this way we're inserting into heap and do nothing else unless we reach the capacity
// after that we pull out the lowest score element on each insert
pq.insertWithOverflow(topDoc);
}
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,11 @@ public void testWithRandomDocumentsAndCombinedScore_whenMultipleScorers_thenRetu
int idx = Arrays.binarySearch(docs2, doc);
expectedScore += scores2[idx];
}
assertEquals(expectedScore, hybridQueryScorer.score(), 0.001f);
float hybridScore = 0.0f;
for (float score : hybridQueryScorer.hybridScores()) {
hybridScore += score;
}
assertEquals(expectedScore, hybridScore, 0.001f);
numOfActualDocs++;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
import org.apache.lucene.search.Weight;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.analysis.MockAnalyzer;
import org.apache.lucene.util.PriorityQueue;
import org.opensearch.index.mapper.TextFieldMapper;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.QueryShardContext;
Expand Down Expand Up @@ -446,13 +445,6 @@ public void testCompoundScorer_whenHybridScorerIsChildScorer_thenSuccessful() {
int nextDoc = hybridQueryScorer.iterator().nextDoc();
leafCollector.collect(nextDoc);

assertNotNull(hybridTopScoreDocCollector.getCompoundScores());
PriorityQueue<ScoreDoc>[] compoundScoresPQ = hybridTopScoreDocCollector.getCompoundScores();
assertEquals(1, compoundScoresPQ.length);
PriorityQueue<ScoreDoc> scoreDoc = compoundScoresPQ[0];
assertNotNull(scoreDoc);
assertNotNull(scoreDoc.top());

w.close();
reader.close();
directory.close();
Expand Down Expand Up @@ -497,13 +489,6 @@ public void testCompoundScorer_whenHybridScorerIsTopLevelScorer_thenSuccessful()
int nextDoc = hybridQueryScorer.iterator().nextDoc();
leafCollector.collect(nextDoc);

assertNotNull(hybridTopScoreDocCollector.getCompoundScores());
PriorityQueue<ScoreDoc>[] compoundScoresPQ = hybridTopScoreDocCollector.getCompoundScores();
assertEquals(1, compoundScoresPQ.length);
PriorityQueue<ScoreDoc> scoreDoc = compoundScoresPQ[0];
assertNotNull(scoreDoc);
assertNotNull(scoreDoc.top());

w.close();
reader.close();
directory.close();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import java.util.LinkedList;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ExecutorService;

import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
Expand Down

0 comments on commit 8ccee40

Please sign in to comment.