Skip to content

Commit

Permalink
Changed approach for storing hybrid query results from compound top d…
Browse files Browse the repository at this point in the history
…ocs to signle list of scores with delimiter

Signed-off-by: Martin Gaievski <gaievski@amazon.com>
  • Loading branch information
martin-gaievski committed Aug 25, 2023
1 parent d12f480 commit 4760545
Show file tree
Hide file tree
Showing 24 changed files with 753 additions and 325 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Features
* Added Score Normalization and Combination feature ([#241](https://github.com/opensearch-project/neural-search/pull/241/))
### Enhancements
* Changed format for hybrid query results to a single list of scores with delimiter ([#259](https://github.com/opensearch-project/neural-search/pull/259))
### Bug Fixes
### Infrastructure
### Documentation
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.neuralsearch.processor;

import static org.opensearch.neuralsearch.processor.NormalizationProcessor.isHybridQueryDelimiterElement;
import static org.opensearch.neuralsearch.processor.NormalizationProcessor.isHybridQueryStartStopElement;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;

import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.Setter;
import lombok.ToString;
import lombok.extern.log4j.Log4j2;

import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;

/**
* Class stores collection of TodDocs for each sub query from hybrid query
*/
@ToString(includeFieldNames = true)
@AllArgsConstructor
@Log4j2
public class CompoundTopDocs {

@Getter
@Setter
private TotalHits totalHits;
@Getter
private final List<TopDocs> compoundTopDocs;
@Getter
@Setter
private ScoreDoc[] scoreDocs;

public CompoundTopDocs(final TotalHits totalHits, final List<TopDocs> compoundTopDocs) {
this.totalHits = totalHits;
this.compoundTopDocs = compoundTopDocs;
scoreDocs = cloneLargestScoreDocs(compoundTopDocs);
}

private static ScoreDoc[] cloneLargestScoreDocs(List<TopDocs> docs) {
if (docs == null) {
return null;
}
ScoreDoc[] maxScoreDocs = new ScoreDoc[0];
int maxLength = -1;
for (TopDocs topDoc : docs) {
if (topDoc == null || topDoc.scoreDocs == null) {
continue;
}
if (topDoc.scoreDocs.length > maxLength) {
maxLength = topDoc.scoreDocs.length;
maxScoreDocs = topDoc.scoreDocs;
}
}
// do deep copy
return Arrays.stream(maxScoreDocs).map(doc -> new ScoreDoc(doc.doc, doc.score, doc.shardIndex)).toArray(ScoreDoc[]::new);
}

/**
* Create new instance from TopDocs by parsing scores of sub-queries. Final format looks like:
* doc_id | magic_number_1
* doc_id | magic_number_2
* ...
* doc_id | magic_number_2
* ...
* doc_id | magic_number_2
* ...
* doc_id | magic_number_1
*
* where doc_id is one of valid ids from result
*
* @param topDocs object with scores from multiple sub-queries
* @return compound TopDocs object that has results from all sub-queries
*/
public static CompoundTopDocs create(final TopDocs topDocs) {
ScoreDoc[] scoreDocs = topDocs.scoreDocs;
if (Objects.isNull(scoreDocs) || scoreDocs.length < 2) {
return new CompoundTopDocs(topDocs.totalHits, new ArrayList<>());

Check warning on line 87 in src/main/java/org/opensearch/neuralsearch/processor/CompoundTopDocs.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/processor/CompoundTopDocs.java#L87

Added line #L87 was not covered by tests
}
// skipping first two elements, it's a start-stop element and delimiter for first series
List<TopDocs> topDocsList = new ArrayList<>();
List<ScoreDoc> scoreDocList = new ArrayList<>();
for (int index = 2; index < scoreDocs.length; index++) {
// getting first element of score's series
ScoreDoc scoreDoc = scoreDocs[index];
if (isHybridQueryDelimiterElement(scoreDoc) || isHybridQueryStartStopElement(scoreDoc)) {
ScoreDoc[] subQueryScores = scoreDocList.toArray(new ScoreDoc[0]);
TotalHits totalHits = new TotalHits(subQueryScores.length, TotalHits.Relation.EQUAL_TO);
TopDocs subQueryTopDocs = new TopDocs(totalHits, subQueryScores);
topDocsList.add(subQueryTopDocs);
scoreDocList.clear();
} else {
scoreDocList.add(scoreDoc);
}
}
return new CompoundTopDocs(topDocs.totalHits, topDocsList);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,25 @@

package org.opensearch.neuralsearch.processor;

import static org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher.MAGIC_NUMBER_DELIMITER;
import static org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher.MAGIC_NUMBER_START_STOP;

import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;

import lombok.AllArgsConstructor;
import lombok.extern.log4j.Log4j2;

import org.apache.lucene.search.ScoreDoc;
import org.opensearch.action.search.QueryPhaseResultConsumer;
import org.opensearch.action.search.SearchPhaseContext;
import org.opensearch.action.search.SearchPhaseName;
import org.opensearch.action.search.SearchPhaseResults;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
import org.opensearch.neuralsearch.search.CompoundTopDocs;
import org.opensearch.search.SearchPhaseResult;
import org.opensearch.search.fetch.FetchSearchResult;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.pipeline.SearchPhaseResultsProcessor;
import org.opensearch.search.query.QuerySearchResult;
Expand Down Expand Up @@ -56,7 +59,14 @@ public <Result extends SearchPhaseResult> void process(
return;
}
List<QuerySearchResult> querySearchResults = getQueryPhaseSearchResults(searchPhaseResult);
normalizationWorkflow.execute(querySearchResults, normalizationTechnique, combinationTechnique);
FetchSearchResult fetchSearchResult = searchPhaseResult.getAtomicArray().asList().get(0).fetchResult();
normalizationWorkflow.execute(
querySearchResults,
fetchSearchResult,
normalizationTechnique,
combinationTechnique,
searchPhaseContext.getNumShards() == 1
);
}

@Override
Expand Down Expand Up @@ -95,19 +105,21 @@ private <Result extends SearchPhaseResult> boolean shouldSkipProcessor(SearchPha
}

QueryPhaseResultConsumer queryPhaseResultConsumer = (QueryPhaseResultConsumer) searchPhaseResult;
Optional<SearchPhaseResult> optionalSearchPhaseResult = queryPhaseResultConsumer.getAtomicArray()
.asList()
.stream()
.filter(Objects::nonNull)
.findFirst();
return isNotHybridQuery(optionalSearchPhaseResult);
return queryPhaseResultConsumer.getAtomicArray().asList().stream().filter(Objects::nonNull).noneMatch(this::isHybridQuery);
}

private boolean isNotHybridQuery(final Optional<SearchPhaseResult> optionalSearchPhaseResult) {
return optionalSearchPhaseResult.isEmpty()
|| Objects.isNull(optionalSearchPhaseResult.get().queryResult())
|| Objects.isNull(optionalSearchPhaseResult.get().queryResult().topDocs())
|| !(optionalSearchPhaseResult.get().queryResult().topDocs().topDocs instanceof CompoundTopDocs);
/**
* Return true if results are from hybrid query.
* @param searchPhaseResult
* @return true if results are from hybrid query
*/
private boolean isHybridQuery(final SearchPhaseResult searchPhaseResult) {
// check for delimiter at the end of the score docs.
return Objects.nonNull(searchPhaseResult.queryResult())
&& Objects.nonNull(searchPhaseResult.queryResult().topDocs())
&& Objects.nonNull(searchPhaseResult.queryResult().topDocs().topDocs.scoreDocs)
&& searchPhaseResult.queryResult().topDocs().topDocs.scoreDocs.length > 0
&& isHybridQueryStartStopElement(searchPhaseResult.queryResult().topDocs().topDocs.scoreDocs[0]);
}

private <Result extends SearchPhaseResult> List<QuerySearchResult> getQueryPhaseSearchResults(
Expand All @@ -119,4 +131,12 @@ private <Result extends SearchPhaseResult> List<QuerySearchResult> getQueryPhase
.map(result -> result == null ? null : result.queryResult())
.collect(Collectors.toList());
}

public static boolean isHybridQueryStartStopElement(final ScoreDoc scoreDoc) {
return Objects.nonNull(scoreDoc) && scoreDoc.doc >= 0 && Float.compare(scoreDoc.score, MAGIC_NUMBER_START_STOP) == 0;
}

public static boolean isHybridQueryDelimiterElement(final ScoreDoc scoreDoc) {
return Objects.nonNull(scoreDoc) && scoreDoc.doc >= 0 && Float.compare(scoreDoc.score, MAGIC_NUMBER_DELIMITER) == 0;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,25 @@

package org.opensearch.neuralsearch.processor;

import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;

import lombok.AllArgsConstructor;
import lombok.extern.log4j.Log4j2;

import org.apache.lucene.search.TopDocs;
import org.opensearch.common.lucene.search.TopDocsAndMaxScore;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
import org.opensearch.neuralsearch.processor.combination.ScoreCombiner;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer;
import org.opensearch.neuralsearch.search.CompoundTopDocs;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits;
import org.opensearch.search.fetch.FetchSearchResult;
import org.opensearch.search.query.QuerySearchResult;

/**
Expand All @@ -39,8 +45,10 @@ public class NormalizationProcessorWorkflow {
*/
public void execute(
final List<QuerySearchResult> querySearchResults,
final FetchSearchResult fetchSearchResult,
final ScoreNormalizationTechnique normalizationTechnique,
final ScoreCombinationTechnique combinationTechnique
final ScoreCombinationTechnique combinationTechnique,
final boolean isSingleShard
) {
// pre-process data
log.debug("Pre-process query results");
Expand All @@ -56,7 +64,7 @@ public void execute(

// post-process data
log.debug("Post-process query results after score normalization and combination");
updateOriginalQueryResults(querySearchResults, queryTopDocs);
updateOriginalQueryResults(querySearchResults, fetchSearchResult, queryTopDocs, isSingleShard);
}

/**
Expand All @@ -67,22 +75,71 @@ public void execute(
private List<CompoundTopDocs> getQueryTopDocs(final List<QuerySearchResult> querySearchResults) {
List<CompoundTopDocs> queryTopDocs = querySearchResults.stream()
.filter(searchResult -> Objects.nonNull(searchResult.topDocs()))
.filter(searchResult -> searchResult.topDocs().topDocs instanceof CompoundTopDocs)
.map(searchResult -> (CompoundTopDocs) searchResult.topDocs().topDocs)
.map(querySearchResult -> querySearchResult.topDocs().topDocs)
.map(CompoundTopDocs::create)
.collect(Collectors.toList());
if (queryTopDocs.size() != querySearchResults.size()) {
log.debug("Some of querySearchResults are not produced by hybrid query");

Check warning on line 82 in src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java#L82

Added line #L82 was not covered by tests
}
return queryTopDocs;
}

private void updateOriginalQueryResults(final List<QuerySearchResult> querySearchResults, final List<CompoundTopDocs> queryTopDocs) {
for (int i = 0; i < querySearchResults.size(); i++) {
QuerySearchResult querySearchResult = querySearchResults.get(i);
if (!(querySearchResult.topDocs().topDocs instanceof CompoundTopDocs) || Objects.isNull(queryTopDocs.get(i))) {
continue;
}
CompoundTopDocs updatedTopDocs = queryTopDocs.get(i);
float maxScore = updatedTopDocs.totalHits.value > 0 ? updatedTopDocs.scoreDocs[0].score : 0.0f;
TopDocsAndMaxScore updatedTopDocsAndMaxScore = new TopDocsAndMaxScore(updatedTopDocs, maxScore);
private void updateOriginalQueryResults(
final List<QuerySearchResult> querySearchResults,
final FetchSearchResult fetchSearchResult,
final List<CompoundTopDocs> queryTopDocs,
final boolean isSingleShard
) {
int queryTopDocsIndex = 0;
for (QuerySearchResult querySearchResult : querySearchResults) {
CompoundTopDocs updatedTopDocs = queryTopDocs.get(queryTopDocsIndex++);
float maxScore = updatedTopDocs.getTotalHits().value > 0 ? updatedTopDocs.getScoreDocs()[0].score : 0.0f;

// create final version of top docs with all updated values
TopDocs topDocs = new TopDocs(updatedTopDocs.getTotalHits(), updatedTopDocs.getScoreDocs());

TopDocsAndMaxScore updatedTopDocsAndMaxScore = new TopDocsAndMaxScore(topDocs, maxScore);
querySearchResult.topDocs(updatedTopDocsAndMaxScore, null);
}
// a workaround for a single shard case, fetch has happened, and we need to update both fetch and
// query results
if (isSingleShard && querySearchResults.size() == 1) {
updateFetchSearchResults(querySearchResults, fetchSearchResult);
}
}

private void updateFetchSearchResults(final List<QuerySearchResult> querySearchResults, final FetchSearchResult fetchSearchResult) {
if (Objects.isNull(fetchSearchResult)) {
return;

Check warning on line 113 in src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java#L113

Added line #L113 was not covered by tests
}
// fetch results have list of document content, that includes start/stop and
// delimiter elements. list is in original order from query searcher. We need to:
// 1. filter out start/stop and delimiter elements
// 2. filter out duplicates from different sub-queries
// 3. update original scores to normalized and combined values
// 4. order scores based on normalized and combined values
SearchHits searchHits = fetchSearchResult.hits();

// create map of docId to index of search hits
Map<Integer, SearchHit> docIdToSearchHit = new HashMap<>();
for (SearchHit searchHit : searchHits) {
docIdToSearchHit.put(searchHit.docId(), searchHit);
}
QuerySearchResult querySearchResult = querySearchResults.get(0);
TopDocs topDocs = querySearchResult.topDocs().topDocs;
// iterate over the normalized/combined scores, that solves (1), (2) and (3)
SearchHit[] updatedSearchHitArray = Arrays.stream(topDocs.scoreDocs).map(scoreDoc -> {
// get fetched hit content by doc_id
SearchHit searchHit = docIdToSearchHit.get(scoreDoc.doc);
// update score to normalized/combined value (3)
searchHit.score(scoreDoc.score);
return searchHit;
}).toArray(SearchHit[]::new);
SearchHits updatedSearchHits = new SearchHits(
updatedSearchHitArray,
querySearchResult.getTotalHits(),
querySearchResult.getMaxScore()
);
fetchSearchResult.hits(updatedSearchHits);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
package org.opensearch.neuralsearch.processor.combination;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand All @@ -18,7 +17,7 @@
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.opensearch.neuralsearch.search.CompoundTopDocs;
import org.opensearch.neuralsearch.processor.CompoundTopDocs;

/**
* Abstracts combination of scores in query search results.
Expand Down Expand Up @@ -48,7 +47,7 @@ public void combineScores(final List<CompoundTopDocs> queryTopDocs, final ScoreC
}

private void combineShardScores(final ScoreCombinationTechnique scoreCombinationTechnique, final CompoundTopDocs compoundQueryTopDocs) {
if (Objects.isNull(compoundQueryTopDocs) || compoundQueryTopDocs.totalHits.value == 0) {
if (Objects.isNull(compoundQueryTopDocs) || compoundQueryTopDocs.getTotalHits().value == 0) {
return;
}
List<TopDocs> topDocsPerSubQuery = compoundQueryTopDocs.getCompoundTopDocs();
Expand Down Expand Up @@ -84,7 +83,7 @@ private ScoreDoc[] getCombinedScoreDocs(
) {
ScoreDoc[] finalScoreDocs = new ScoreDoc[maxHits];

int shardId = compoundQueryTopDocs.scoreDocs[0].shardIndex;
int shardId = compoundQueryTopDocs.getScoreDocs()[0].shardIndex;
for (int j = 0; j < maxHits && j < sortedScores.size(); j++) {
int docId = sortedScores.get(j);
finalScoreDocs[j] = new ScoreDoc(docId, combinedNormalizedScoresByDocId.get(docId), shardId);
Expand All @@ -100,7 +99,6 @@ public Map<Integer, float[]> getNormalizedScoresPerDocument(final List<TopDocs>
normalizedScoresPerDoc.computeIfAbsent(scoreDoc.doc, key -> {
float[] scores = new float[topDocsPerSubQuery.size()];
// we initialize with -1.0, as after normalization it's possible that score is 0.0
Arrays.fill(scores, -1.0f);
return scores;
});
normalizedScoresPerDoc.get(scoreDoc.doc)[j] = scoreDoc.score;
Expand All @@ -127,8 +125,10 @@ private void updateQueryTopDocsWithCombinedScores(
// - count max number of hits among sub-queries
int maxHits = getMaxHits(topDocsPerSubQuery);
// - update query search results with normalized scores
compoundQueryTopDocs.scoreDocs = getCombinedScoreDocs(compoundQueryTopDocs, combinedNormalizedScoresByDocId, sortedScores, maxHits);
compoundQueryTopDocs.totalHits = getTotalHits(topDocsPerSubQuery, maxHits);
compoundQueryTopDocs.setScoreDocs(
getCombinedScoreDocs(compoundQueryTopDocs, combinedNormalizedScoresByDocId, sortedScores, maxHits)
);
compoundQueryTopDocs.setTotalHits(getTotalHits(topDocsPerSubQuery, maxHits));
}

protected int getMaxHits(final List<TopDocs> topDocsPerSubQuery) {
Expand Down
Loading

0 comments on commit 4760545

Please sign in to comment.