Skip to content

Commit

Permalink
Fix score count validation in reranker response (#111212)
Browse files Browse the repository at this point in the history
* Fix rerank score validation

* Update docs/changelog/111212.yaml

* Add test case for invalid document indices in reranker result

* Preemptive top_n config check

* Reorg code + refine tests

* Add support for Google Vertex AI task settings

* Spotless

* Make top N eval async

* Update test

* Fix broken unit test

* Clean up tests

* Spotless

* Add size check + compare against rankWindowSize

* Fix import
  • Loading branch information
demjened authored Jul 29, 2024
1 parent a4e6cf9 commit c722ceb
Show file tree
Hide file tree
Showing 5 changed files with 188 additions and 54 deletions.
6 changes: 6 additions & 0 deletions docs/changelog/111212.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 111212
summary: Fix score count validation in reranker response
area: Ranking
type: bug
issues:
- 111202
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.search.rank.context.RankFeaturePhaseRankCoordinatorContext;
import org.elasticsearch.search.rank.feature.RankFeatureDoc;
import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings;
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankTaskSettings;

import java.util.Arrays;
import java.util.Comparator;
Expand Down Expand Up @@ -53,24 +56,77 @@ public TextSimilarityRankFeaturePhaseRankCoordinatorContext(
protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener<float[]> scoreListener) {
// Wrap the provided rankListener to an ActionListener that would handle the response from the inference service
// and then pass the results
final ActionListener<InferenceAction.Response> actionListener = scoreListener.delegateFailureAndWrap((l, r) -> {
float[] scores = extractScoresFromResponse(r);
if (scores.length != featureDocs.length) {
final ActionListener<InferenceAction.Response> inferenceListener = scoreListener.delegateFailureAndWrap((l, r) -> {
InferenceServiceResults results = r.getResults();
assert results instanceof RankedDocsResults;

// Ensure we get exactly as many scores as the number of docs we passed, otherwise we may return incorrect results
List<RankedDocsResults.RankedDoc> rankedDocs = ((RankedDocsResults) results).getRankedDocs();
if (rankedDocs.size() != featureDocs.length) {
l.onFailure(
new IllegalStateException("Document and score count mismatch: [" + featureDocs.length + "] vs [" + scores.length + "]")
new IllegalStateException(
"Reranker input document count and returned score count mismatch: ["
+ featureDocs.length
+ "] vs ["
+ rankedDocs.size()
+ "]"
)
);
} else {
float[] scores = extractScoresFromRankedDocs(rankedDocs);
l.onResponse(scores);
}
});

List<String> featureData = Arrays.stream(featureDocs).map(x -> x.featureData).toList();
InferenceAction.Request request = generateRequest(featureData);
try {
client.execute(InferenceAction.INSTANCE, request, actionListener);
} finally {
request.decRef();
}
// top N listener
ActionListener<GetInferenceModelAction.Response> topNListener = scoreListener.delegateFailureAndWrap((l, r) -> {
// The rerank inference endpoint may have an override to return top N documents only, in that case let's fail fast to avoid
// assigning scores to the wrong input
Integer configuredTopN = null;
if (r.getEndpoints().isEmpty() == false
&& r.getEndpoints().get(0).getTaskSettings() instanceof CohereRerankTaskSettings cohereTaskSettings) {
configuredTopN = cohereTaskSettings.getTopNDocumentsOnly();
} else if (r.getEndpoints().isEmpty() == false
&& r.getEndpoints().get(0).getTaskSettings() instanceof GoogleVertexAiRerankTaskSettings googleVertexAiTaskSettings) {
configuredTopN = googleVertexAiTaskSettings.topN();
}
if (configuredTopN != null && configuredTopN < rankWindowSize) {
l.onFailure(
new IllegalArgumentException(
"Inference endpoint ["
+ inferenceId
+ "] is configured to return the top ["
+ configuredTopN
+ "] results, but rank_window_size is ["
+ rankWindowSize
+ "]. Reduce rank_window_size to be less than or equal to the configured top N value."
)
);
return;
}
List<String> featureData = Arrays.stream(featureDocs).map(x -> x.featureData).toList();
InferenceAction.Request inferenceRequest = generateRequest(featureData);
try {
client.execute(InferenceAction.INSTANCE, inferenceRequest, inferenceListener);
} finally {
inferenceRequest.decRef();
}
});

GetInferenceModelAction.Request getModelRequest = new GetInferenceModelAction.Request(inferenceId, TaskType.RERANK);
client.execute(GetInferenceModelAction.INSTANCE, getModelRequest, topNListener);
}

/**
* Sorts documents by score descending and discards those with a score less than minScore.
* @param originalDocs documents to process
*/
@Override
protected RankFeatureDoc[] preprocess(RankFeatureDoc[] originalDocs) {
return Arrays.stream(originalDocs)
.filter(doc -> minScore == null || doc.score >= minScore)
.sorted(Comparator.comparing((RankFeatureDoc doc) -> doc.score).reversed())
.toArray(RankFeatureDoc[]::new);
}

protected InferenceAction.Request generateRequest(List<String> docFeatures) {
Expand All @@ -85,28 +141,12 @@ protected InferenceAction.Request generateRequest(List<String> docFeatures) {
);
}

private float[] extractScoresFromResponse(InferenceAction.Response response) {
InferenceServiceResults results = response.getResults();
assert results instanceof RankedDocsResults;

List<RankedDocsResults.RankedDoc> rankedDocs = ((RankedDocsResults) results).getRankedDocs();
private float[] extractScoresFromRankedDocs(List<RankedDocsResults.RankedDoc> rankedDocs) {
float[] scores = new float[rankedDocs.size()];
for (RankedDocsResults.RankedDoc rankedDoc : rankedDocs) {
scores[rankedDoc.index()] = rankedDoc.relevanceScore();
}

return scores;
}

/**
* Sorts documents by score descending and discards those with a score less than minScore.
* @param originalDocs documents to process
*/
@Override
protected RankFeatureDoc[] preprocess(RankFeatureDoc[] originalDocs) {
return Arrays.stream(originalDocs)
.filter(doc -> minScore == null || doc.score >= minScore)
.sorted(Comparator.comparing((RankFeatureDoc doc) -> doc.score).reversed())
.toArray(RankFeatureDoc[]::new);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.search.rank.feature.RankFeatureDoc;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.argThat;
Expand Down Expand Up @@ -54,10 +54,9 @@ public void onFailure(Exception e) {
fail();
}
});

verify(mockClient).execute(
eq(InferenceAction.INSTANCE),
argThat(actionRequest -> ((InferenceAction.Request) actionRequest).getTaskType().equals(TaskType.RERANK)),
eq(GetInferenceModelAction.INSTANCE),
argThat(actionRequest -> ((GetInferenceModelAction.Request) actionRequest).getTaskType().equals(TaskType.RERANK)),
any()
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

package org.elasticsearch.xpack.inference.rank.textsimilarity;

import org.elasticsearch.action.search.SearchPhaseExecutionException;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.inference.InputType;
Expand All @@ -29,22 +30,46 @@
import java.util.Objects;

import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;

public class TextSimilarityRankTests extends ESSingleNodeTestCase {

/**
* {@code TextSimilarityRankBuilder} that simulates an inference call that returns a different number of results as the input.
* {@code TextSimilarityRankBuilder} that sets top_n in the inference endpoint's task settings.
* See {@code TextSimilarityTestPlugin -> TestFilter -> handleGetInferenceModelActionRequest} for the logic that extracts the top_n
* value.
*/
public static class InvalidInferenceResultCountProvidingTextSimilarityRankBuilder extends TextSimilarityRankBuilder {
public static class TopNConfigurationAcceptingTextSimilarityRankBuilder extends TextSimilarityRankBuilder {

public InvalidInferenceResultCountProvidingTextSimilarityRankBuilder(
public TopNConfigurationAcceptingTextSimilarityRankBuilder(
String field,
String inferenceId,
String inferenceText,
int rankWindowSize,
Float minScore
Float minScore,
int topN
) {
super(field, inferenceId + "-task-settings-top-" + topN, inferenceText, rankWindowSize, minScore);
}
}

/**
* {@code TextSimilarityRankBuilder} that simulates an inference call returning N results.
*/
public static class InferenceResultCountAcceptingTextSimilarityRankBuilder extends TextSimilarityRankBuilder {

private final int inferenceResultCount;

public InferenceResultCountAcceptingTextSimilarityRankBuilder(
String field,
String inferenceId,
String inferenceText,
int rankWindowSize,
Float minScore,
int inferenceResultCount
) {
super(field, inferenceId, inferenceText, rankWindowSize, minScore);
this.inferenceResultCount = inferenceResultCount;
}

@Override
Expand All @@ -62,10 +87,10 @@ public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorCo
protected InferenceAction.Request generateRequest(List<String> docFeatures) {
return new InferenceAction.Request(
TaskType.RERANK,
inferenceId,
this.inferenceId,
inferenceText,
docFeatures,
Map.of("invalidInferenceResultCount", true),
Map.of("inferenceResultCount", inferenceResultCount),
InputType.SEARCH,
InferenceAction.Request.DEFAULT_TIMEOUT
);
Expand Down Expand Up @@ -151,17 +176,38 @@ public void testRerankInferenceFailure() {
);
}

public void testRerankInferenceResultMismatch() {
ElasticsearchAssertions.assertFailures(
public void testRerankTopNConfigurationAndRankWindowSizeMismatch() {
SearchPhaseExecutionException ex = expectThrows(
SearchPhaseExecutionException.class,
// Execute search with text similarity reranking
client.prepareSearch()
.setRankBuilder(
new InvalidInferenceResultCountProvidingTextSimilarityRankBuilder("text", "my-rerank-model", "my query", 100, 1.5f)
// Simulate reranker configuration with top_n=3 in task_settings, which is different from rank_window_size=10
// (Note: top_n comes from inferenceId, there's no other easy way of passing this to the mocked get model request)
new TopNConfigurationAcceptingTextSimilarityRankBuilder("text", "my-rerank-model", "my query", 100, 1.5f, 3)
)
.setQuery(QueryBuilders.matchAllQuery()),
RestStatus.INTERNAL_SERVER_ERROR,
containsString("Failed to execute phase [rank-feature], Computing updated ranks for results failed")
.setQuery(QueryBuilders.matchAllQuery())
);
assertThat(ex.status(), equalTo(RestStatus.BAD_REQUEST));
assertThat(
ex.getDetailedMessage(),
containsString("Reduce rank_window_size to be less than or equal to the configured top N value")
);
}

public void testRerankInputSizeAndInferenceResultsMismatch() {
SearchPhaseExecutionException ex = expectThrows(
SearchPhaseExecutionException.class,
// Execute search with text similarity reranking
client.prepareSearch()
.setRankBuilder(
// Simulate reranker returning different number of results from input
new InferenceResultCountAcceptingTextSimilarityRankBuilder("text", "my-rerank-model", "my query", 100, 1.5f, 4)
)
.setQuery(QueryBuilders.matchAllQuery())
);
assertThat(ex.status(), equalTo(RestStatus.INTERNAL_SERVER_ERROR));
assertThat(ex.getDetailedMessage(), containsString("Reranker input document count and returned score count mismatch"));
}

private static void assertHitHasRankScoreAndText(SearchHit hit, int expectedRank, float expectedScore, String expectedText) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.inference.EmptyTaskSettings;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.plugins.ActionPlugin;
import org.elasticsearch.plugins.Plugin;
Expand All @@ -39,15 +41,21 @@
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
import org.elasticsearch.xpack.inference.services.cohere.CohereService;
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankServiceSettings;
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import static java.util.Collections.singletonList;
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
Expand Down Expand Up @@ -100,31 +108,66 @@ public int order() {
}

@Override
@SuppressWarnings("unchecked")
public <Request extends ActionRequest, Response extends ActionResponse> void apply(
Task task,
String action,
Request request,
ActionListener<Response> listener,
ActionFilterChain<Request, Response> chain
) {
// For any other action than inference, execute normally
if (action.equals(InferenceAction.INSTANCE.name()) == false) {
if (action.equals(GetInferenceModelAction.INSTANCE.name())) {
assert request instanceof GetInferenceModelAction.Request;
handleGetInferenceModelActionRequest((GetInferenceModelAction.Request) request, listener);
} else if (action.equals(InferenceAction.INSTANCE.name())) {
assert request instanceof InferenceAction.Request;
handleInferenceActionRequest((InferenceAction.Request) request, listener);
} else {
// For any other action than get model and inference, execute normally
chain.proceed(task, action, request, listener);
return;
}
}

assert request instanceof InferenceAction.Request;
boolean shouldThrow = (boolean) ((InferenceAction.Request) request).getTaskSettings().getOrDefault("throwing", false);
boolean hasInvalidInferenceResultCount = (boolean) ((InferenceAction.Request) request).getTaskSettings()
.getOrDefault("invalidInferenceResultCount", false);
@SuppressWarnings("unchecked")
private <Response extends ActionResponse> void handleGetInferenceModelActionRequest(
GetInferenceModelAction.Request request,
ActionListener<Response> listener
) {
String inferenceEntityId = request.getInferenceEntityId();
Integer topN = null;
Matcher extractTopN = Pattern.compile(".*(task-settings-top-\\d+).*").matcher(inferenceEntityId);
if (extractTopN.find()) {
topN = Integer.parseInt(extractTopN.group(1).replaceAll("\\D", ""));
}

ActionResponse response = new GetInferenceModelAction.Response(
List.of(
new ModelConfigurations(
request.getInferenceEntityId(),
request.getTaskType(),
CohereService.NAME,
new CohereRerankServiceSettings("uri", "model", null),
topN == null ? new EmptyTaskSettings() : new CohereRerankTaskSettings(topN, null, null)
)
)
);
listener.onResponse((Response) response);
}

@SuppressWarnings("unchecked")
private <Response extends ActionResponse> void handleInferenceActionRequest(
InferenceAction.Request request,
ActionListener<Response> listener
) {
Map<String, Object> taskSettings = request.getTaskSettings();
boolean shouldThrow = (boolean) taskSettings.getOrDefault("throwing", false);
Integer inferenceResultCount = (Integer) taskSettings.get("inferenceResultCount");

if (shouldThrow) {
listener.onFailure(new UnsupportedOperationException("simulated failure"));
} else {
List<RankedDocsResults.RankedDoc> rankedDocsResults = new ArrayList<>();
List<String> inputs = ((InferenceAction.Request) request).getInput();
int resultCount = hasInvalidInferenceResultCount ? inputs.size() - 1 : inputs.size();
List<String> inputs = request.getInput();
int resultCount = inferenceResultCount == null ? inputs.size() : inferenceResultCount;
for (int i = 0; i < resultCount; i++) {
rankedDocsResults.add(new RankedDocsResults.RankedDoc(i, Float.parseFloat(inputs.get(i)), inputs.get(i)));
}
Expand Down

0 comments on commit c722ceb

Please sign in to comment.