Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changed format for hybrid query results to a single list of scores with delimiter #259

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Features
* Added Score Normalization and Combination feature ([#241](https://github.com/opensearch-project/neural-search/pull/241/))
### Enhancements
* Changed format for hybrid query results to a single list of scores with delimiter ([#259](https://github.com/opensearch-project/neural-search/pull/259))
### Bug Fixes
### Infrastructure
### Documentation
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.neuralsearch.processor;

import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQueryDelimiterElement;
import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQueryStartStopElement;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;

import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.Setter;
import lombok.ToString;
import lombok.extern.log4j.Log4j2;

import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;

/**
* Class stores collection of TopDocs for each sub query from hybrid query. Collection of results is at shard level. We do store
* list of TopDocs and list of ScoreDoc as well as total hits for the shard.
*/
@AllArgsConstructor
@Getter
@ToString(includeFieldNames = true)
@Log4j2
public class CompoundTopDocs {

@Setter
private TotalHits totalHits;
private List<TopDocs> topDocs;
@Setter
private List<ScoreDoc> scoreDocs;

public CompoundTopDocs(final TotalHits totalHits, final List<TopDocs> topDocs) {
initialize(totalHits, topDocs);
}

private void initialize(TotalHits totalHits, List<TopDocs> topDocs) {
this.totalHits = totalHits;
this.topDocs = topDocs;
scoreDocs = cloneLargestScoreDocs(topDocs);
}

/**
* Create new instance from TopDocs by parsing scores of sub-queries. Final format looks like:
martin-gaievski marked this conversation as resolved.
Show resolved Hide resolved
* doc_id | magic_number_1
* doc_id | magic_number_2
* ...
* doc_id | magic_number_2
* ...
* doc_id | magic_number_2
* ...
* doc_id | magic_number_1
martin-gaievski marked this conversation as resolved.
Show resolved Hide resolved
*
* where doc_id is one of valid ids from result. For example, this is list with results for there sub-queries
*
* 0, 9549511920.4881596047
* 0, 4422440593.9791198149
* 0, 0.8
* 2, 0.5
* 0, 4422440593.9791198149
* 0, 4422440593.9791198149
* 2, 0.7
* 5, 0.65
* 6, 0.15
* 0, 9549511920.4881596047
*/
public CompoundTopDocs(final TopDocs topDocs) {
ScoreDoc[] scoreDocs = topDocs.scoreDocs;
if (Objects.isNull(scoreDocs) || scoreDocs.length < 2) {
initialize(topDocs.totalHits, new ArrayList<>());
return;

Check warning on line 81 in src/main/java/org/opensearch/neuralsearch/processor/CompoundTopDocs.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/processor/CompoundTopDocs.java#L80-L81

Added lines #L80 - L81 were not covered by tests
}
// skipping first two elements, it's a start-stop element and delimiter for first series
martin-gaievski marked this conversation as resolved.
Show resolved Hide resolved
List<TopDocs> topDocsList = new ArrayList<>();
List<ScoreDoc> scoreDocList = new ArrayList<>();
for (int index = 2; index < scoreDocs.length; index++) {
// getting first element of score's series
ScoreDoc scoreDoc = scoreDocs[index];
if (isHybridQueryDelimiterElement(scoreDoc) || isHybridQueryStartStopElement(scoreDoc)) {
martin-gaievski marked this conversation as resolved.
Show resolved Hide resolved
ScoreDoc[] subQueryScores = scoreDocList.toArray(new ScoreDoc[0]);
TotalHits totalHits = new TotalHits(subQueryScores.length, TotalHits.Relation.EQUAL_TO);
TopDocs subQueryTopDocs = new TopDocs(totalHits, subQueryScores);
topDocsList.add(subQueryTopDocs);
scoreDocList.clear();
} else {
scoreDocList.add(scoreDoc);
}
}
initialize(topDocs.totalHits, topDocsList);
}

private List<ScoreDoc> cloneLargestScoreDocs(final List<TopDocs> docs) {
if (docs == null) {
return null;
}
ScoreDoc[] maxScoreDocs = new ScoreDoc[0];
int maxLength = -1;
for (TopDocs topDoc : docs) {
if (topDoc == null || topDoc.scoreDocs == null) {
continue;
}
if (topDoc.scoreDocs.length > maxLength) {
maxLength = topDoc.scoreDocs.length;
maxScoreDocs = topDoc.scoreDocs;
}
}
// do deep copy
return Arrays.stream(maxScoreDocs).map(doc -> new ScoreDoc(doc.doc, doc.score, doc.shardIndex)).collect(Collectors.toList());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

package org.opensearch.neuralsearch.processor;

import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQueryStartStopElement;

import java.util.List;
import java.util.Objects;
import java.util.Optional;
Expand All @@ -19,8 +21,8 @@
import org.opensearch.action.search.SearchPhaseResults;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
import org.opensearch.neuralsearch.search.CompoundTopDocs;
import org.opensearch.search.SearchPhaseResult;
import org.opensearch.search.fetch.FetchSearchResult;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.pipeline.SearchPhaseResultsProcessor;
import org.opensearch.search.query.QuerySearchResult;
Expand Down Expand Up @@ -56,7 +58,8 @@ public <Result extends SearchPhaseResult> void process(
return;
}
List<QuerySearchResult> querySearchResults = getQueryPhaseSearchResults(searchPhaseResult);
normalizationWorkflow.execute(querySearchResults, normalizationTechnique, combinationTechnique);
Optional<FetchSearchResult> fetchSearchResult = getFetchSearchResults(searchPhaseResult);
normalizationWorkflow.execute(querySearchResults, fetchSearchResult, normalizationTechnique, combinationTechnique);
}

@Override
Expand Down Expand Up @@ -95,19 +98,21 @@ private <Result extends SearchPhaseResult> boolean shouldSkipProcessor(SearchPha
}

QueryPhaseResultConsumer queryPhaseResultConsumer = (QueryPhaseResultConsumer) searchPhaseResult;
Optional<SearchPhaseResult> optionalSearchPhaseResult = queryPhaseResultConsumer.getAtomicArray()
.asList()
.stream()
.filter(Objects::nonNull)
.findFirst();
return isNotHybridQuery(optionalSearchPhaseResult);
return queryPhaseResultConsumer.getAtomicArray().asList().stream().filter(Objects::nonNull).noneMatch(this::isHybridQuery);
}

private boolean isNotHybridQuery(final Optional<SearchPhaseResult> optionalSearchPhaseResult) {
return optionalSearchPhaseResult.isEmpty()
|| Objects.isNull(optionalSearchPhaseResult.get().queryResult())
|| Objects.isNull(optionalSearchPhaseResult.get().queryResult().topDocs())
|| !(optionalSearchPhaseResult.get().queryResult().topDocs().topDocs instanceof CompoundTopDocs);
/**
* Return true if results are from hybrid query.
* @param searchPhaseResult
* @return true if results are from hybrid query
*/
private boolean isHybridQuery(final SearchPhaseResult searchPhaseResult) {
// check for delimiter at the end of the score docs.
return Objects.nonNull(searchPhaseResult.queryResult())
&& Objects.nonNull(searchPhaseResult.queryResult().topDocs())
&& Objects.nonNull(searchPhaseResult.queryResult().topDocs().topDocs.scoreDocs)
&& searchPhaseResult.queryResult().topDocs().topDocs.scoreDocs.length > 0
&& isHybridQueryStartStopElement(searchPhaseResult.queryResult().topDocs().topDocs.scoreDocs[0]);
}

private <Result extends SearchPhaseResult> List<QuerySearchResult> getQueryPhaseSearchResults(
Expand All @@ -119,4 +124,11 @@ private <Result extends SearchPhaseResult> List<QuerySearchResult> getQueryPhase
.map(result -> result == null ? null : result.queryResult())
.collect(Collectors.toList());
}

private <Result extends SearchPhaseResult> Optional<FetchSearchResult> getFetchSearchResults(
final SearchPhaseResults<Result> searchPhaseResults
) {
Optional<Result> optionalFirstSearchPhaseResult = searchPhaseResults.getAtomicArray().asList().stream().findFirst();
return optionalFirstSearchPhaseResult.map(SearchPhaseResult::fetchResult);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,28 @@

package org.opensearch.neuralsearch.processor;

import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Function;
import java.util.stream.Collectors;

import lombok.AllArgsConstructor;
import lombok.extern.log4j.Log4j2;

import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.opensearch.common.lucene.search.TopDocsAndMaxScore;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
import org.opensearch.neuralsearch.processor.combination.ScoreCombiner;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer;
import org.opensearch.neuralsearch.search.CompoundTopDocs;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits;
import org.opensearch.search.fetch.FetchSearchResult;
import org.opensearch.search.query.QuerySearchResult;

/**
Expand All @@ -39,6 +48,7 @@
*/
public void execute(
final List<QuerySearchResult> querySearchResults,
final Optional<FetchSearchResult> fetchSearchResultOptional,
final ScoreNormalizationTechnique normalizationTechnique,
final ScoreCombinationTechnique combinationTechnique
) {
Expand All @@ -57,6 +67,7 @@
// post-process data
log.debug("Post-process query results after score normalization and combination");
updateOriginalQueryResults(querySearchResults, queryTopDocs);
updateOriginalFetchResults(querySearchResults, fetchSearchResultOptional);
}

/**
Expand All @@ -67,22 +78,87 @@
private List<CompoundTopDocs> getQueryTopDocs(final List<QuerySearchResult> querySearchResults) {
List<CompoundTopDocs> queryTopDocs = querySearchResults.stream()
.filter(searchResult -> Objects.nonNull(searchResult.topDocs()))
.filter(searchResult -> searchResult.topDocs().topDocs instanceof CompoundTopDocs)
.map(searchResult -> (CompoundTopDocs) searchResult.topDocs().topDocs)
.map(querySearchResult -> querySearchResult.topDocs().topDocs)
.map(CompoundTopDocs::new)
.collect(Collectors.toList());
if (queryTopDocs.size() != querySearchResults.size()) {
martin-gaievski marked this conversation as resolved.
Show resolved Hide resolved
throw new IllegalStateException(
String.format(

Check warning on line 86 in src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java#L85-L86

Added lines #L85 - L86 were not covered by tests
Locale.ROOT,
"query results were not formatted correctly by the hybrid query; sizes of querySearchResults [%d] and queryTopDocs [%d] must match",
querySearchResults.size(),
queryTopDocs.size()

Check warning on line 90 in src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java#L89-L90

Added lines #L89 - L90 were not covered by tests
)
);
}
return queryTopDocs;
}

private void updateOriginalQueryResults(final List<QuerySearchResult> querySearchResults, final List<CompoundTopDocs> queryTopDocs) {
for (int i = 0; i < querySearchResults.size(); i++) {
QuerySearchResult querySearchResult = querySearchResults.get(i);
if (!(querySearchResult.topDocs().topDocs instanceof CompoundTopDocs) || Objects.isNull(queryTopDocs.get(i))) {
continue;
}
CompoundTopDocs updatedTopDocs = queryTopDocs.get(i);
float maxScore = updatedTopDocs.totalHits.value > 0 ? updatedTopDocs.scoreDocs[0].score : 0.0f;
TopDocsAndMaxScore updatedTopDocsAndMaxScore = new TopDocsAndMaxScore(updatedTopDocs, maxScore);
if (querySearchResults.size() != queryTopDocs.size()) {
throw new IllegalStateException(
String.format(

Check warning on line 100 in src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java#L99-L100

Added lines #L99 - L100 were not covered by tests
Locale.ROOT,
"query results were not formatted correctly by the hybrid query; sizes of querySearchResults [%d] and queryTopDocs [%d] must match",
querySearchResults.size(),
queryTopDocs.size()

Check warning on line 104 in src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java#L103-L104

Added lines #L103 - L104 were not covered by tests
)
);
}
for (int index = 0; index < querySearchResults.size(); index++) {
QuerySearchResult querySearchResult = querySearchResults.get(index);
CompoundTopDocs updatedTopDocs = queryTopDocs.get(index);
float maxScore = updatedTopDocs.getTotalHits().value > 0 ? updatedTopDocs.getScoreDocs().get(0).score : 0.0f;

// create final version of top docs with all updated values
TopDocs topDocs = new TopDocs(updatedTopDocs.getTotalHits(), updatedTopDocs.getScoreDocs().toArray(new ScoreDoc[0]));

TopDocsAndMaxScore updatedTopDocsAndMaxScore = new TopDocsAndMaxScore(topDocs, maxScore);
querySearchResult.topDocs(updatedTopDocsAndMaxScore, null);
}
}

/**
* A workaround for a single shard case, fetch has happened, and we need to update both fetch and query results
*/
private void updateOriginalFetchResults(
final List<QuerySearchResult> querySearchResults,
final Optional<FetchSearchResult> fetchSearchResultOptional
) {
if (fetchSearchResultOptional.isEmpty()) {
return;
}
// fetch results have list of document content, that includes start/stop and
// delimiter elements. list is in original order from query searcher. We need to:
// 1. filter out start/stop and delimiter elements
// 2. filter out duplicates from different sub-queries
// 3. update original scores to normalized and combined values
// 4. order scores based on normalized and combined values
FetchSearchResult fetchSearchResult = fetchSearchResultOptional.get();
SearchHits searchHits = fetchSearchResult.hits();

// create map of docId to index of search hits. This solves (2), duplicates are from
// delimiter and start/stop elements, they all have same valid doc_id. For this map
// we use doc_id as a key, and all those special elements are collapsed into a single
// key-value pair.
Map<Integer, SearchHit> docIdToSearchHit = Arrays.stream(searchHits.getHits())
.collect(Collectors.toMap(SearchHit::docId, Function.identity(), (a1, a2) -> a1));
martin-gaievski marked this conversation as resolved.
Show resolved Hide resolved

QuerySearchResult querySearchResult = querySearchResults.get(0);
martin-gaievski marked this conversation as resolved.
Show resolved Hide resolved
TopDocs topDocs = querySearchResult.topDocs().topDocs;
// iterate over the normalized/combined scores, that solves (1) and (3)
SearchHit[] updatedSearchHitArray = Arrays.stream(topDocs.scoreDocs).map(scoreDoc -> {
// get fetched hit content by doc_id
SearchHit searchHit = docIdToSearchHit.get(scoreDoc.doc);
// update score to normalized/combined value (3)
searchHit.score(scoreDoc.score);
return searchHit;
}).toArray(SearchHit[]::new);
SearchHits updatedSearchHits = new SearchHits(
updatedSearchHitArray,
querySearchResult.getTotalHits(),
querySearchResult.getMaxScore()
);
fetchSearchResult.hits(updatedSearchHits);
}
}
Loading
Loading