-
Notifications
You must be signed in to change notification settings - Fork 24.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix score count validation in reranker response #111212
Changes from 7 commits
a9805ae
fb5a671
7e3f153
7aa9a63
58a032c
c4d4d97
f30311d
4edc52e
6ea9eb7
7b91518
30ed86b
be485ac
ca032ee
71bc9bf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
pr: 111212 | ||
summary: Fix score count validation in reranker response | ||
area: Ranking | ||
type: bug | ||
issues: | ||
- 111202 |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,15 +7,22 @@ | |
|
||
package org.elasticsearch.xpack.inference.rank.textsimilarity; | ||
|
||
import org.elasticsearch.action.ActionFuture; | ||
import org.elasticsearch.action.ActionListener; | ||
import org.elasticsearch.client.internal.Client; | ||
import org.elasticsearch.inference.InferenceServiceResults; | ||
import org.elasticsearch.inference.InputType; | ||
import org.elasticsearch.inference.ModelConfigurations; | ||
import org.elasticsearch.inference.TaskType; | ||
import org.elasticsearch.search.rank.context.RankFeaturePhaseRankCoordinatorContext; | ||
import org.elasticsearch.search.rank.feature.RankFeatureDoc; | ||
import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction; | ||
import org.elasticsearch.xpack.core.inference.action.InferenceAction; | ||
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; | ||
import org.elasticsearch.xpack.inference.services.cohere.CohereService; | ||
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings; | ||
import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiService; | ||
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankTaskSettings; | ||
|
||
import java.util.Arrays; | ||
import java.util.Comparator; | ||
|
@@ -51,15 +58,42 @@ public TextSimilarityRankFeaturePhaseRankCoordinatorContext( | |
|
||
@Override | ||
protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener<float[]> scoreListener) { | ||
// The rerank inference endpoint may have an override to return top N documents only, in that case let's fail fast to avoid | ||
// assigning scores to the wrong inputs | ||
Integer inferenceEndpointTopN = getTopNFromInferenceEndpointTaskSettings(); | ||
if (inferenceEndpointTopN != null && inferenceEndpointTopN < featureDocs.length) { | ||
scoreListener.onFailure( | ||
new IllegalArgumentException( | ||
"Inference endpoint [" | ||
+ inferenceId | ||
+ "] is configured to return the top [" | ||
+ inferenceEndpointTopN | ||
+ "] results. Reduce rank_window_size to be less than or equal to this value." | ||
) | ||
); | ||
return; | ||
} | ||
|
||
// Wrap the provided rankListener to an ActionListener that would handle the response from the inference service | ||
// and then pass the results | ||
final ActionListener<InferenceAction.Response> actionListener = scoreListener.delegateFailureAndWrap((l, r) -> { | ||
float[] scores = extractScoresFromResponse(r); | ||
if (scores.length != featureDocs.length) { | ||
InferenceServiceResults results = r.getResults(); | ||
assert results instanceof RankedDocsResults; | ||
|
||
// Ensure we get exactly as many scores as the number of docs we passed, otherwise we may return incorrect results | ||
List<RankedDocsResults.RankedDoc> rankedDocs = ((RankedDocsResults) results).getRankedDocs(); | ||
if (rankedDocs.size() != featureDocs.length) { | ||
l.onFailure( | ||
new IllegalStateException("Document and score count mismatch: [" + featureDocs.length + "] vs [" + scores.length + "]") | ||
new IllegalStateException( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we also add a test that throws this exception instead of the IOOB ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @benwtrent @pmpailis I added a test case for simulating invalid indices (7e3f153). Note I don't think it actually covers the use case, because it's a sub-case of the reranker input/output count mismatch. With the bugfix this is now caught before the actual doc index assignment happens that would trigger the IOOB. (We do not have specific handling of N inputs -> N outputs -> index pointing outside 0..N-1, I'd consider that a reranker error.) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @demjened AH, ok, is there a way to parse this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @benwtrent Not without issuing a GET call before each inference call to check the inference endpoint's configuration, I'm afraid. The rerank retriever framework only exposes hooks for creation and submission of an inference request, and parsing the results, but it doesn't know about the config. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
We need to fix this. If there are things that are invalid for a configuration & a search request, we need to fix this. Calling a local GET to retrieve some minor documentation is way cheaper than calling some external service when the we know the request will fail. I would hold off on merging or progressing this until we can eagerly validate the search request as we know ahead of time that it won't work. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @benwtrent @Mikep86 @pmpailis I'll summarize the latest changes here that accommodates all your suggestions (thanks for those):
Let me know if you feel there are still things to improve and must go in this PR; I want to timebox the remaining effort as the bug has been fixed and I want to make sure this gets merged by 8.15.0 BC5. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Heads up that @demjened & I discussed offline yesterday and we're going to investigate doing the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @Mikep86 I've been thinking about this; while moving this logic to the ML code would make it cleaner and would remove coupling, I'm worried it's not a good idea from a product perspective. For reference, here's the code you suggested that would replace the preemptive top_n check from the reranker: // CohereService#doInfer
...
CohereModel cohereModel = (CohereModel) model;
var actionCreator = new CohereActionCreator(getSender(), getServiceComponents());
if (model instanceof CohereRerankModel rerankModel) {
Integer topN = rerankModel.getTaskSettings().getTopNDocumentsOnly();
if (topN != null && topN < input.size()) {
listener.onFailure(new IllegalArgumentException("top_n < doc count"));
return;
}
}
... The problem is we're invalidating a normal use case. It is a valid scenario that a rerank endpoint is configured to return a maximum of N hits for >N inputs. It's only an issue from the reranker retriever's perspective, where the a So we have 3 options to implement the check:
WDYT? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is valid for the the reranker API to return fewer docs as that is the whole point of The retriever also sets window size. For now, we should enforce that window size is <= topN. Maybe we can adjust this in the future, but for retrievers, this is a sane default & safe behavior. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Got it, wasn't thinking of rerank endpoint usage outside of this context. |
||
"Reranker input document count and returned score count mismatch: [" | ||
+ featureDocs.length | ||
+ "] vs [" | ||
+ rankedDocs.size() | ||
+ "]" | ||
) | ||
); | ||
} else { | ||
float[] scores = extractScoresFromRankedDocs(rankedDocs); | ||
l.onResponse(scores); | ||
} | ||
}); | ||
|
@@ -73,6 +107,18 @@ protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener<float[ | |
} | ||
} | ||
|
||
/** | ||
* Sorts documents by score descending and discards those with a score less than minScore. | ||
* @param originalDocs documents to process | ||
*/ | ||
@Override | ||
protected RankFeatureDoc[] preprocess(RankFeatureDoc[] originalDocs) { | ||
return Arrays.stream(originalDocs) | ||
.filter(doc -> minScore == null || doc.score >= minScore) | ||
.sorted(Comparator.comparing((RankFeatureDoc doc) -> doc.score).reversed()) | ||
.toArray(RankFeatureDoc[]::new); | ||
} | ||
|
||
protected InferenceAction.Request generateRequest(List<String> docFeatures) { | ||
return new InferenceAction.Request( | ||
TaskType.RERANK, | ||
|
@@ -85,11 +131,7 @@ protected InferenceAction.Request generateRequest(List<String> docFeatures) { | |
); | ||
} | ||
|
||
private float[] extractScoresFromResponse(InferenceAction.Response response) { | ||
InferenceServiceResults results = response.getResults(); | ||
assert results instanceof RankedDocsResults; | ||
|
||
List<RankedDocsResults.RankedDoc> rankedDocs = ((RankedDocsResults) results).getRankedDocs(); | ||
private float[] extractScoresFromRankedDocs(List<RankedDocsResults.RankedDoc> rankedDocs) { | ||
float[] scores = new float[rankedDocs.size()]; | ||
for (RankedDocsResults.RankedDoc rankedDoc : rankedDocs) { | ||
scores[rankedDoc.index()] = rankedDoc.relevanceScore(); | ||
|
@@ -98,15 +140,22 @@ private float[] extractScoresFromResponse(InferenceAction.Response response) { | |
return scores; | ||
} | ||
|
||
/** | ||
* Sorts documents by score descending and discards those with a score less than minScore. | ||
* @param originalDocs documents to process | ||
*/ | ||
@Override | ||
protected RankFeatureDoc[] preprocess(RankFeatureDoc[] originalDocs) { | ||
return Arrays.stream(originalDocs) | ||
.filter(doc -> minScore == null || doc.score >= minScore) | ||
.sorted(Comparator.comparing((RankFeatureDoc doc) -> doc.score).reversed()) | ||
.toArray(RankFeatureDoc[]::new); | ||
private Integer getTopNFromInferenceEndpointTaskSettings() { | ||
GetInferenceModelAction.Request request = new GetInferenceModelAction.Request(inferenceId, TaskType.RERANK); | ||
ActionFuture<GetInferenceModelAction.Response> response = client.execute(GetInferenceModelAction.INSTANCE, request); | ||
ModelConfigurations modelConfigurations = response.actionGet().getEndpoints().get(0); | ||
|
||
if (modelConfigurations.getService().equals(CohereService.NAME) | ||
&& modelConfigurations.getTaskType().equals(TaskType.RERANK) | ||
&& modelConfigurations.getTaskSettings() instanceof CohereRerankTaskSettings) { | ||
return ((CohereRerankTaskSettings) modelConfigurations.getTaskSettings()).getTopNDocumentsOnly(); | ||
} else if (modelConfigurations.getService().equals(GoogleVertexAiService.NAME) | ||
&& modelConfigurations.getTaskType().equals(TaskType.RERANK) | ||
&& modelConfigurations.getTaskSettings() instanceof GoogleVertexAiRerankTaskSettings) { | ||
return ((GoogleVertexAiRerankTaskSettings) modelConfigurations.getTaskSettings()).topN(); | ||
} | ||
|
||
return null; | ||
} | ||
|
||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should definitely keep this asynchronous. Rule of thumb, never use an ActionFuture or make a async call synchronous when you can easily avoid it.
It should be something like this (obviously with things better filled in)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@benwtrent Thanks for the code snippet, I updated it.
One question regarding your comment:
Isn't it sufficient to just delegate to the outer listener, i.e.
scoreListener.onFailure(f)
?