Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix score count validation in reranker response #111212

Merged

Conversation

demjened
Copy link
Contributor

@demjened demjened commented Jul 23, 2024

Fixes #111202.

The text_similarity_reranker retriever query fails if rank_window_size is greater than top_n in the rerank inference endpoint's task settings. More details and steps to reproduce in the above mentioned issue.

The cause is this line. The rerank inference response contains document indices in reference to the rank_window_size inputs. However if task_settings.top_n is specified in the inference endpoint, it only returns the top N of those. For example with rank_window_size==20 and top_n==10 this is a valid response:

[RankedDoc{index='1', relevanceScore='0.9961155', text='null', hashcode=-1335808012},
RankedDoc{index='15', relevanceScore='0.9865199', text='null', hashcode=-1340785186},
RankedDoc{index='3', relevanceScore='0.049773447', text='null', hashcode=1815090117}, 
... (7 more RankedDoc-s)]

The problematic line creates a scores array with a length of 10, but processing the 2nd item (index='15') triggers the out-of-bounds exception.

The validation of rereanker input == score output is already in place (and covered by the TextSimilarityRankTests#testRerankInferenceFailure test), however due to the above bug the process never reaches that state, it fails with an index out of bounds exception. The solution is to move the validation logic before the scores are extracted.

@elasticsearchmachine elasticsearchmachine added the Team:Search Relevance Meta label for the Search Relevance team in Elasticsearch label Jul 23, 2024
@elasticsearchmachine
Copy link
Collaborator

Pinging @elastic/es-search-relevance (Team:Search Relevance)

@elasticsearchmachine
Copy link
Collaborator

Hi @demjened, I've created a changelog YAML for you.

l.onFailure(
new IllegalStateException("Document and score count mismatch: [" + featureDocs.length + "] vs [" + scores.length + "]")
new IllegalStateException(
Copy link
Contributor

@pmpailis pmpailis Jul 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we also add a test that throws this exception instead of the IOOB ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@benwtrent @pmpailis I added a test case for simulating invalid indices (7e3f153).

Note I don't think it actually covers the use case, because it's a sub-case of the reranker input/output count mismatch. With the bugfix this is now caught before the actual doc index assignment happens that would trigger the IOOB. (We do not have specific handling of N inputs -> N outputs -> index pointing outside 0..N-1, I'd consider that a reranker error.)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@demjened AH, ok, is there a way to parse this top_n setting and validate that it matches the window size earlier in the request? That seems like we should return "Bad Request" in that scenario.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@benwtrent Not without issuing a GET call before each inference call to check the inference endpoint's configuration, I'm afraid. The rerank retriever framework only exposes hooks for creation and submission of an inference request, and parsing the results, but it doesn't know about the config.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not without issuing a GET call before each inference call to check the inference endpoint's configuration, I'm afraid... but it doesn't know about the config.

We need to fix this. If there are things that are invalid for a configuration & a search request, we need to fix this.

Calling a local GET to retrieve some minor documentation is way cheaper than calling some external service when the we know the request will fail. I would hold off on merging or progressing this until we can eagerly validate the search request as we know ahead of time that it won't work.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@benwtrent @Mikep86 @pmpailis I'll summarize the latest changes here that accommodates all your suggestions (thanks for those):

  • There is now handling logic for two failure scenarios: 1. we find out that task_settings.top_n < rank_window_size and preemptively fail before running the inference, 2. the reranker actually returns <>N output on N inputs. These are tested in testRerankTopNConfigurationAndRankWindowSizeMismatch() and testRerankInputSizeAndInferenceResultsMismatch() respectively.
  • The 1st scenario results in a 400, the 2nd one a 500 (as it's not a client-side error if the reranker misbehaves).
  • The extraction logic of top_n from the fetched inference endpoint config is vendor-specific. I'm not happy about this, but the empty TaskSettings interface doesn't give me much options to get the top N other than casting after a type check.
  • In the test class, the mock configuration for scenario 1 requires passing the expected top N as part of the inference ID and parsing it (i.e. my-inference-id-task-settings-top-3 -> top_n=3). Again I'm not proud of this implementation, but GetInferenceModelAction has a very limited interface to pass control parameters in in order to mock behavior in a test.
  • Reworded the error messages to make them more clear and actionable. Also the failure tests now check the exact message, based on Panos' suggestion.

Let me know if you feel there are still things to improve and must go in this PR; I want to timebox the remaining effort as the bug has been fixed and I want to make sure this gets merged by 8.15.0 BC5.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Heads up that @demjened & I discussed offline yesterday and we're going to investigate doing the task_settings.top_n < rank_window_size check in ML code

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Mikep86 I've been thinking about this; while moving this logic to the ML code would make it cleaner and would remove coupling, I'm worried it's not a good idea from a product perspective.

For reference, here's the code you suggested that would replace the preemptive top_n check from the reranker:

// CohereService#doInfer
...
        CohereModel cohereModel = (CohereModel) model;
        var actionCreator = new CohereActionCreator(getSender(), getServiceComponents());

        if (model instanceof CohereRerankModel rerankModel) {
            Integer topN = rerankModel.getTaskSettings().getTopNDocumentsOnly();
            if (topN != null && topN < input.size()) {
                listener.onFailure(new IllegalArgumentException("top_n < doc count"));
                return;
            }
        }
...

The problem is we're invalidating a normal use case. It is a valid scenario that a rerank endpoint is configured to return a maximum of N hits for >N inputs. It's only an issue from the reranker retriever's perspective, where the a model.top_k < retriever.rank_window_size can lead to partial or incorrect results.

So we have 3 options to implement the check:

  1. (Current) Fetch task settings and validate in retriever context. Problem: coupling with ML code, need to maintain.
  2. (Proposal in this comment) Move validation to XxxService-s in ML code. Problem: invalidating valid use case.
  3. Maybe a control flag to toggle this validation in ML code - doInfer(..., inputSizeMustMatchTopN) + the code above? Problem: complex, requires refactoring of the interface and transport objects.

WDYT?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is valid for the the reranker API to return fewer docs as that is the whole point of topN. Otherwise, why is it even configurable?

The retriever also sets window size. For now, we should enforce that window size is <= topN. Maybe we can adjust this in the future, but for retrievers, this is a sane default & safe behavior.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, wasn't thinking of rerank endpoint usage outside of this context.

Copy link
Member

@benwtrent benwtrent left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch! Let's have a test that would fail with an IOOB exception without this fix and then throws the better exception now.

@Mikep86
Copy link
Contributor

Mikep86 commented Jul 24, 2024

Is it OK if top_n > rank_window_size?

Copy link
Contributor

@Mikep86 Mikep86 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch on the bug!

Comment on lines 198 to 199
RestStatus.INTERNAL_SERVER_ERROR,
containsString("Failed to execute phase [rank-feature], Computing updated ranks for results failed")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two things here:

  • We should return a 400 response here, the error is ultimately due to a bad value provided by the user
  • Can we test that the returned error message contains the root cause of the problem (i.e. "Document and score count mismatch")?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error response is produced here, which responds to an IllegalStateException. @pmpailis do you know of any way to turn this into a 400?

I'll look into using different matchers in the test to verify the cause.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should work on making the HTTP response code configurable. 500 == rest suppressed errors == things to investigate during serverless on-call

Copy link
Contributor Author

@demjened demjened Jul 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update: there's no logic in ElasticsearchAssertions#assertFailures that we could call to check the cause of a search phase exception 🙁 (in this case e.getCause()).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that we could slightly rewrite this to capture the cause's message (as there won't be any response to decref).
E.g:

        SearchPhaseExecutionException ex = expectThrows(
            SearchPhaseExecutionException.class,
            client.prepareSearch()
                .setRankBuilder(
                    // Simulate reranker returning different number of results from input, also invalid document indices in results
                    new InvalidInferenceResultCountProvidingTextSimilarityRankBuilder(
                        "text",
                        "my-rerank-model",
                        "my query",
                        100,
                        1.5f,
                        true
                    )
                )
                .setQuery(QueryBuilders.matchAllQuery())::get
        );
        assertThat(ex.getDetailedMessage(), containsString("Document and score count mismatch"));

Copy link
Contributor

@pmpailis pmpailis Jul 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pmpailis do you know of any way to turn this into a 400?

The status of the error is determined through the SearchPhaseExecutionException#status, which in turn uses the cause for the exception through org.elasticsearch.ExceptionsHelper#status .

I think converting the IllegalStateException to an IllegalArgumentException should be enough to have a 400 response.

Copy link
Member

@benwtrent benwtrent left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking way better!

Let's keep it async, but this is looking nice.

Comment on lines 142 to 150
GetInferenceModelAction.Request request = new GetInferenceModelAction.Request(inferenceId, TaskType.RERANK);
ActionFuture<GetInferenceModelAction.Response> response = client.execute(GetInferenceModelAction.INSTANCE, request);
ModelConfigurations modelConfigurations = response.actionGet().getEndpoints().get(0);

return modelConfigurations.getService().equals(CohereService.NAME)
&& modelConfigurations.getTaskType().equals(TaskType.RERANK)
&& modelConfigurations.getTaskSettings() instanceof CohereRerankTaskSettings
? ((CohereRerankTaskSettings) modelConfigurations.getTaskSettings()).getTopNDocumentsOnly()
: null;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets not use an "ActionFuture" this will take a thread and block on it. We should instead keep it all asynchronous.

Comment on lines 61 to 73
Integer inferenceEndpointTopN = getTopNFromInferenceEndpointTaskSettings();
if (inferenceEndpointTopN != null && inferenceEndpointTopN < featureDocs.length) {
scoreListener.onFailure(
new IllegalArgumentException(
"Inference endpoint ["
+ inferenceId
+ "] is configured to return the top ["
+ inferenceEndpointTopN
+ "] results. Reduce rank_window_size to be less than or equal to this value."
)
);
return;
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should definitely keep this asynchronous. Rule of thumb, never use an ActionFuture or make a async call synchronous when you can easily avoid it.

It should be something like this (obviously with things better filled in)

        // The rerank inference endpoint may have an override to return top N documents only, in that case let's fail fast to avoid
        // assigning scores to the wrong inputs
        // rank Listener
        final ActionListener<InferenceAction.Response> actionListener = scoreListener.delegateFailureAndWrap((l, r) -> {
            // rank task
        }); 

        // top N listener
        GetInferenceModelAction.Request request = new GetInferenceModelAction.Request(inferenceId, TaskType.RERANK, );
        ActionListener<GetInferenceModelAction.Response> topNListener = ActionListener.wrap(
            topN -> {
                Integer topNDocs = null;
                if (topN.getEndpoints().get(0).getTaskSettings() instanceof CohereRerankTaskSettings cohere) {
                    topNDocs = cohere.getTopNDocumentsOnly();
                } else if (topN.getEndpoints().get(0).getTaskSettings() instanceof GoogleVertexAiRerankTaskSettings googl) {
                    topNDocs = googl.topN();
                }
                if (topNDocs != null && topNDocs < featureDocs.length) {
                    scoreListener.onFailure(
                        new IllegalArgumentException(
                            "Inference endpoint ["
                                + inferenceId
                                + "] is configured to return the top ["
                                + topNDocs 
                                // indicate what the rank window size is
                                + "] results. Reduce rank_window_size to be less than or equal to this value."
                        )
                    );
                    return;
                }
                List<String> featureData = Arrays.stream(featureDocs).map(x -> x.featureData).toList();
                InferenceAction.Request inferenceRequest = generateRequest(featureData);
                try {
                    client.execute(InferenceAction.INSTANCE, inferenceRequest, actionListener);
                } finally {
                    request.decRef();
                }
            },
            f -> {
                // throw appropriate errors.
            }
        );
client.execute(GetInferenceModelAction.INSTANCE, request, topNListener);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@benwtrent Thanks for the code snippet, I updated it.

One question regarding your comment:

            f -> {
                // throw appropriate errors.
            }

Isn't it sufficient to just delegate to the outer listener, i.e. scoreListener.onFailure(f)?

Copy link
Contributor

@Mikep86 Mikep86 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work, I like the new approach 🙌

l.onFailure(
new IllegalStateException("Document and score count mismatch: [" + featureDocs.length + "] vs [" + scores.length + "]")
new IllegalStateException(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, wasn't thinking of rerank endpoint usage outside of this context.

@demjened demjened requested a review from benwtrent July 26, 2024 20:37
@demjened demjened merged commit c722ceb into elastic:main Jul 29, 2024
15 checks passed
@demjened demjened deleted the demjened/fix-check-rerank-score-vs-input branch July 29, 2024 19:03
@demjened demjened added auto-backport Automatically create backport pull requests when merged and removed auto-backport Automatically create backport pull requests when merged labels Jul 29, 2024
weizijun added a commit to weizijun/elasticsearch that referenced this pull request Jul 30, 2024
* upstream/main: (105 commits)
  Removing the use of watcher stats from WatchAcTests (elastic#111435)
  Mute org.elasticsearch.xpack.restart.FullClusterRestartIT testSingleDoc {cluster=UPGRADED} elastic#111434
  Make `EnrichPolicyRunner` more properly async (elastic#111321)
  Mute org.elasticsearch.xpack.restart.FullClusterRestartIT testSingleDoc {cluster=OLD} elastic#111430
  Mute org.elasticsearch.xpack.esql.expression.function.aggregate.ValuesTests testGroupingAggregate {TestCase=<long unicode KEYWORDs>} elastic#111428
  Mute org.elasticsearch.xpack.esql.expression.function.aggregate.ValuesTests testGroupingAggregate {TestCase=<long unicode TEXTs>} elastic#111429
  Mute org.elasticsearch.xpack.repositories.metering.azure.AzureRepositoriesMeteringIT org.elasticsearch.xpack.repositories.metering.azure.AzureRepositoriesMeteringIT elastic#111307
  Update semantic_text field to support indexing numeric and boolean data types (elastic#111284)
  Mute org.elasticsearch.repositories.blobstore.testkit.AzureSnapshotRepoTestKitIT testRepositoryAnalysis elastic#111280
  Ensure vector similarity correctly limits inner_hits returned for nested kNN (elastic#111363)
  Fix LogsIndexModeFullClusterRestartIT (elastic#111362)
  Remove 4096 bool query max limit from docs (elastic#111421)
  Fix score count validation in reranker response (elastic#111212)
  Integrate data generator in LogsDB mode challenge test (elastic#111303)
  ESQL: Add COUNT and COUNT_DISTINCT aggregation tests (elastic#111409)
  [Service Account] Add AutoOps account (elastic#111316)
  [ML] Fix failing test DetectionRulesTests.testEqualsAndHashcode (elastic#111351)
  [ML] Create and inject APM Inference Metrics (elastic#111293)
  [DOCS] Additional reranking docs updates (elastic#111350)
  Mute org.elasticsearch.repositories.azure.RepositoryAzureClientYamlTestSuiteIT org.elasticsearch.repositories.azure.RepositoryAzureClientYamlTestSuiteIT elastic#111345
  ...

# Conflicts:
#	server/src/main/java/org/elasticsearch/TransportVersions.java
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
>bug :Search Relevance/Ranking Scoring, rescoring, rank evaluation. Team:Search Relevance Meta label for the Search Relevance team in Elasticsearch v8.15.0 v8.15.1 v8.16.0
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Reranker retriever query fails if window size > top N in inference endpoint
5 participants