Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix score count validation in reranker response #111212

Merged
6 changes: 6 additions & 0 deletions docs/changelog/111212.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 111212
summary: Fix score count validation in reranker response
area: Ranking
type: bug
issues:
- 111202
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,24 @@ protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener<float[
// Wrap the provided rankListener to an ActionListener that would handle the response from the inference service
// and then pass the results
final ActionListener<InferenceAction.Response> actionListener = scoreListener.delegateFailureAndWrap((l, r) -> {
float[] scores = extractScoresFromResponse(r);
if (scores.length != featureDocs.length) {
InferenceServiceResults results = r.getResults();
assert results instanceof RankedDocsResults;

// Ensure we get exactly as many scores as the number of docs we passed, otherwise we may return incorrect results
List<RankedDocsResults.RankedDoc> rankedDocs = ((RankedDocsResults) results).getRankedDocs();
if (rankedDocs.size() != featureDocs.length) {
l.onFailure(
new IllegalStateException("Document and score count mismatch: [" + featureDocs.length + "] vs [" + scores.length + "]")
new IllegalStateException(
Copy link
Contributor

@pmpailis pmpailis Jul 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we also add a test that throws this exception instead of the IOOB ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@benwtrent @pmpailis I added a test case for simulating invalid indices (7e3f153).

Note I don't think it actually covers the use case, because it's a sub-case of the reranker input/output count mismatch. With the bugfix this is now caught before the actual doc index assignment happens that would trigger the IOOB. (We do not have specific handling of N inputs -> N outputs -> index pointing outside 0..N-1, I'd consider that a reranker error.)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@demjened AH, ok, is there a way to parse this top_n setting and validate that it matches the window size earlier in the request? That seems like we should return "Bad Request" in that scenario.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@benwtrent Not without issuing a GET call before each inference call to check the inference endpoint's configuration, I'm afraid. The rerank retriever framework only exposes hooks for creation and submission of an inference request, and parsing the results, but it doesn't know about the config.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not without issuing a GET call before each inference call to check the inference endpoint's configuration, I'm afraid... but it doesn't know about the config.

We need to fix this. If there are things that are invalid for a configuration & a search request, we need to fix this.

Calling a local GET to retrieve some minor documentation is way cheaper than calling some external service when the we know the request will fail. I would hold off on merging or progressing this until we can eagerly validate the search request as we know ahead of time that it won't work.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@benwtrent @Mikep86 @pmpailis I'll summarize the latest changes here that accommodates all your suggestions (thanks for those):

  • There is now handling logic for two failure scenarios: 1. we find out that task_settings.top_n < rank_window_size and preemptively fail before running the inference, 2. the reranker actually returns <>N output on N inputs. These are tested in testRerankTopNConfigurationAndRankWindowSizeMismatch() and testRerankInputSizeAndInferenceResultsMismatch() respectively.
  • The 1st scenario results in a 400, the 2nd one a 500 (as it's not a client-side error if the reranker misbehaves).
  • The extraction logic of top_n from the fetched inference endpoint config is vendor-specific. I'm not happy about this, but the empty TaskSettings interface doesn't give me much options to get the top N other than casting after a type check.
  • In the test class, the mock configuration for scenario 1 requires passing the expected top N as part of the inference ID and parsing it (i.e. my-inference-id-task-settings-top-3 -> top_n=3). Again I'm not proud of this implementation, but GetInferenceModelAction has a very limited interface to pass control parameters in in order to mock behavior in a test.
  • Reworded the error messages to make them more clear and actionable. Also the failure tests now check the exact message, based on Panos' suggestion.

Let me know if you feel there are still things to improve and must go in this PR; I want to timebox the remaining effort as the bug has been fixed and I want to make sure this gets merged by 8.15.0 BC5.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Heads up that @demjened & I discussed offline yesterday and we're going to investigate doing the task_settings.top_n < rank_window_size check in ML code

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Mikep86 I've been thinking about this; while moving this logic to the ML code would make it cleaner and would remove coupling, I'm worried it's not a good idea from a product perspective.

For reference, here's the code you suggested that would replace the preemptive top_n check from the reranker:

// CohereService#doInfer
...
        CohereModel cohereModel = (CohereModel) model;
        var actionCreator = new CohereActionCreator(getSender(), getServiceComponents());

        if (model instanceof CohereRerankModel rerankModel) {
            Integer topN = rerankModel.getTaskSettings().getTopNDocumentsOnly();
            if (topN != null && topN < input.size()) {
                listener.onFailure(new IllegalArgumentException("top_n < doc count"));
                return;
            }
        }
...

The problem is we're invalidating a normal use case. It is a valid scenario that a rerank endpoint is configured to return a maximum of N hits for >N inputs. It's only an issue from the reranker retriever's perspective, where the a model.top_k < retriever.rank_window_size can lead to partial or incorrect results.

So we have 3 options to implement the check:

  1. (Current) Fetch task settings and validate in retriever context. Problem: coupling with ML code, need to maintain.
  2. (Proposal in this comment) Move validation to XxxService-s in ML code. Problem: invalidating valid use case.
  3. Maybe a control flag to toggle this validation in ML code - doInfer(..., inputSizeMustMatchTopN) + the code above? Problem: complex, requires refactoring of the interface and transport objects.

WDYT?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is valid for the the reranker API to return fewer docs as that is the whole point of topN. Otherwise, why is it even configurable?

The retriever also sets window size. For now, we should enforce that window size is <= topN. Maybe we can adjust this in the future, but for retrievers, this is a sane default & safe behavior.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, wasn't thinking of rerank endpoint usage outside of this context.

"Document and score count mismatch: ["
+ featureDocs.length
+ "] vs ["
+ rankedDocs.size()
+ "]. Check your rerank inference endpoint configuration and ensure it returns rank_window_size scores for "
+ "rank_window_size input documents."
Mikep86 marked this conversation as resolved.
Show resolved Hide resolved
)
);
} else {
float[] scores = extractScoresFromRankedDocs(rankedDocs);
l.onResponse(scores);
}
});
Expand All @@ -85,11 +97,7 @@ protected InferenceAction.Request generateRequest(List<String> docFeatures) {
);
}

private float[] extractScoresFromResponse(InferenceAction.Response response) {
InferenceServiceResults results = response.getResults();
assert results instanceof RankedDocsResults;

List<RankedDocsResults.RankedDoc> rankedDocs = ((RankedDocsResults) results).getRankedDocs();
private float[] extractScoresFromRankedDocs(List<RankedDocsResults.RankedDoc> rankedDocs) {
float[] scores = new float[rankedDocs.size()];
for (RankedDocsResults.RankedDoc rankedDoc : rankedDocs) {
scores[rankedDoc.index()] = rankedDoc.relevanceScore();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,20 @@ public class TextSimilarityRankTests extends ESSingleNodeTestCase {
*/
public static class InvalidInferenceResultCountProvidingTextSimilarityRankBuilder extends TextSimilarityRankBuilder {

private boolean hasInvalidDocumentIndices = false;
Mikep86 marked this conversation as resolved.
Show resolved Hide resolved

public InvalidInferenceResultCountProvidingTextSimilarityRankBuilder(
String field,
String inferenceId,
String inferenceText,
int rankWindowSize,
Float minScore,
boolean hasInvalidDocumentIndices
) {
this(field, inferenceId, inferenceText, rankWindowSize, minScore);
this.hasInvalidDocumentIndices = hasInvalidDocumentIndices;
}

public InvalidInferenceResultCountProvidingTextSimilarityRankBuilder(
String field,
String inferenceId,
Expand Down Expand Up @@ -65,7 +79,7 @@ protected InferenceAction.Request generateRequest(List<String> docFeatures) {
inferenceId,
inferenceText,
docFeatures,
Map.of("invalidInferenceResultCount", true),
Map.of("invalidInferenceResultCount", true, "invalidDocumentIndices", hasInvalidDocumentIndices),
InputType.SEARCH,
InferenceAction.Request.DEFAULT_TIMEOUT
);
Expand Down Expand Up @@ -151,11 +165,12 @@ public void testRerankInferenceFailure() {
);
}

public void testRerankInferenceResultMismatch() {
public void testRerankInferenceResultCountMismatch() {
ElasticsearchAssertions.assertFailures(
// Execute search with text similarity reranking
client.prepareSearch()
.setRankBuilder(
// Simulate reranker returning different number of results from input
new InvalidInferenceResultCountProvidingTextSimilarityRankBuilder("text", "my-rerank-model", "my query", 100, 1.5f)
)
.setQuery(QueryBuilders.matchAllQuery()),
Expand All @@ -164,6 +179,27 @@ public void testRerankInferenceResultMismatch() {
);
}

public void testRerankInvalidDocumentIndices() {
Mikep86 marked this conversation as resolved.
Show resolved Hide resolved
ElasticsearchAssertions.assertFailures(
// Execute search with text similarity reranking
client.prepareSearch()
.setRankBuilder(
// Simulate reranker returning different number of results from input, also invalid document indices in results
new InvalidInferenceResultCountProvidingTextSimilarityRankBuilder(
"text",
"my-rerank-model",
"my query",
100,
1.5f,
true
)
)
.setQuery(QueryBuilders.matchAllQuery()),
RestStatus.INTERNAL_SERVER_ERROR,
containsString("Failed to execute phase [rank-feature], Computing updated ranks for results failed")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two things here:

  • We should return a 400 response here, the error is ultimately due to a bad value provided by the user
  • Can we test that the returned error message contains the root cause of the problem (i.e. "Document and score count mismatch")?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error response is produced here, which responds to an IllegalStateException. @pmpailis do you know of any way to turn this into a 400?

I'll look into using different matchers in the test to verify the cause.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should work on making the HTTP response code configurable. 500 == rest suppressed errors == things to investigate during serverless on-call

Copy link
Contributor Author

@demjened demjened Jul 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update: there's no logic in ElasticsearchAssertions#assertFailures that we could call to check the cause of a search phase exception 🙁 (in this case e.getCause()).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that we could slightly rewrite this to capture the cause's message (as there won't be any response to decref).
E.g:

        SearchPhaseExecutionException ex = expectThrows(
            SearchPhaseExecutionException.class,
            client.prepareSearch()
                .setRankBuilder(
                    // Simulate reranker returning different number of results from input, also invalid document indices in results
                    new InvalidInferenceResultCountProvidingTextSimilarityRankBuilder(
                        "text",
                        "my-rerank-model",
                        "my query",
                        100,
                        1.5f,
                        true
                    )
                )
                .setQuery(QueryBuilders.matchAllQuery())::get
        );
        assertThat(ex.getDetailedMessage(), containsString("Document and score count mismatch"));

Copy link
Contributor

@pmpailis pmpailis Jul 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pmpailis do you know of any way to turn this into a 400?

The status of the error is determined through the SearchPhaseExecutionException#status, which in turn uses the cause for the exception through org.elasticsearch.ExceptionsHelper#status .

I think converting the IllegalStateException to an IllegalArgumentException should be enough to have a 400 response.

);
}

private static void assertHitHasRankScoreAndText(SearchHit hit, int expectedRank, float expectedScore, String expectedText) {
assertEquals(expectedRank, hit.getRank());
assertEquals(expectedScore, hit.getScore(), 0.0f);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,10 @@ public <Request extends ActionRequest, Response extends ActionResponse> void app
}

assert request instanceof InferenceAction.Request;
boolean shouldThrow = (boolean) ((InferenceAction.Request) request).getTaskSettings().getOrDefault("throwing", false);
boolean hasInvalidInferenceResultCount = (boolean) ((InferenceAction.Request) request).getTaskSettings()
.getOrDefault("invalidInferenceResultCount", false);
Map<String, Object> taskSettings = ((InferenceAction.Request) request).getTaskSettings();
boolean shouldThrow = (boolean) taskSettings.getOrDefault("throwing", false);
boolean hasInvalidInferenceResultCount = (boolean) taskSettings.getOrDefault("invalidInferenceResultCount", false);
boolean hasInvalidDocumentIndices = (boolean) taskSettings.getOrDefault("invalidDocumentIndices", false);

if (shouldThrow) {
listener.onFailure(new UnsupportedOperationException("simulated failure"));
Expand All @@ -126,7 +127,13 @@ public <Request extends ActionRequest, Response extends ActionResponse> void app
List<String> inputs = ((InferenceAction.Request) request).getInput();
int resultCount = hasInvalidInferenceResultCount ? inputs.size() - 1 : inputs.size();
for (int i = 0; i < resultCount; i++) {
rankedDocsResults.add(new RankedDocsResults.RankedDoc(i, Float.parseFloat(inputs.get(i)), inputs.get(i)));
rankedDocsResults.add(
new RankedDocsResults.RankedDoc(
hasInvalidDocumentIndices ? i * 2 : i,
Float.parseFloat(inputs.get(i)),
inputs.get(i)
)
);
}
ActionResponse response = new InferenceAction.Response(new RankedDocsResults(rankedDocsResults));
listener.onResponse((Response) response);
Expand Down