Skip to content

Commit

Permalink
Add test case for invalid document indices in reranker result
Browse files Browse the repository at this point in the history
  • Loading branch information
demjened committed Jul 24, 2024
1 parent fb5a671 commit 7e3f153
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,20 @@ public class TextSimilarityRankTests extends ESSingleNodeTestCase {
*/
public static class InvalidInferenceResultCountProvidingTextSimilarityRankBuilder extends TextSimilarityRankBuilder {

private boolean hasInvalidDocumentIndices = false;

public InvalidInferenceResultCountProvidingTextSimilarityRankBuilder(
String field,
String inferenceId,
String inferenceText,
int rankWindowSize,
Float minScore,
boolean hasInvalidDocumentIndices
) {
this(field, inferenceId, inferenceText, rankWindowSize, minScore);
this.hasInvalidDocumentIndices = hasInvalidDocumentIndices;
}

public InvalidInferenceResultCountProvidingTextSimilarityRankBuilder(
String field,
String inferenceId,
Expand Down Expand Up @@ -65,7 +79,7 @@ protected InferenceAction.Request generateRequest(List<String> docFeatures) {
inferenceId,
inferenceText,
docFeatures,
Map.of("invalidInferenceResultCount", true),
Map.of("invalidInferenceResultCount", true, "invalidDocumentIndices", hasInvalidDocumentIndices),
InputType.SEARCH,
InferenceAction.Request.DEFAULT_TIMEOUT
);
Expand Down Expand Up @@ -151,11 +165,12 @@ public void testRerankInferenceFailure() {
);
}

public void testRerankInferenceResultMismatch() {
public void testRerankInferenceResultCountMismatch() {
ElasticsearchAssertions.assertFailures(
// Execute search with text similarity reranking
client.prepareSearch()
.setRankBuilder(
// Simulate reranker returning different number of results from input
new InvalidInferenceResultCountProvidingTextSimilarityRankBuilder("text", "my-rerank-model", "my query", 100, 1.5f)
)
.setQuery(QueryBuilders.matchAllQuery()),
Expand All @@ -164,6 +179,27 @@ public void testRerankInferenceResultMismatch() {
);
}

public void testRerankInvalidDocumentIndices() {
ElasticsearchAssertions.assertFailures(
// Execute search with text similarity reranking
client.prepareSearch()
.setRankBuilder(
// Simulate reranker returning different number of results from input, also invalid document indices in results
new InvalidInferenceResultCountProvidingTextSimilarityRankBuilder(
"text",
"my-rerank-model",
"my query",
100,
1.5f,
true
)
)
.setQuery(QueryBuilders.matchAllQuery()),
RestStatus.INTERNAL_SERVER_ERROR,
containsString("Failed to execute phase [rank-feature], Computing updated ranks for results failed")
);
}

private static void assertHitHasRankScoreAndText(SearchHit hit, int expectedRank, float expectedScore, String expectedText) {
assertEquals(expectedRank, hit.getRank());
assertEquals(expectedScore, hit.getScore(), 0.0f);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,10 @@ public <Request extends ActionRequest, Response extends ActionResponse> void app
}

assert request instanceof InferenceAction.Request;
boolean shouldThrow = (boolean) ((InferenceAction.Request) request).getTaskSettings().getOrDefault("throwing", false);
boolean hasInvalidInferenceResultCount = (boolean) ((InferenceAction.Request) request).getTaskSettings()
.getOrDefault("invalidInferenceResultCount", false);
Map<String, Object> taskSettings = ((InferenceAction.Request) request).getTaskSettings();
boolean shouldThrow = (boolean) taskSettings.getOrDefault("throwing", false);
boolean hasInvalidInferenceResultCount = (boolean) taskSettings.getOrDefault("invalidInferenceResultCount", false);
boolean hasInvalidDocumentIndices = (boolean) taskSettings.getOrDefault("invalidDocumentIndices", false);

if (shouldThrow) {
listener.onFailure(new UnsupportedOperationException("simulated failure"));
Expand All @@ -126,7 +127,13 @@ public <Request extends ActionRequest, Response extends ActionResponse> void app
List<String> inputs = ((InferenceAction.Request) request).getInput();
int resultCount = hasInvalidInferenceResultCount ? inputs.size() - 1 : inputs.size();
for (int i = 0; i < resultCount; i++) {
rankedDocsResults.add(new RankedDocsResults.RankedDoc(i, Float.parseFloat(inputs.get(i)), inputs.get(i)));
rankedDocsResults.add(
new RankedDocsResults.RankedDoc(
hasInvalidDocumentIndices ? i * 2 : i,
Float.parseFloat(inputs.get(i)),
inputs.get(i)
)
);
}
ActionResponse response = new InferenceAction.Response(new RankedDocsResults(rankedDocsResults));
listener.onResponse((Response) response);
Expand Down

0 comments on commit 7e3f153

Please sign in to comment.