Skip to content

Commit

Permalink
Addressed multiple review comments
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <gaievski@amazon.com>
  • Loading branch information
martin-gaievski committed Aug 26, 2023
1 parent 4760545 commit eb2e29c
Show file tree
Hide file tree
Showing 14 changed files with 227 additions and 135 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@

package org.opensearch.neuralsearch.processor;

import static org.opensearch.neuralsearch.processor.NormalizationProcessor.isHybridQueryDelimiterElement;
import static org.opensearch.neuralsearch.processor.NormalizationProcessor.isHybridQueryStartStopElement;
import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQueryDelimiterElement;
import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQueryStartStopElement;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;

import lombok.AllArgsConstructor;
import lombok.Getter;
Expand All @@ -35,36 +36,21 @@ public class CompoundTopDocs {
@Setter
private TotalHits totalHits;
@Getter
private final List<TopDocs> compoundTopDocs;
private List<TopDocs> compoundTopDocs;
@Getter
@Setter
private ScoreDoc[] scoreDocs;
private List<ScoreDoc> scoreDocs;

public CompoundTopDocs(final TotalHits totalHits, final List<TopDocs> compoundTopDocs) {
initialize(totalHits, compoundTopDocs);
}

private void initialize(TotalHits totalHits, List<TopDocs> compoundTopDocs) {
this.totalHits = totalHits;
this.compoundTopDocs = compoundTopDocs;
scoreDocs = cloneLargestScoreDocs(compoundTopDocs);
}

private static ScoreDoc[] cloneLargestScoreDocs(List<TopDocs> docs) {
if (docs == null) {
return null;
}
ScoreDoc[] maxScoreDocs = new ScoreDoc[0];
int maxLength = -1;
for (TopDocs topDoc : docs) {
if (topDoc == null || topDoc.scoreDocs == null) {
continue;
}
if (topDoc.scoreDocs.length > maxLength) {
maxLength = topDoc.scoreDocs.length;
maxScoreDocs = topDoc.scoreDocs;
}
}
// do deep copy
return Arrays.stream(maxScoreDocs).map(doc -> new ScoreDoc(doc.doc, doc.score, doc.shardIndex)).toArray(ScoreDoc[]::new);
}

/**
* Create new instance from TopDocs by parsing scores of sub-queries. Final format looks like:
* doc_id | magic_number_1
Expand All @@ -76,15 +62,24 @@ private static ScoreDoc[] cloneLargestScoreDocs(List<TopDocs> docs) {
* ...
* doc_id | magic_number_1
*
* where doc_id is one of valid ids from result
* where doc_id is one of valid ids from result. For example, this is list with results for there sub-queries
*
* @param topDocs object with scores from multiple sub-queries
* @return compound TopDocs object that has results from all sub-queries
* 0, 9549511920.4881596047
* 0, 4422440593.9791198149
* 0, 0.8
* 2, 0.5
* 0, 4422440593.9791198149
* 0, 4422440593.9791198149
* 2, 0.7
* 5, 0.65
* 6, 0.15
* 0, 9549511920.4881596047
*/
public static CompoundTopDocs create(final TopDocs topDocs) {
public CompoundTopDocs(final TopDocs topDocs) {
ScoreDoc[] scoreDocs = topDocs.scoreDocs;
if (Objects.isNull(scoreDocs) || scoreDocs.length < 2) {
return new CompoundTopDocs(topDocs.totalHits, new ArrayList<>());
initialize(topDocs.totalHits, new ArrayList<>());
return;

Check warning on line 82 in src/main/java/org/opensearch/neuralsearch/processor/CompoundTopDocs.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/processor/CompoundTopDocs.java#L81-L82

Added lines #L81 - L82 were not covered by tests
}
// skipping first two elements, it's a start-stop element and delimiter for first series
List<TopDocs> topDocsList = new ArrayList<>();
Expand All @@ -102,6 +97,25 @@ public static CompoundTopDocs create(final TopDocs topDocs) {
scoreDocList.add(scoreDoc);
}
}
return new CompoundTopDocs(topDocs.totalHits, topDocsList);
initialize(topDocs.totalHits, topDocsList);
}

private List<ScoreDoc> cloneLargestScoreDocs(final List<TopDocs> docs) {
if (docs == null) {
return null;
}
ScoreDoc[] maxScoreDocs = new ScoreDoc[0];
int maxLength = -1;
for (TopDocs topDoc : docs) {
if (topDoc == null || topDoc.scoreDocs == null) {
continue;
}
if (topDoc.scoreDocs.length > maxLength) {
maxLength = topDoc.scoreDocs.length;
maxScoreDocs = topDoc.scoreDocs;
}
}
// do deep copy
return Arrays.stream(maxScoreDocs).map(doc -> new ScoreDoc(doc.doc, doc.score, doc.shardIndex)).collect(Collectors.toList());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@

package org.opensearch.neuralsearch.processor;

import static org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher.MAGIC_NUMBER_DELIMITER;
import static org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher.MAGIC_NUMBER_START_STOP;
import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQueryStartStopElement;

import java.util.List;
import java.util.Objects;
Expand All @@ -15,7 +14,6 @@
import lombok.AllArgsConstructor;
import lombok.extern.log4j.Log4j2;

import org.apache.lucene.search.ScoreDoc;
import org.opensearch.action.search.QueryPhaseResultConsumer;
import org.opensearch.action.search.SearchPhaseContext;
import org.opensearch.action.search.SearchPhaseName;
Expand Down Expand Up @@ -131,12 +129,4 @@ private <Result extends SearchPhaseResult> List<QuerySearchResult> getQueryPhase
.map(result -> result == null ? null : result.queryResult())
.collect(Collectors.toList());
}

public static boolean isHybridQueryStartStopElement(final ScoreDoc scoreDoc) {
return Objects.nonNull(scoreDoc) && scoreDoc.doc >= 0 && Float.compare(scoreDoc.score, MAGIC_NUMBER_START_STOP) == 0;
}

public static boolean isHybridQueryDelimiterElement(final ScoreDoc scoreDoc) {
return Objects.nonNull(scoreDoc) && scoreDoc.doc >= 0 && Float.compare(scoreDoc.score, MAGIC_NUMBER_DELIMITER) == 0;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;

import lombok.AllArgsConstructor;
import lombok.extern.log4j.Log4j2;

import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.opensearch.common.lucene.search.TopDocsAndMaxScore;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
Expand Down Expand Up @@ -64,7 +66,8 @@ public void execute(

// post-process data
log.debug("Post-process query results after score normalization and combination");
updateOriginalQueryResults(querySearchResults, fetchSearchResult, queryTopDocs, isSingleShard);
updateOriginalQueryResults(querySearchResults, fetchSearchResult, queryTopDocs);
updateOriginalFetchResults(querySearchResults, fetchSearchResult, isSingleShard);
}

/**
Expand All @@ -76,40 +79,52 @@ private List<CompoundTopDocs> getQueryTopDocs(final List<QuerySearchResult> quer
List<CompoundTopDocs> queryTopDocs = querySearchResults.stream()
.filter(searchResult -> Objects.nonNull(searchResult.topDocs()))
.map(querySearchResult -> querySearchResult.topDocs().topDocs)
.map(CompoundTopDocs::create)
.map(CompoundTopDocs::new)
.collect(Collectors.toList());
if (queryTopDocs.size() != querySearchResults.size()) {
log.debug("Some of querySearchResults are not produced by hybrid query");
log.warn("Some of querySearchResults are not produced by hybrid query");

Check warning on line 85 in src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java#L85

Added line #L85 was not covered by tests
}
return queryTopDocs;
}

private void updateOriginalQueryResults(
final List<QuerySearchResult> querySearchResults,
final FetchSearchResult fetchSearchResult,
final List<CompoundTopDocs> queryTopDocs,
final boolean isSingleShard
final List<CompoundTopDocs> queryTopDocs
) {
int queryTopDocsIndex = 0;
for (QuerySearchResult querySearchResult : querySearchResults) {
CompoundTopDocs updatedTopDocs = queryTopDocs.get(queryTopDocsIndex++);
float maxScore = updatedTopDocs.getTotalHits().value > 0 ? updatedTopDocs.getScoreDocs()[0].score : 0.0f;
if (querySearchResults.size() != queryTopDocs.size()) {
log.error(
String.format(

Check warning on line 97 in src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java#L96-L97

Added lines #L96 - L97 were not covered by tests
Locale.ROOT,
"sizes of querySearchResults [%d] and queryTopDocs [%d] must match",
querySearchResults.size(),
queryTopDocs.size()

Check warning on line 101 in src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java#L100-L101

Added lines #L100 - L101 were not covered by tests
)
);
throw new IllegalStateException("found inconsistent system state while processing score normalization and combination");

Check warning on line 104 in src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java#L104

Added line #L104 was not covered by tests
}
for (int index = 0; index < querySearchResults.size(); index++) {
QuerySearchResult querySearchResult = querySearchResults.get(index);
CompoundTopDocs updatedTopDocs = queryTopDocs.get(index);
float maxScore = updatedTopDocs.getTotalHits().value > 0 ? updatedTopDocs.getScoreDocs().get(0).score : 0.0f;

// create final version of top docs with all updated values
TopDocs topDocs = new TopDocs(updatedTopDocs.getTotalHits(), updatedTopDocs.getScoreDocs());
TopDocs topDocs = new TopDocs(updatedTopDocs.getTotalHits(), updatedTopDocs.getScoreDocs().toArray(new ScoreDoc[0]));

TopDocsAndMaxScore updatedTopDocsAndMaxScore = new TopDocsAndMaxScore(topDocs, maxScore);
querySearchResult.topDocs(updatedTopDocsAndMaxScore, null);
}
// a workaround for a single shard case, fetch has happened, and we need to update both fetch and
// query results
if (isSingleShard && querySearchResults.size() == 1) {
updateFetchSearchResults(querySearchResults, fetchSearchResult);
}
}

private void updateFetchSearchResults(final List<QuerySearchResult> querySearchResults, final FetchSearchResult fetchSearchResult) {
if (Objects.isNull(fetchSearchResult)) {
/**
* A workaround for a single shard case, fetch has happened, and we need to update both fetch and query results
*/
private void updateOriginalFetchResults(
final List<QuerySearchResult> querySearchResults,
final FetchSearchResult fetchSearchResult,
final boolean isSingleShard
) {
if (!isSingleShard || querySearchResults.size() != 1 || Objects.isNull(fetchSearchResult)) {
return;
}
// fetch results have list of document content, that includes start/stop and
Expand All @@ -120,14 +135,15 @@ private void updateFetchSearchResults(final List<QuerySearchResult> querySearchR
// 4. order scores based on normalized and combined values
SearchHits searchHits = fetchSearchResult.hits();

// create map of docId to index of search hits
// create map of docId to index of search hits, handles (2)
Map<Integer, SearchHit> docIdToSearchHit = new HashMap<>();
for (SearchHit searchHit : searchHits) {
docIdToSearchHit.put(searchHit.docId(), searchHit);
}

QuerySearchResult querySearchResult = querySearchResults.get(0);
TopDocs topDocs = querySearchResult.topDocs().topDocs;
// iterate over the normalized/combined scores, that solves (1), (2) and (3)
// iterate over the normalized/combined scores, that solves (1) and (3)
SearchHit[] updatedSearchHitArray = Arrays.stream(topDocs.scoreDocs).map(scoreDoc -> {
// get fetched hit content by doc_id
SearchHit searchHit = docIdToSearchHit.get(scoreDoc.doc);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.neuralsearch.processor.combination;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -75,20 +76,20 @@ private List<Integer> getSortedDocIds(final Map<Integer, Float> combinedNormaliz
return sortedDocsIds;
}

private ScoreDoc[] getCombinedScoreDocs(
private List<ScoreDoc> getCombinedScoreDocs(
final CompoundTopDocs compoundQueryTopDocs,
final Map<Integer, Float> combinedNormalizedScoresByDocId,
final List<Integer> sortedScores,
final int maxHits
) {
ScoreDoc[] finalScoreDocs = new ScoreDoc[maxHits];

int shardId = compoundQueryTopDocs.getScoreDocs()[0].shardIndex;
int shardId = compoundQueryTopDocs.getScoreDocs().get(0).shardIndex;
for (int j = 0; j < maxHits && j < sortedScores.size(); j++) {
int docId = sortedScores.get(j);
finalScoreDocs[j] = new ScoreDoc(docId, combinedNormalizedScoresByDocId.get(docId), shardId);
}
return finalScoreDocs;
return Arrays.stream(finalScoreDocs).collect(Collectors.toList());
}

public Map<Integer, float[]> getNormalizedScoresPerDocument(final List<TopDocs> topDocsPerSubQuery) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

package org.opensearch.neuralsearch.search.query;

import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createDelimiterElementForHybridSearchResults;
import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults;
import static org.opensearch.search.query.TopDocsCollectorContext.createTopDocsCollectorContext;

import java.io.IOException;
Expand Down Expand Up @@ -45,9 +47,6 @@
@Log4j2
public class HybridQueryPhaseSearcher extends QueryPhase.DefaultQueryPhaseSearcher {

public static final Float MAGIC_NUMBER_START_STOP = 9549511920.4881596047f;
public static final Float MAGIC_NUMBER_DELIMITER = 4422440593.9791198149f;

public boolean searchWith(
final SearchContext searchContext,
final ContextIndexSearcher searcher,
Expand Down Expand Up @@ -115,13 +114,13 @@ private void setTopDocsInQueryResult(
) {
final List<TopDocs> topDocs = collector.topDocs();
final float maxScore = getMaxScore(topDocs);
boolean isSingleShard = searchContext.numberOfShards() == 1;
final boolean isSingleShard = searchContext.numberOfShards() == 1;
final TopDocs newTopDocs = getNewTopDocs(getTotalHits(searchContext, topDocs, isSingleShard), topDocs);
final TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(newTopDocs, maxScore);
queryResult.topDocs(topDocsAndMaxScore, getSortValueFormats(searchContext.sort()));
}

TopDocs getNewTopDocs(final TotalHits totalHits, final List<TopDocs> topDocs) {
private TopDocs getNewTopDocs(final TotalHits totalHits, final List<TopDocs> topDocs) {
ScoreDoc[] scoreDocs = new ScoreDoc[0];
if (Objects.nonNull(topDocs)) {
// for a single shard case we need to do score processing at coordinator level.
Expand Down Expand Up @@ -205,22 +204,4 @@ private float getMaxScore(final List<TopDocs> topDocs) {
private DocValueFormat[] getSortValueFormats(final SortAndFormats sortAndFormats) {
return sortAndFormats == null ? null : sortAndFormats.formats;
}

/**
* Create ScoreDoc object that is a start/stop element in case of hybrid search query results
* @param docId id of one of docs from actual result object, or -1 if there are no matches
* @return
*/
public static ScoreDoc createStartStopElementForHybridSearchResults(final int docId) {
return new ScoreDoc(docId, MAGIC_NUMBER_START_STOP);
}

/**
* Create ScoreDoc object that is a delimiter element between sub-query results in hybrid search query results
* @param docId id of one of docs from actual result object, or -1 if there are no matches
* @return
*/
public static ScoreDoc createDelimiterElementForHybridSearchResults(final int docId) {
return new ScoreDoc(docId, MAGIC_NUMBER_DELIMITER);
}
}
Loading

0 comments on commit eb2e29c

Please sign in to comment.