Skip to content

Fix score count validation in reranker response #111424

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/changelog/111424.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 111424
summary: Fix score count validation in reranker response
area: Ranking
type: bug
issues:
- 111202
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.search.rank.context.RankFeaturePhaseRankCoordinatorContext;
import org.elasticsearch.search.rank.feature.RankFeatureDoc;
import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings;
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankTaskSettings;

import java.util.Arrays;
import java.util.Comparator;
Expand Down Expand Up @@ -53,24 +56,77 @@ public TextSimilarityRankFeaturePhaseRankCoordinatorContext(
protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener<float[]> scoreListener) {
// Wrap the provided rankListener to an ActionListener that would handle the response from the inference service
// and then pass the results
final ActionListener<InferenceAction.Response> actionListener = scoreListener.delegateFailureAndWrap((l, r) -> {
float[] scores = extractScoresFromResponse(r);
if (scores.length != featureDocs.length) {
final ActionListener<InferenceAction.Response> inferenceListener = scoreListener.delegateFailureAndWrap((l, r) -> {
InferenceServiceResults results = r.getResults();
assert results instanceof RankedDocsResults;

// Ensure we get exactly as many scores as the number of docs we passed, otherwise we may return incorrect results
List<RankedDocsResults.RankedDoc> rankedDocs = ((RankedDocsResults) results).getRankedDocs();
if (rankedDocs.size() != featureDocs.length) {
l.onFailure(
new IllegalStateException("Document and score count mismatch: [" + featureDocs.length + "] vs [" + scores.length + "]")
new IllegalStateException(
"Reranker input document count and returned score count mismatch: ["
+ featureDocs.length
+ "] vs ["
+ rankedDocs.size()
+ "]"
)
);
} else {
float[] scores = extractScoresFromRankedDocs(rankedDocs);
l.onResponse(scores);
}
});

List<String> featureData = Arrays.stream(featureDocs).map(x -> x.featureData).toList();
InferenceAction.Request request = generateRequest(featureData);
try {
client.execute(InferenceAction.INSTANCE, request, actionListener);
} finally {
request.decRef();
}
// top N listener
ActionListener<GetInferenceModelAction.Response> topNListener = scoreListener.delegateFailureAndWrap((l, r) -> {
// The rerank inference endpoint may have an override to return top N documents only, in that case let's fail fast to avoid
// assigning scores to the wrong input
Integer configuredTopN = null;
if (r.getEndpoints().isEmpty() == false
&& r.getEndpoints().get(0).getTaskSettings() instanceof CohereRerankTaskSettings cohereTaskSettings) {
configuredTopN = cohereTaskSettings.getTopNDocumentsOnly();
} else if (r.getEndpoints().isEmpty() == false
&& r.getEndpoints().get(0).getTaskSettings() instanceof GoogleVertexAiRerankTaskSettings googleVertexAiTaskSettings) {
configuredTopN = googleVertexAiTaskSettings.topN();
}
if (configuredTopN != null && configuredTopN < rankWindowSize) {
l.onFailure(
new IllegalArgumentException(
"Inference endpoint ["
+ inferenceId
+ "] is configured to return the top ["
+ configuredTopN
+ "] results, but rank_window_size is ["
+ rankWindowSize
+ "]. Reduce rank_window_size to be less than or equal to the configured top N value."
)
);
return;
}
List<String> featureData = Arrays.stream(featureDocs).map(x -> x.featureData).toList();
InferenceAction.Request inferenceRequest = generateRequest(featureData);
try {
client.execute(InferenceAction.INSTANCE, inferenceRequest, inferenceListener);
} finally {
inferenceRequest.decRef();
}
});

GetInferenceModelAction.Request getModelRequest = new GetInferenceModelAction.Request(inferenceId, TaskType.RERANK);
client.execute(GetInferenceModelAction.INSTANCE, getModelRequest, topNListener);
}

/**
* Sorts documents by score descending and discards those with a score less than minScore.
* @param originalDocs documents to process
*/
@Override
protected RankFeatureDoc[] preprocess(RankFeatureDoc[] originalDocs) {
return Arrays.stream(originalDocs)
.filter(doc -> minScore == null || doc.score >= minScore)
.sorted(Comparator.comparing((RankFeatureDoc doc) -> doc.score).reversed())
.toArray(RankFeatureDoc[]::new);
}

protected InferenceAction.Request generateRequest(List<String> docFeatures) {
Expand All @@ -85,28 +141,12 @@ protected InferenceAction.Request generateRequest(List<String> docFeatures) {
);
}

private float[] extractScoresFromResponse(InferenceAction.Response response) {
InferenceServiceResults results = response.getResults();
assert results instanceof RankedDocsResults;

List<RankedDocsResults.RankedDoc> rankedDocs = ((RankedDocsResults) results).getRankedDocs();
private float[] extractScoresFromRankedDocs(List<RankedDocsResults.RankedDoc> rankedDocs) {
float[] scores = new float[rankedDocs.size()];
for (RankedDocsResults.RankedDoc rankedDoc : rankedDocs) {
scores[rankedDoc.index()] = rankedDoc.relevanceScore();
}

return scores;
}

/**
* Sorts documents by score descending and discards those with a score less than minScore.
* @param originalDocs documents to process
*/
@Override
protected RankFeatureDoc[] preprocess(RankFeatureDoc[] originalDocs) {
return Arrays.stream(originalDocs)
.filter(doc -> minScore == null || doc.score >= minScore)
.sorted(Comparator.comparing((RankFeatureDoc doc) -> doc.score).reversed())
.toArray(RankFeatureDoc[]::new);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.search.rank.feature.RankFeatureDoc;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.argThat;
Expand Down Expand Up @@ -54,10 +54,9 @@ public void onFailure(Exception e) {
fail();
}
});

verify(mockClient).execute(
eq(InferenceAction.INSTANCE),
argThat(actionRequest -> ((InferenceAction.Request) actionRequest).getTaskType().equals(TaskType.RERANK)),
eq(GetInferenceModelAction.INSTANCE),
argThat(actionRequest -> ((GetInferenceModelAction.Request) actionRequest).getTaskType().equals(TaskType.RERANK)),
any()
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

package org.elasticsearch.xpack.inference.rank.textsimilarity;

import org.elasticsearch.action.search.SearchPhaseExecutionException;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.inference.InputType;
Expand All @@ -29,22 +30,46 @@
import java.util.Objects;

import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;

public class TextSimilarityRankTests extends ESSingleNodeTestCase {

/**
* {@code TextSimilarityRankBuilder} that simulates an inference call that returns a different number of results as the input.
* {@code TextSimilarityRankBuilder} that sets top_n in the inference endpoint's task settings.
* See {@code TextSimilarityTestPlugin -> TestFilter -> handleGetInferenceModelActionRequest} for the logic that extracts the top_n
* value.
*/
public static class InvalidInferenceResultCountProvidingTextSimilarityRankBuilder extends TextSimilarityRankBuilder {
public static class TopNConfigurationAcceptingTextSimilarityRankBuilder extends TextSimilarityRankBuilder {

public InvalidInferenceResultCountProvidingTextSimilarityRankBuilder(
public TopNConfigurationAcceptingTextSimilarityRankBuilder(
String field,
String inferenceId,
String inferenceText,
int rankWindowSize,
Float minScore
Float minScore,
int topN
) {
super(field, inferenceId + "-task-settings-top-" + topN, inferenceText, rankWindowSize, minScore);
}
}

/**
* {@code TextSimilarityRankBuilder} that simulates an inference call returning N results.
*/
public static class InferenceResultCountAcceptingTextSimilarityRankBuilder extends TextSimilarityRankBuilder {

private final int inferenceResultCount;

public InferenceResultCountAcceptingTextSimilarityRankBuilder(
String field,
String inferenceId,
String inferenceText,
int rankWindowSize,
Float minScore,
int inferenceResultCount
) {
super(field, inferenceId, inferenceText, rankWindowSize, minScore);
this.inferenceResultCount = inferenceResultCount;
}

@Override
Expand All @@ -62,10 +87,10 @@ public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorCo
protected InferenceAction.Request generateRequest(List<String> docFeatures) {
return new InferenceAction.Request(
TaskType.RERANK,
inferenceId,
this.inferenceId,
inferenceText,
docFeatures,
Map.of("invalidInferenceResultCount", true),
Map.of("inferenceResultCount", inferenceResultCount),
InputType.SEARCH,
InferenceAction.Request.DEFAULT_TIMEOUT
);
Expand Down Expand Up @@ -151,17 +176,38 @@ public void testRerankInferenceFailure() {
);
}

public void testRerankInferenceResultMismatch() {
ElasticsearchAssertions.assertFailures(
public void testRerankTopNConfigurationAndRankWindowSizeMismatch() {
SearchPhaseExecutionException ex = expectThrows(
SearchPhaseExecutionException.class,
// Execute search with text similarity reranking
client.prepareSearch()
.setRankBuilder(
new InvalidInferenceResultCountProvidingTextSimilarityRankBuilder("text", "my-rerank-model", "my query", 100, 1.5f)
// Simulate reranker configuration with top_n=3 in task_settings, which is different from rank_window_size=10
// (Note: top_n comes from inferenceId, there's no other easy way of passing this to the mocked get model request)
new TopNConfigurationAcceptingTextSimilarityRankBuilder("text", "my-rerank-model", "my query", 100, 1.5f, 3)
)
.setQuery(QueryBuilders.matchAllQuery()),
RestStatus.INTERNAL_SERVER_ERROR,
containsString("Failed to execute phase [rank-feature], Computing updated ranks for results failed")
.setQuery(QueryBuilders.matchAllQuery())
);
assertThat(ex.status(), equalTo(RestStatus.BAD_REQUEST));
assertThat(
ex.getDetailedMessage(),
containsString("Reduce rank_window_size to be less than or equal to the configured top N value")
);
}

public void testRerankInputSizeAndInferenceResultsMismatch() {
SearchPhaseExecutionException ex = expectThrows(
SearchPhaseExecutionException.class,
// Execute search with text similarity reranking
client.prepareSearch()
.setRankBuilder(
// Simulate reranker returning different number of results from input
new InferenceResultCountAcceptingTextSimilarityRankBuilder("text", "my-rerank-model", "my query", 100, 1.5f, 4)
)
.setQuery(QueryBuilders.matchAllQuery())
);
assertThat(ex.status(), equalTo(RestStatus.INTERNAL_SERVER_ERROR));
assertThat(ex.getDetailedMessage(), containsString("Reranker input document count and returned score count mismatch"));
}

private static void assertHitHasRankScoreAndText(SearchHit hit, int expectedRank, float expectedScore, String expectedText) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.inference.EmptyTaskSettings;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.plugins.ActionPlugin;
import org.elasticsearch.plugins.Plugin;
Expand All @@ -39,15 +41,22 @@
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
import org.elasticsearch.xpack.inference.services.cohere.CohereService;
import org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings;
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankServiceSettings;
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import static java.util.Collections.singletonList;
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
Expand Down Expand Up @@ -100,31 +109,66 @@ public int order() {
}

@Override
@SuppressWarnings("unchecked")
public <Request extends ActionRequest, Response extends ActionResponse> void apply(
Task task,
String action,
Request request,
ActionListener<Response> listener,
ActionFilterChain<Request, Response> chain
) {
// For any other action than inference, execute normally
if (action.equals(InferenceAction.INSTANCE.name()) == false) {
if (action.equals(GetInferenceModelAction.INSTANCE.name())) {
assert request instanceof GetInferenceModelAction.Request;
handleGetInferenceModelActionRequest((GetInferenceModelAction.Request) request, listener);
} else if (action.equals(InferenceAction.INSTANCE.name())) {
assert request instanceof InferenceAction.Request;
handleInferenceActionRequest((InferenceAction.Request) request, listener);
} else {
// For any other action than get model and inference, execute normally
chain.proceed(task, action, request, listener);
return;
}
}

assert request instanceof InferenceAction.Request;
boolean shouldThrow = (boolean) ((InferenceAction.Request) request).getTaskSettings().getOrDefault("throwing", false);
boolean hasInvalidInferenceResultCount = (boolean) ((InferenceAction.Request) request).getTaskSettings()
.getOrDefault("invalidInferenceResultCount", false);
@SuppressWarnings("unchecked")
private <Response extends ActionResponse> void handleGetInferenceModelActionRequest(
GetInferenceModelAction.Request request,
ActionListener<Response> listener
) {
String inferenceEntityId = request.getInferenceEntityId();
Integer topN = null;
Matcher extractTopN = Pattern.compile(".*(task-settings-top-\\d+).*").matcher(inferenceEntityId);
if (extractTopN.find()) {
topN = Integer.parseInt(extractTopN.group(1).replaceAll("\\D", ""));
}

ActionResponse response = new GetInferenceModelAction.Response(
List.of(
new ModelConfigurations(
request.getInferenceEntityId(),
request.getTaskType(),
CohereService.NAME,
new CohereRerankServiceSettings(new CohereServiceSettings()),
Copy link
Contributor Author

@demjened demjened Jul 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to modify this line because this constructor doesn't exist on 8.15. There's no difference in the behavior of the test plugin.

topN == null ? new EmptyTaskSettings() : new CohereRerankTaskSettings(topN, null, null)
)
)
);
listener.onResponse((Response) response);
}

@SuppressWarnings("unchecked")
private <Response extends ActionResponse> void handleInferenceActionRequest(
InferenceAction.Request request,
ActionListener<Response> listener
) {
Map<String, Object> taskSettings = request.getTaskSettings();
boolean shouldThrow = (boolean) taskSettings.getOrDefault("throwing", false);
Integer inferenceResultCount = (Integer) taskSettings.get("inferenceResultCount");

if (shouldThrow) {
listener.onFailure(new UnsupportedOperationException("simulated failure"));
} else {
List<RankedDocsResults.RankedDoc> rankedDocsResults = new ArrayList<>();
List<String> inputs = ((InferenceAction.Request) request).getInput();
int resultCount = hasInvalidInferenceResultCount ? inputs.size() - 1 : inputs.size();
List<String> inputs = request.getInput();
int resultCount = inferenceResultCount == null ? inputs.size() : inferenceResultCount;
for (int i = 0; i < resultCount; i++) {
rankedDocsResults.add(new RankedDocsResults.RankedDoc(i, Float.parseFloat(inputs.get(i)), inputs.get(i)));
}
Expand Down