Skip to content

Commit

Permalink
Use lazy initialization for priority queue of hits and scores to impr…
Browse files Browse the repository at this point in the history
…ove latencies by 20% (#746) (#755)

* Optimize PQ for hybrid scores

Signed-off-by: Martin Gaievski <gaievski@amazon.com>
(cherry picked from commit 940a7ea)

Co-authored-by: Martin Gaievski <gaievski@amazon.com>
  • Loading branch information
1 parent b0753b8 commit c54a797
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 60 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Enhancements
- Pass empty doc collector instead of top docs collector to improve hybrid query latencies by 20% ([#731](https://github.com/opensearch-project/neural-search/pull/731))
- Optimize parameter parsing in text chunking processor ([#733](https://github.com/opensearch-project/neural-search/pull/733))
- Use lazy initialization for priority queue of hits and scores to improve latencies by 20% ([#746](https://github.com/opensearch-project/neural-search/pull/746))
### Bug Fixes
### Infrastructure
### Documentation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,7 @@
*/
package org.opensearch.neuralsearch.query;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Objects;

import lombok.Getter;
import lombok.extern.log4j.Log4j2;
import org.apache.lucene.search.DisiPriorityQueue;
import org.apache.lucene.search.DisiWrapper;
Expand All @@ -21,11 +14,16 @@
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.TwoPhaseIterator;
import org.apache.lucene.search.Weight;

import lombok.Getter;
import org.apache.lucene.util.PriorityQueue;
import org.opensearch.neuralsearch.search.HybridDisiWrapper;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Objects;

/**
* Class abstracts functionality of Scorer for hybrid query. When iterating over documents in increasing
* order of doc id, this class fills up array of scores per sub-query for each doc id. Order in array of scores
Expand All @@ -40,11 +38,10 @@ public final class HybridQueryScorer extends Scorer {

private final DisiPriorityQueue subScorersPQ;

private final float[] subScores;

private final DocIdSetIterator approximation;
private final HybridScoreBlockBoundaryPropagator disjunctionBlockPropagator;
private final TwoPhase twoPhase;
private final int numSubqueries;

public HybridQueryScorer(final Weight weight, final List<Scorer> subScorers) throws IOException {
this(weight, subScorers, ScoreMode.TOP_SCORES);
Expand All @@ -53,7 +50,7 @@ public HybridQueryScorer(final Weight weight, final List<Scorer> subScorers) thr
HybridQueryScorer(final Weight weight, final List<Scorer> subScorers, final ScoreMode scoreMode) throws IOException {
super(weight);
this.subScorers = Collections.unmodifiableList(subScorers);
subScores = new float[subScorers.size()];
this.numSubqueries = subScorers.size();
this.subScorersPQ = initializeSubScorersPQ();
boolean needsScores = scoreMode != ScoreMode.COMPLETE_NO_SCORES;

Expand Down Expand Up @@ -100,13 +97,8 @@ public int advanceShallow(int target) throws IOException {
*/
@Override
public float score() throws IOException {
return score(getSubMatches());
}

private float score(DisiWrapper topList) throws IOException {
float totalScore = 0.0f;
for (DisiWrapper disiWrapper = topList; disiWrapper != null; disiWrapper = disiWrapper.next) {
// check if this doc has match in the subQuery. If not, add score as 0.0 and continue
for (DisiWrapper disiWrapper : subScorersPQ) {
if (disiWrapper.scorer.docID() == DocIdSetIterator.NO_MORE_DOCS) {
continue;
}
Expand Down Expand Up @@ -189,21 +181,8 @@ public int docID() {
* @throws IOException
*/
public float[] hybridScores() throws IOException {
float[] scores = new float[subScores.length];
float[] scores = new float[numSubqueries];
DisiWrapper topList = subScorersPQ.topList();
if (topList instanceof HybridDisiWrapper == false) {
log.error(
String.format(
Locale.ROOT,
"Unexpected type of DISI wrapper, expected [%s] but found [%s]",
HybridDisiWrapper.class.getSimpleName(),
subScorersPQ.topList().getClass().getSimpleName()
)
);
throw new IllegalStateException(
"Unable to collect scores for one of the sub-queries, encountered an unexpected type of score iterator."
);
}
for (HybridDisiWrapper disiWrapper = (HybridDisiWrapper) topList; disiWrapper != null; disiWrapper =
(HybridDisiWrapper) disiWrapper.next) {
// check if this doc has match in the subQuery. If not, add score as 0.0 and continue
Expand All @@ -219,9 +198,8 @@ public float[] hybridScores() throws IOException {
private DisiPriorityQueue initializeSubScorersPQ() {
Objects.requireNonNull(subScorers, "should not be null");
// we need to count this way in order to include all identical sub-queries
int numOfSubQueries = subScorers.size();
DisiPriorityQueue subScorersPQ = new DisiPriorityQueue(numOfSubQueries);
for (int idx = 0; idx < subScorers.size(); idx++) {
DisiPriorityQueue subScorersPQ = new DisiPriorityQueue(numSubqueries);
for (int idx = 0; idx < numSubqueries; idx++) {
Scorer scorer = subScorers.get(idx);
if (scorer == null) {
continue;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,9 @@
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.util.PriorityQueue;
import org.opensearch.neuralsearch.query.HybridQueryScorer;

import lombok.Getter;
import lombok.extern.log4j.Log4j2;
import org.opensearch.neuralsearch.query.HybridQueryScorer;

/**
* Collects the TopDocs after executing hybrid query. Uses HybridQueryTopDocs as DTO to handle each sub query results
Expand All @@ -38,7 +37,6 @@ public class HybridTopScoreDocCollector implements Collector {
private TotalHits.Relation totalHitsRelation = TotalHits.Relation.EQUAL_TO;
private int[] totalHits;
private final int numOfHits;
@Getter
private PriorityQueue<ScoreDoc>[] compoundScores;

public HybridTopScoreDocCollector(int numHits, HitsThresholdChecker hitsThresholdChecker) {
Expand Down Expand Up @@ -96,12 +94,13 @@ public void collect(int doc) throws IOException {
if (Objects.isNull(compoundQueryScorer)) {
throw new IllegalArgumentException("scorers are null for all sub-queries in hybrid query");
}

float[] subScoresByQuery = compoundQueryScorer.hybridScores();
// iterate over results for each query
if (compoundScores == null) {
compoundScores = new PriorityQueue[subScoresByQuery.length];
for (int i = 0; i < subScoresByQuery.length; i++) {
compoundScores[i] = new HitQueue(numOfHits, true);
compoundScores[i] = new HitQueue(numOfHits, false);
}
totalHits = new int[subScoresByQuery.length];
}
Expand All @@ -113,10 +112,10 @@ public void collect(int doc) throws IOException {
}
totalHits[i]++;
PriorityQueue<ScoreDoc> pq = compoundScores[i];
ScoreDoc topDoc = pq.top();
topDoc.doc = doc + docBase;
topDoc.score = score;
pq.updateTop();
ScoreDoc currentDoc = new ScoreDoc(doc + docBase, score);
// this way we're inserting into heap and do nothing else unless we reach the capacity
// after that we pull out the lowest score element on each insert
pq.insertWithOverflow(currentDoc);
}
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,11 @@ public void testWithRandomDocumentsAndCombinedScore_whenMultipleScorers_thenRetu
int idx = Arrays.binarySearch(docs2, doc);
expectedScore += scores2[idx];
}
assertEquals(expectedScore, hybridQueryScorer.score(), 0.001f);
float hybridScore = 0.0f;
for (float score : hybridQueryScorer.hybridScores()) {
hybridScore += score;
}
assertEquals(expectedScore, hybridScore, 0.001f);
numOfActualDocs++;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
import org.apache.lucene.search.Weight;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.analysis.MockAnalyzer;
import org.apache.lucene.util.PriorityQueue;
import org.opensearch.index.mapper.TextFieldMapper;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.QueryShardContext;
Expand Down Expand Up @@ -446,13 +445,6 @@ public void testCompoundScorer_whenHybridScorerIsChildScorer_thenSuccessful() {
int nextDoc = hybridQueryScorer.iterator().nextDoc();
leafCollector.collect(nextDoc);

assertNotNull(hybridTopScoreDocCollector.getCompoundScores());
PriorityQueue<ScoreDoc>[] compoundScoresPQ = hybridTopScoreDocCollector.getCompoundScores();
assertEquals(1, compoundScoresPQ.length);
PriorityQueue<ScoreDoc> scoreDoc = compoundScoresPQ[0];
assertNotNull(scoreDoc);
assertNotNull(scoreDoc.top());

w.close();
reader.close();
directory.close();
Expand Down Expand Up @@ -497,13 +489,6 @@ public void testCompoundScorer_whenHybridScorerIsTopLevelScorer_thenSuccessful()
int nextDoc = hybridQueryScorer.iterator().nextDoc();
leafCollector.collect(nextDoc);

assertNotNull(hybridTopScoreDocCollector.getCompoundScores());
PriorityQueue<ScoreDoc>[] compoundScoresPQ = hybridTopScoreDocCollector.getCompoundScores();
assertEquals(1, compoundScoresPQ.length);
PriorityQueue<ScoreDoc> scoreDoc = compoundScoresPQ[0];
assertNotNull(scoreDoc);
assertNotNull(scoreDoc.top());

w.close();
reader.close();
directory.close();
Expand Down

0 comments on commit c54a797

Please sign in to comment.