-
Notifications
You must be signed in to change notification settings - Fork 25.3k
[ES|QL] COMPLETION command - Inference Operator implementation #127409
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
b85e704
e6ac175
7fe8adc
757fbe4
39ad919
6f5a8b3
bbd1f69
71ad3b8
33a289d
d5797d3
62f39eb
2229c94
3a59b96
7423cfc
7acd6dd
1c1c003
505dbdc
815d479
1657d2c
3d48b05
06b99ed
a5ea05a
66563b6
5384db0
04f0d36
e0020e2
9e57399
65da3ef
9f88520
cb20ce7
bf73b40
1698f5f
d107767
e77a8bb
58aa070
1c858c6
c37d3dc
b9c22bd
89cb900
190f7d7
72541c0
1d486b6
88a63a8
c70b0a1
1d6b589
3d0819d
115ee49
1e95722
340c189
3a8422d
d3f7a0e
d3a47a2
e64b81c
1901bde
12ab742
d99971f
7463dad
63f47a8
20773db
2387d14
035ae04
0a6ef3f
cc7d8fa
04191cb
7582956
2f751cc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -67,14 +67,11 @@ | |
import static org.elasticsearch.xpack.esql.CsvTestUtils.isEnabled; | ||
import static org.elasticsearch.xpack.esql.CsvTestUtils.loadCsvSpecValues; | ||
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.availableDatasetsForEs; | ||
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.clusterHasInferenceEndpoint; | ||
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.clusterHasRerankInferenceEndpoint; | ||
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.createInferenceEndpoint; | ||
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.createRerankInferenceEndpoint; | ||
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.deleteInferenceEndpoint; | ||
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.deleteRerankInferenceEndpoint; | ||
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.createInferenceEndpoints; | ||
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.deleteInferenceEndpoints; | ||
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.loadDataSetIntoEs; | ||
import static org.elasticsearch.xpack.esql.EsqlTestUtils.classpathResources; | ||
import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.COMPLETION; | ||
import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.METRICS_COMMAND; | ||
import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.RERANK; | ||
import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.SEMANTIC_TEXT_FIELD_CAPS; | ||
|
@@ -138,12 +135,8 @@ protected EsqlSpecTestCase( | |
|
||
@Before | ||
public void setup() throws IOException { | ||
if (supportsInferenceTestService() && clusterHasInferenceEndpoint(client()) == false) { | ||
createInferenceEndpoint(client()); | ||
} | ||
|
||
if (supportsInferenceTestService() && clusterHasRerankInferenceEndpoint(client()) == false) { | ||
createRerankInferenceEndpoint(client()); | ||
if (supportsInferenceTestService()) { | ||
createInferenceEndpoints(adminClient()); | ||
} | ||
|
||
boolean supportsLookup = supportsIndexModeLookup(); | ||
|
@@ -164,8 +157,8 @@ public static void wipeTestData() throws IOException { | |
} | ||
} | ||
|
||
deleteInferenceEndpoint(client()); | ||
deleteRerankInferenceEndpoint(client()); | ||
deleteInferenceEndpoints(adminClient()); | ||
|
||
} | ||
|
||
public boolean logResults() { | ||
|
@@ -254,7 +247,7 @@ protected boolean supportsInferenceTestService() { | |
} | ||
|
||
protected boolean requiresInferenceEndpoint() { | ||
return Stream.of(SEMANTIC_TEXT_FIELD_CAPS.capabilityName(), RERANK.capabilityName()) | ||
return Stream.of(SEMANTIC_TEXT_FIELD_CAPS.capabilityName(), RERANK.capabilityName(), COMPLETION.capabilityName()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ℹ️ Can not test completion in multi_cluster cause the inference test plugin is not available. |
||
.anyMatch(testCase.requiredCapabilities::contains); | ||
} | ||
|
||
|
@@ -372,6 +365,11 @@ private Object valueMapper(CsvTestUtils.Type type, Object value) { | |
return new BigDecimal(s).round(new MathContext(7, RoundingMode.DOWN)).doubleValue(); | ||
} | ||
} | ||
if (type == CsvTestUtils.Type.TEXT || type == CsvTestUtils.Type.KEYWORD || type == CsvTestUtils.Type.SEMANTIC_TEXT) { | ||
if (value instanceof String s) { | ||
value = s.replaceAll("\\\\n", "\n"); | ||
} | ||
} | ||
return value.toString(); | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,6 +27,7 @@ | |
import org.elasticsearch.common.settings.Settings; | ||
import org.elasticsearch.common.xcontent.XContentHelper; | ||
import org.elasticsearch.core.Nullable; | ||
import org.elasticsearch.inference.TaskType; | ||
import org.elasticsearch.logging.LogManager; | ||
import org.elasticsearch.logging.Logger; | ||
import org.elasticsearch.test.rest.ESRestTestCase; | ||
|
@@ -317,7 +318,7 @@ public static Set<TestDataset> availableDatasetsForEs( | |
boolean supportsIndexModeLookup, | ||
boolean supportsSourceFieldMapping | ||
) throws IOException { | ||
boolean inferenceEnabled = clusterHasInferenceEndpoint(client); | ||
boolean inferenceEnabled = clusterHasSparseEmbeddingInferenceEndpoint(client); | ||
|
||
Set<TestDataset> testDataSets = new HashSet<>(); | ||
|
||
|
@@ -379,77 +380,90 @@ private static void loadDataSetIntoEs( | |
} | ||
} | ||
|
||
public static void createInferenceEndpoints(RestClient client) throws IOException { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ℹ️ Added new inference endpoint "test_completion" available in CSV tests |
||
if (clusterHasSparseEmbeddingInferenceEndpoint(client) == false) { | ||
createSparseEmbeddingInferenceEndpoint(client); | ||
} | ||
|
||
if (clusterHasRerankInferenceEndpoint(client) == false) { | ||
createRerankInferenceEndpoint(client); | ||
} | ||
|
||
if (clusterHasCompletionInferenceEndpoint(client) == false) { | ||
createCompletionInferenceEndpoint(client); | ||
} | ||
} | ||
|
||
public static void deleteInferenceEndpoints(RestClient client) throws IOException { | ||
deleteSparseEmbeddingInferenceEndpoint(client); | ||
deleteRerankInferenceEndpoint(client); | ||
deleteCompletionInferenceEndpoint(client); | ||
} | ||
|
||
/** The semantic_text mapping type require an inference endpoint that needs to be setup before creating the index. */ | ||
public static void createInferenceEndpoint(RestClient client) throws IOException { | ||
Request request = new Request("PUT", "_inference/sparse_embedding/test_sparse_inference"); | ||
request.setJsonEntity(""" | ||
public static void createSparseEmbeddingInferenceEndpoint(RestClient client) throws IOException { | ||
createInferenceEndpoint(client, TaskType.SPARSE_EMBEDDING, "test_sparse_inference", """ | ||
{ | ||
"service": "test_service", | ||
"service_settings": { | ||
"model": "my_model", | ||
"api_key": "abc64" | ||
}, | ||
"task_settings": { | ||
} | ||
"service_settings": { "model": "my_model", "api_key": "abc64" }, | ||
"task_settings": { } | ||
} | ||
"""); | ||
client.performRequest(request); | ||
} | ||
|
||
public static void deleteInferenceEndpoint(RestClient client) throws IOException { | ||
try { | ||
client.performRequest(new Request("DELETE", "_inference/test_sparse_inference")); | ||
} catch (ResponseException e) { | ||
// 404 here means the endpoint was not created | ||
if (e.getResponse().getStatusLine().getStatusCode() != 404) { | ||
throw e; | ||
} | ||
} | ||
public static void deleteSparseEmbeddingInferenceEndpoint(RestClient client) throws IOException { | ||
deleteInferenceEndpoint(client, "test_sparse_inference"); | ||
} | ||
|
||
public static boolean clusterHasInferenceEndpoint(RestClient client) throws IOException { | ||
Request request = new Request("GET", "_inference/sparse_embedding/test_sparse_inference"); | ||
try { | ||
client.performRequest(request); | ||
} catch (ResponseException e) { | ||
if (e.getResponse().getStatusLine().getStatusCode() == 404) { | ||
return false; | ||
} | ||
throw e; | ||
} | ||
return true; | ||
public static boolean clusterHasSparseEmbeddingInferenceEndpoint(RestClient client) throws IOException { | ||
return clusterHasInferenceEndpoint(client, TaskType.SPARSE_EMBEDDING, "test_sparse_inference"); | ||
} | ||
|
||
public static void createRerankInferenceEndpoint(RestClient client) throws IOException { | ||
Request request = new Request("PUT", "_inference/rerank/test_reranker"); | ||
request.setJsonEntity(""" | ||
createInferenceEndpoint(client, TaskType.RERANK, "test_reranker", """ | ||
{ | ||
"service": "test_reranking_service", | ||
"service_settings": { | ||
"model_id": "my_model", | ||
"api_key": "abc64" | ||
}, | ||
"task_settings": { | ||
"use_text_length": true | ||
} | ||
"service_settings": { "model_id": "my_model", "api_key": "abc64" }, | ||
"task_settings": { "use_text_length": true } | ||
} | ||
"""); | ||
client.performRequest(request); | ||
} | ||
|
||
public static void deleteRerankInferenceEndpoint(RestClient client) throws IOException { | ||
try { | ||
client.performRequest(new Request("DELETE", "_inference/rerank/test_reranker")); | ||
} catch (ResponseException e) { | ||
// 404 here means the endpoint was not created | ||
if (e.getResponse().getStatusLine().getStatusCode() != 404) { | ||
throw e; | ||
} | ||
} | ||
deleteInferenceEndpoint(client, "test_reranker"); | ||
} | ||
|
||
public static boolean clusterHasRerankInferenceEndpoint(RestClient client) throws IOException { | ||
Request request = new Request("GET", "_inference/rerank/test_reranker"); | ||
return clusterHasInferenceEndpoint(client, TaskType.RERANK, "test_reranker"); | ||
} | ||
|
||
public static void createCompletionInferenceEndpoint(RestClient client) throws IOException { | ||
createInferenceEndpoint(client, TaskType.COMPLETION, "test_completion", """ | ||
{ | ||
"service": "completion_test_service", | ||
"service_settings": { "model": "my_model", "api_key": "abc64" }, | ||
"task_settings": { "temperature": 3 } | ||
} | ||
"""); | ||
} | ||
|
||
public static void deleteCompletionInferenceEndpoint(RestClient client) throws IOException { | ||
deleteInferenceEndpoint(client, "test_completion"); | ||
} | ||
|
||
public static boolean clusterHasCompletionInferenceEndpoint(RestClient client) throws IOException { | ||
return clusterHasInferenceEndpoint(client, TaskType.COMPLETION, "test_completion"); | ||
} | ||
|
||
private static void createInferenceEndpoint(RestClient client, TaskType taskType, String inferenceId, String modelSettings) | ||
throws IOException { | ||
Request request = new Request("PUT", "_inference/" + taskType.name() + "/" + inferenceId); | ||
request.setJsonEntity(modelSettings); | ||
client.performRequest(request); | ||
} | ||
|
||
private static boolean clusterHasInferenceEndpoint(RestClient client, TaskType taskType, String inferenceId) throws IOException { | ||
Request request = new Request("GET", "_inference/" + taskType.name() + "/" + inferenceId); | ||
try { | ||
client.performRequest(request); | ||
} catch (ResponseException e) { | ||
|
@@ -461,6 +475,17 @@ public static boolean clusterHasRerankInferenceEndpoint(RestClient client) throw | |
return true; | ||
} | ||
|
||
private static void deleteInferenceEndpoint(RestClient client, String inferenceId) throws IOException { | ||
try { | ||
client.performRequest(new Request("DELETE", "_inference/" + inferenceId)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we give the right path here? is it supposed to be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In fact, both endpoints are valid and can be used indifferently. |
||
} catch (ResponseException e) { | ||
// 404 here means the endpoint was not created | ||
if (e.getResponse().getStatusLine().getStatusCode() != 404) { | ||
throw e; | ||
} | ||
} | ||
} | ||
|
||
private static void loadEnrichPolicy(RestClient client, String policyName, String policyFileName, Logger logger) throws IOException { | ||
URL policyMapping = getResource("/" + policyFileName); | ||
String entity = readTextFile(policyMapping); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
// Note: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ℹ️ Added CSV tests for the completion command |
||
// The "test_completion" service returns the prompt in uppercase, making the output easy to guess. | ||
|
||
|
||
completion using a ROW source operator | ||
required_capability: completion | ||
|
||
ROW prompt="Who is Victor Hugo?" | ||
| COMPLETION prompt WITH test_completion AS completion_output | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what happens if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If the prompt is multi-valued, the So the multi-value input: Will be translated into the following prompt:
I built this as a quite good alternative to concat in some case. Also, I added a CSV test cases for it. |
||
; | ||
|
||
prompt:keyword | completion_output:keyword | ||
Who is Victor Hugo? | WHO IS VICTOR HUGO? | ||
; | ||
|
||
|
||
completion using a ROW source operator and prompt is a multi-valued field | ||
required_capability: completion | ||
|
||
ROW prompt=["Answer the following question:", "Who is Victor Hugo?"] | ||
| COMPLETION prompt WITH test_completion AS completion_output | ||
; | ||
|
||
prompt:keyword | completion_output:keyword | ||
[Answer the following question:, Who is Victor Hugo?] | ANSWER THE FOLLOWING QUESTION:\nWHO IS VICTOR HUGO? | ||
; | ||
|
||
|
||
completion after a search | ||
required_capability: completion | ||
required_capability: match_operator_colon | ||
|
||
FROM books METADATA _score | ||
| WHERE title:"war and peace" AND author:"Tolstoy" | ||
| SORT _score DESC | ||
| LIMIT 2 | ||
| COMPLETION title WITH test_completion | ||
| KEEP title, completion | ||
; | ||
|
||
title:text | completion:keyword | ||
War and Peace | WAR AND PEACE | ||
War and Peace (Signet Classics) | WAR AND PEACE (SIGNET CLASSICS) | ||
; | ||
|
||
completion using a function as a prompt | ||
required_capability: completion | ||
required_capability: match_operator_colon | ||
|
||
FROM books METADATA _score | ||
| WHERE title:"war and peace" AND author:"Tolstoy" | ||
| SORT _score DESC | ||
| LIMIT 2 | ||
| COMPLETION CONCAT("This is a prompt: ", title) WITH test_completion | ||
| KEEP title, completion | ||
; | ||
|
||
title:text | completion:keyword | ||
War and Peace | THIS IS A PROMPT: WAR AND PEACE | ||
War and Peace (Signet Classics) | THIS IS A PROMPT: WAR AND PEACE (SIGNET CLASSICS) | ||
; |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -617,7 +617,7 @@ private LogicalPlan resolveCompletion(Completion p, List<Attribute> childrenOutp | |
Expression prompt = p.prompt(); | ||
|
||
if (targetField instanceof UnresolvedAttribute ua) { | ||
targetField = new ReferenceAttribute(ua.source(), ua.name(), TEXT); | ||
targetField = new ReferenceAttribute(ua.source(), ua.name(), KEYWORD); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ℹ️ |
||
} | ||
|
||
if (prompt.resolved() == false) { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ℹ️ Re-enable RerankOperator flaky tests because they are fixed right now.