Skip to content

[8.19] Backporting Linear Retriever MinScore #129368

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/129359.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 129359
summary: Add min score linear retriever
area: Search
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,8 @@ public void onFailure(Exception e) {
RankDocsRetrieverBuilder rankDocsRetrieverBuilder = new RankDocsRetrieverBuilder(
rankWindowSize,
newRetrievers.stream().map(s -> s.retriever).toList(),
results::get
results::get,
this.minScore
);
rankDocsRetrieverBuilder.retrieverName(retrieverName());
return rankDocsRetrieverBuilder;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,14 @@ public class RankDocsRetrieverBuilder extends RetrieverBuilder {
final List<RetrieverBuilder> sources;
final Supplier<RankDoc[]> rankDocs;

public RankDocsRetrieverBuilder(int rankWindowSize, List<RetrieverBuilder> sources, Supplier<RankDoc[]> rankDocs) {
public RankDocsRetrieverBuilder(int rankWindowSize, List<RetrieverBuilder> sources, Supplier<RankDoc[]> rankDocs, Float minScore) {
this.rankWindowSize = rankWindowSize;
this.rankDocs = rankDocs;
if (sources == null || sources.isEmpty()) {
throw new IllegalArgumentException("sources must not be null or empty");
}
this.sources = sources;
this.minScore = minScore;
}

@Override
Expand All @@ -48,7 +49,7 @@ public String getName() {
}

private boolean sourceHasMinScore() {
return minScore != null || sources.stream().anyMatch(x -> x.minScore() != null);
return this.minScore != null || sources.stream().anyMatch(x -> x.minScore() != null);
}

private boolean sourceShouldRewrite(QueryRewriteContext ctx) throws IOException {
Expand Down Expand Up @@ -132,7 +133,7 @@ public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder
searchSourceBuilder.size(rankWindowSize);
}
if (sourceHasMinScore()) {
searchSourceBuilder.minScore(this.minScore() == null ? Float.MIN_VALUE : this.minScore());
searchSourceBuilder.minScore(this.minScore == null ? Float.MIN_VALUE : this.minScore);
}
if (searchSourceBuilder.size() + searchSourceBuilder.from() > rankDocResults.length) {
searchSourceBuilder.size(Math.max(0, rankDocResults.length - searchSourceBuilder.from()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ private List<QueryBuilder> preFilters(QueryRewriteContext queryRewriteContext) t
}

private RankDocsRetrieverBuilder createRandomRankDocsRetrieverBuilder(QueryRewriteContext queryRewriteContext) throws IOException {
return new RankDocsRetrieverBuilder(randomIntBetween(1, 100), innerRetrievers(queryRewriteContext), rankDocsSupplier());
return new RankDocsRetrieverBuilder(randomIntBetween(1, 100), innerRetrievers(queryRewriteContext), rankDocsSupplier(), null);
}

public void testExtractToSearchSourceBuilder() throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ public Set<NodeFeature> getTestFeatures() {
SemanticInferenceMetadataFieldsMapper.EXPLICIT_NULL_FIXES,
SEMANTIC_KNN_VECTOR_QUERY_REWRITE_INTERCEPTION_SUPPORTED,
TextSimilarityRankRetrieverBuilder.TEXT_SIMILARITY_RERANKER_ALIAS_HANDLING_FIX,
TextSimilarityRankRetrieverBuilder.TEXT_SIMILARITY_RERANKER_MINSCORE_FIX,
SemanticInferenceMetadataFieldsMapper.INFERENCE_METADATA_FIELDS_ENABLED_BY_DEFAULT,
SEMANTIC_TEXT_HIGHLIGHTER_DEFAULT,
SEMANTIC_KNN_FILTER_FIX,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.elasticsearch.xcontent.XContentParser;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;

Expand All @@ -49,6 +50,7 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder
"text_similarity_reranker_alias_handling_fix",
true
);
public static final NodeFeature TEXT_SIMILARITY_RERANKER_MINSCORE_FIX = new NodeFeature("text_similarity_reranker_minscore_fix");

public static final ParseField RETRIEVER_FIELD = new ParseField("retriever");
public static final ParseField INFERENCE_ID_FIELD = new ParseField("inference_id");
Expand Down Expand Up @@ -175,23 +177,21 @@ protected TextSimilarityRankRetrieverBuilder clone(
protected RankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults, boolean explain) {
assert rankResults.size() == 1;
ScoreDoc[] scoreDocs = rankResults.get(0);
TextSimilarityRankDoc[] textSimilarityRankDocs = new TextSimilarityRankDoc[scoreDocs.length];
List<TextSimilarityRankDoc> filteredDocs = new ArrayList<>();
// Filtering by min_score must be done here, after reranking.
// Applying min_score in the child retriever could prematurely exclude documents that would receive high scores from the reranker.
for (int i = 0; i < scoreDocs.length; i++) {
ScoreDoc scoreDoc = scoreDocs[i];
assert scoreDoc.score >= 0;
if (explain) {
textSimilarityRankDocs[i] = new TextSimilarityRankDoc(
scoreDoc.doc,
scoreDoc.score,
scoreDoc.shardIndex,
inferenceId,
field
);
} else {
textSimilarityRankDocs[i] = new TextSimilarityRankDoc(scoreDoc.doc, scoreDoc.score, scoreDoc.shardIndex);
if (minScore == null || scoreDoc.score >= minScore) {
if (explain) {
filteredDocs.add(new TextSimilarityRankDoc(scoreDoc.doc, scoreDoc.score, scoreDoc.shardIndex, inferenceId, field));
} else {
filteredDocs.add(new TextSimilarityRankDoc(scoreDoc.doc, scoreDoc.score, scoreDoc.shardIndex));
}
}
}
return textSimilarityRankDocs;
return filteredDocs.toArray(new TextSimilarityRankDoc[0]);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -381,3 +381,111 @@ setup:
- match: { hits.total.value: 1 }
- length: { hits.hits: 1 }
- match: { hits.hits.0._id: "doc_1" }

---
"Text similarity reranker respects min_score":

- requires:
cluster_features: "text_similarity_reranker_minscore_fix"
reason: test min score functionality

- do:
index:
index: test-index
id: doc_2
body:
text: "The phases of the Moon come from the position of the Moon relative to the Earth and Sun."
topic: [ "science" ]
subtopic: [ "astronomy" ]
inference_text_field: "10"
refresh: true

- do:
search:
index: test-index
body:
track_total_hits: true
fields: [ "text", "topic" ]
retriever:
text_similarity_reranker:
retriever:
standard:
query:
bool:
should:
- constant_score:
filter:
term: { subtopic: "technology" }
boost: 10
- constant_score:
filter:
term: { subtopic: "astronomy" }
boost: 1
rank_window_size: 10
inference_id: my-rerank-model
inference_text: "How often does the moon hide the sun?"
field: inference_text_field
min_score: 10
size: 10

- match: { hits.total.value: 1 }
- length: { hits.hits: 1 }
- match: { hits.hits.0._id: "doc_2" }

---
"Text similarity reranker with min_score zero includes all docs":

- requires:
cluster_features: "text_similarity_reranker_minscore_fix"
reason: test min score functionality

- do:
search:
index: test-index
body:
track_total_hits: true
fields: [ "text", "topic" ]
retriever:
text_similarity_reranker:
retriever:
standard:
query:
match_all: {}
rank_window_size: 10
inference_id: my-rerank-model
inference_text: "How often does the moon hide the sun?"
field: inference_text_field
min_score: 0
size: 10

- match: { hits.total.value: 3 }
- length: { hits.hits: 3 }

---
"Text similarity reranker with high min_score excludes all docs":

- requires:
cluster_features: "text_similarity_reranker_minscore_fix"
reason: test min score functionality

- do:
search:
index: test-index
body:
track_total_hits: true
fields: [ "text", "topic" ]
retriever:
text_similarity_reranker:
retriever:
standard:
query:
match_all: {}
rank_window_size: 10
inference_id: my-rerank-model
inference_text: "How often does the moon hide the sun?"
field: inference_text_field
min_score: 1000
size: 10

- match: { hits.total.value: 0 }
- length: { hits.hits: 0 }
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import static org.elasticsearch.search.retriever.CompoundRetrieverBuilder.INNER_RETRIEVERS_FILTER_SUPPORT;
import static org.elasticsearch.xpack.rank.linear.L2ScoreNormalizer.LINEAR_RETRIEVER_L2_NORM;
import static org.elasticsearch.xpack.rank.linear.LinearRetrieverBuilder.LINEAR_RETRIEVER_MINSCORE_FIX;
import static org.elasticsearch.xpack.rank.linear.MinMaxScoreNormalizer.LINEAR_RETRIEVER_MINMAX_SINGLE_DOC_FIX;
import static org.elasticsearch.xpack.rank.rrf.RRFRetrieverBuilder.RRF_RETRIEVER_COMPOSITION_SUPPORTED;

Expand All @@ -32,6 +33,11 @@ public Set<NodeFeature> getFeatures() {

@Override
public Set<NodeFeature> getTestFeatures() {
return Set.of(INNER_RETRIEVERS_FILTER_SUPPORT, LINEAR_RETRIEVER_MINMAX_SINGLE_DOC_FIX, LINEAR_RETRIEVER_L2_NORM);
return Set.of(
INNER_RETRIEVERS_FILTER_SUPPORT,
LINEAR_RETRIEVER_MINMAX_SINGLE_DOC_FIX,
LINEAR_RETRIEVER_L2_NORM,
LINEAR_RETRIEVER_MINSCORE_FIX
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import org.apache.lucene.search.ScoreDoc;
import org.elasticsearch.common.ParsingException;
import org.elasticsearch.common.util.Maps;
import org.elasticsearch.features.NodeFeature;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.search.builder.SearchSourceBuilder;
Expand Down Expand Up @@ -46,6 +47,7 @@
*/
public final class LinearRetrieverBuilder extends CompoundRetrieverBuilder<LinearRetrieverBuilder> {

public static final NodeFeature LINEAR_RETRIEVER_MINSCORE_FIX = new NodeFeature("linear_retriever_minscore_fix");
public static final String NAME = "linear";

public static final ParseField RETRIEVERS_FIELD = new ParseField("retrievers");
Expand Down Expand Up @@ -125,12 +127,35 @@ public LinearRetrieverBuilder(
this.normalizers = normalizers;
}

public LinearRetrieverBuilder(
List<RetrieverSource> innerRetrievers,
int rankWindowSize,
float[] weights,
ScoreNormalizer[] normalizers,
Float minScore,
String retrieverName,
List<QueryBuilder> preFilterQueryBuilders
) {
this(innerRetrievers, rankWindowSize, weights, normalizers);
this.minScore = minScore;
if (minScore != null && minScore < 0) {
throw new IllegalArgumentException("[min_score] must be greater than or equal to 0, was: [" + minScore + "]");
}
this.retrieverName = retrieverName;
this.preFilterQueryBuilders = preFilterQueryBuilders;
}

@Override
protected LinearRetrieverBuilder clone(List<RetrieverSource> newChildRetrievers, List<QueryBuilder> newPreFilterQueryBuilders) {
LinearRetrieverBuilder clone = new LinearRetrieverBuilder(newChildRetrievers, rankWindowSize, weights, normalizers);
clone.preFilterQueryBuilders = newPreFilterQueryBuilders;
clone.retrieverName = retrieverName;
return clone;
return new LinearRetrieverBuilder(
newChildRetrievers,
rankWindowSize,
weights,
normalizers,
minScore,
retrieverName,
newPreFilterQueryBuilders
);
}

@Override
Expand Down Expand Up @@ -181,6 +206,10 @@ protected RankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults, b
topResults[rank] = sortedResults[rank];
topResults[rank].rank = rank + 1;
}
// Filter by minScore if set(inclusive)
if (minScore != null) {
topResults = Arrays.stream(topResults).filter(doc -> doc.score >= minScore).toArray(LinearRankDoc[]::new);
}
return topResults;
}

Expand Down
Loading