Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix explain exception in hybrid queries with partial subquery matches #1123

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Fixed document source and score field mismatch in sorted hybrid queries ([#1043](https://github.com/opensearch-project/neural-search/pull/1043))
- Update NeuralQueryBuilder doEquals() and doHashCode() to cater the missing parameters information ([#1045](https://github.com/opensearch-project/neural-search/pull/1045)).
- Fix bug where embedding is missing when ingested document has "." in field name, and mismatches fieldMap config ([#1062](https://github.com/opensearch-project/neural-search/pull/1062))
- Fix explain exception in hybrid queries with partial subquery matches ([#1123](https://github.com/opensearch-project/neural-search/pull/1123))
- Handle pagination_depth when from =0 and removes default value of pagination_depth ([#1132](https://github.com/opensearch-project/neural-search/pull/1132))
### Infrastructure
### Documentation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.extern.log4j.Log4j2;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.lucene.search.Explanation;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
Expand All @@ -21,6 +23,7 @@
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;

Expand All @@ -32,6 +35,7 @@
*/
@Getter
@AllArgsConstructor
@Log4j2
public class ExplanationResponseProcessor implements SearchResponseProcessor {

public static final String TYPE = "hybrid_score_explanation";
Expand Down Expand Up @@ -99,16 +103,40 @@ public SearchResponse processResponse(
ExplanationDetails normalizationExplanation = combinedExplainDetail.getNormalizationExplanations();
ExplanationDetails combinationExplanation = combinedExplainDetail.getCombinationExplanations();
// Create normalized explanations for each detail
Explanation[] normalizedExplanation = new Explanation[queryLevelExplanation.getDetails().length];
for (int i = 0; i < queryLevelExplanation.getDetails().length; i++) {
normalizedExplanation[i] = Explanation.match(
// normalized score
normalizationExplanation.getScoreDetails().get(i).getKey(),
// description of normalized score
normalizationExplanation.getScoreDetails().get(i).getValue(),
// shard level details
queryLevelExplanation.getDetails()[i]
if (normalizationExplanation.getScoreDetails().size() != queryLevelExplanation.getDetails().length) {
log.error(
String.format(
Locale.ROOT,
"length of query level explanations %d must match length of explanations after normalization %d",
queryLevelExplanation.getDetails().length,
normalizationExplanation.getScoreDetails().size()
)
);
throw new IllegalStateException("mismatch in number of query level explanations and normalization explanations");
}
List<Explanation> normalizedExplanation = new ArrayList<>(queryLevelExplanation.getDetails().length);
int normalizationExplanationIndex = 0;
for (Explanation queryExplanation : queryLevelExplanation.getDetails()) {
martin-gaievski marked this conversation as resolved.
Show resolved Hide resolved
// adding only explanations where this hit has matched
if (Float.compare(queryExplanation.getValue().floatValue(), 0.0f) > 0) {
Pair<Float, String> normalizedScoreDetails = normalizationExplanation.getScoreDetails()
.get(normalizationExplanationIndex);
martin-gaievski marked this conversation as resolved.
Show resolved Hide resolved
if (Objects.isNull(normalizedScoreDetails)) {
throw new IllegalStateException("normalized score details must not be null");
}
normalizedExplanation.add(
Explanation.match(
// normalized score
normalizedScoreDetails.getKey(),
// description of normalized score
normalizedScoreDetails.getValue(),
// shard level details
queryExplanation
)
);
}
// we increment index in all cases, scores in query explanation can be 0.0
normalizationExplanationIndex++;
}
// Create and set final explanation combining all components
Float finalScore = Float.isNaN(searchHit.getScore()) ? 0.0f : searchHit.getScore();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,19 @@ public Map<DocIdAtSearchShard, ExplanationDetails> explain(List<CompoundTopDocs>
continue;
}
List<TopDocs> topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs();
for (int j = 0; j < topDocsPerSubQuery.size(); j++) {
TopDocs subQueryTopDoc = topDocsPerSubQuery.get(j);
int numberOfSubQueries = topDocsPerSubQuery.size();
for (int subQueryIndex = 0; subQueryIndex < numberOfSubQueries; subQueryIndex++) {
TopDocs subQueryTopDoc = topDocsPerSubQuery.get(subQueryIndex);
for (ScoreDoc scoreDoc : subQueryTopDoc.scoreDocs) {
DocIdAtSearchShard docIdAtSearchShard = new DocIdAtSearchShard(scoreDoc.doc, compoundQueryTopDocs.getSearchShard());
float normalizedScore = normalizeSingleScore(scoreDoc.score, normsPerSubquery.get(j));
normalizedScores.computeIfAbsent(docIdAtSearchShard, k -> new ArrayList<>()).add(normalizedScore);
float normalizedScore = normalizeSingleScore(scoreDoc.score, normsPerSubquery.get(subQueryIndex));
ScoreNormalizationUtil.setNormalizedScore(
normalizedScores,
docIdAtSearchShard,
subQueryIndex,
numberOfSubQueries,
normalizedScore
);
scoreDoc.score = normalizedScore;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
*/
package org.opensearch.neuralsearch.processor.normalization;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
Expand Down Expand Up @@ -92,16 +91,23 @@ public Map<DocIdAtSearchShard, ExplanationDetails> explain(final List<CompoundTo
continue;
}
List<TopDocs> topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs();
for (int j = 0; j < topDocsPerSubQuery.size(); j++) {
TopDocs subQueryTopDoc = topDocsPerSubQuery.get(j);
int numberOfSubQueries = topDocsPerSubQuery.size();
for (int subQueryIndex = 0; subQueryIndex < numberOfSubQueries; subQueryIndex++) {
TopDocs subQueryTopDoc = topDocsPerSubQuery.get(subQueryIndex);
for (ScoreDoc scoreDoc : subQueryTopDoc.scoreDocs) {
DocIdAtSearchShard docIdAtSearchShard = new DocIdAtSearchShard(scoreDoc.doc, compoundQueryTopDocs.getSearchShard());
float normalizedScore = normalizeSingleScore(
scoreDoc.score,
minMaxScores.getMinScoresPerSubquery()[j],
minMaxScores.getMaxScoresPerSubquery()[j]
minMaxScores.getMinScoresPerSubquery()[subQueryIndex],
minMaxScores.getMaxScoresPerSubquery()[subQueryIndex]
);
ScoreNormalizationUtil.setNormalizedScore(
normalizedScores,
docIdAtSearchShard,
subQueryIndex,
numberOfSubQueries,
normalizedScore
);
normalizedScores.computeIfAbsent(docIdAtSearchShard, k -> new ArrayList<>()).add(normalizedScore);
scoreDoc.score = normalizedScore;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,24 @@

import java.math.BigDecimal;
import java.math.RoundingMode;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Locale;
import java.util.Set;
import java.util.function.BiConsumer;
import java.util.stream.IntStream;

import org.apache.commons.lang3.Range;
import org.apache.commons.lang3.math.NumberUtils;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.opensearch.common.TriConsumer;
import org.opensearch.neuralsearch.processor.CompoundTopDocs;

import lombok.ToString;
import org.opensearch.neuralsearch.processor.NormalizeScoresDTO;
import org.opensearch.neuralsearch.processor.SearchShard;
import org.opensearch.neuralsearch.processor.explain.DocIdAtSearchShard;
import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique;
import org.opensearch.neuralsearch.processor.explain.ExplanationDetails;
Expand Down Expand Up @@ -65,7 +67,7 @@ public RRFNormalizationTechnique(final Map<String, Object> params, final ScoreNo
public void normalize(final NormalizeScoresDTO normalizeScoresDTO) {
final List<CompoundTopDocs> queryTopDocs = normalizeScoresDTO.getQueryTopDocs();
for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) {
processTopDocs(compoundQueryTopDocs, (docId, score) -> {});
processTopDocs(compoundQueryTopDocs, (docId, score, subQueryIndex) -> {});
}
}

Expand All @@ -79,31 +81,51 @@ public Map<DocIdAtSearchShard, ExplanationDetails> explain(List<CompoundTopDocs>
Map<DocIdAtSearchShard, List<Float>> normalizedScores = new HashMap<>();

for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) {
if (Objects.isNull(compoundQueryTopDocs)) {
continue;
}
List<TopDocs> topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs();
int numberOfSubQueries = topDocsPerSubQuery.size();
processTopDocs(
compoundQueryTopDocs,
(docId, score) -> normalizedScores.computeIfAbsent(docId, k -> new ArrayList<>()).add(score)
(docId, score, subQueryIndex) -> ScoreNormalizationUtil.setNormalizedScore(
normalizedScores,
docId,
subQueryIndex,
numberOfSubQueries,
score
)
);
}

return getDocIdAtQueryForNormalization(normalizedScores, this);
}

private void processTopDocs(CompoundTopDocs compoundQueryTopDocs, BiConsumer<DocIdAtSearchShard, Float> scoreProcessor) {
private void processTopDocs(CompoundTopDocs compoundQueryTopDocs, TriConsumer<DocIdAtSearchShard, Float, Integer> scoreProcessor) {
if (Objects.isNull(compoundQueryTopDocs)) {
return;
}

compoundQueryTopDocs.getTopDocs().forEach(topDocs -> {
IntStream.range(0, topDocs.scoreDocs.length).forEach(position -> {
float normalizedScore = calculateNormalizedScore(position);
DocIdAtSearchShard docIdAtSearchShard = new DocIdAtSearchShard(
topDocs.scoreDocs[position].doc,
compoundQueryTopDocs.getSearchShard()
);
scoreProcessor.accept(docIdAtSearchShard, normalizedScore);
topDocs.scoreDocs[position].score = normalizedScore;
});
});
List<TopDocs> topDocsList = compoundQueryTopDocs.getTopDocs();
SearchShard searchShard = compoundQueryTopDocs.getSearchShard();

for (int topDocsIndex = 0; topDocsIndex < topDocsList.size(); topDocsIndex++) {
processTopDocsEntry(topDocsList.get(topDocsIndex), searchShard, topDocsIndex, scoreProcessor);
}
}

private void processTopDocsEntry(
TopDocs topDocs,
SearchShard searchShard,
int topDocsIndex,
TriConsumer<DocIdAtSearchShard, Float, Integer> scoreProcessor
) {
for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
float normalizedScore = calculateNormalizedScore(Arrays.asList(topDocs.scoreDocs).indexOf(scoreDoc));
DocIdAtSearchShard docIdAtSearchShard = new DocIdAtSearchShard(scoreDoc.doc, searchShard);
scoreProcessor.apply(docIdAtSearchShard, normalizedScore, topDocsIndex);
scoreDoc.score = normalizedScore;
}
}

private float calculateNormalizedScore(int position) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
package org.opensearch.neuralsearch.processor.normalization;

import lombok.extern.log4j.Log4j2;
import org.opensearch.neuralsearch.processor.explain.DocIdAtSearchShard;

import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.Map;
Expand Down Expand Up @@ -54,4 +56,30 @@ public void validateParams(final Map<String, Object> actualParams, final Set<Str
}
}
}

/**
* Sets a normalized score for a specific document at a specific subquery index
*
* @param normalizedScores map of document IDs to their list of scores
* @param docIdAtSearchShard document ID
* @param subQueryIndex index of the subquery
* @param normalizedScore normalized score to set
*/
public static void setNormalizedScore(
martin-gaievski marked this conversation as resolved.
Show resolved Hide resolved
Map<DocIdAtSearchShard, List<Float>> normalizedScores,
DocIdAtSearchShard docIdAtSearchShard,
int subQueryIndex,
int numberOfSubQueries,
float normalizedScore
) {
List<Float> scores = normalizedScores.get(docIdAtSearchShard);
if (Objects.isNull(scores)) {
scores = new ArrayList<>(numberOfSubQueries);
for (int i = 0; i < numberOfSubQueries; i++) {
scores.add(0.0f);
}
normalizedScores.put(docIdAtSearchShard, scores);
}
scores.set(subQueryIndex, normalizedScore);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import java.util.concurrent.Callable;
import java.util.stream.Collectors;

import lombok.AccessLevel;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Explanation;
Expand All @@ -33,6 +35,7 @@
public final class HybridQueryWeight extends Weight {

// The Weights for our subqueries, in 1-1 correspondence
@Getter(AccessLevel.PACKAGE)
private final List<Weight> weights;

private final ScoreMode scoreMode;
Expand Down Expand Up @@ -157,10 +160,13 @@ public Explanation explain(LeafReaderContext context, int doc) throws IOExceptio
if (e.isMatch()) {
match = true;
double score = e.getValue().doubleValue();
subsOnMatch.add(e);
max = Math.max(max, score);
} else if (!match) {
subsOnNoMatch.add(e);
subsOnMatch.add(e);
} else {
if (!match) {
subsOnNoMatch.add(e);
}
subsOnMatch.add(e);
}
}
if (match) {
Expand Down
Loading
Loading