Skip to content

Commit 36f7399

Browse files
committed
[ML] Fix request format for Cohere V2 completions (elastic#131091)
1 parent 6e2b561 commit 36f7399

File tree

8 files changed

+223
-8
lines changed

8 files changed

+223
-8
lines changed

x-pack/plugin/inference/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/xpack/application/CohereServiceUpgradeIT.java

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,12 @@ public class CohereServiceUpgradeIT extends InferenceUpgradeTestCase {
3939
// TODO: replace with proper test features
4040
private static final String COHERE_EMBEDDINGS_ADDED_TEST_FEATURE = "gte_v8.13.0";
4141
private static final String COHERE_RERANK_ADDED_TEST_FEATURE = "gte_v8.14.0";
42+
private static final String COHERE_COMPLETIONS_ADDED_TEST_FEATURE = "gte_v8.15.0";
4243
private static final String COHERE_V2_API_ADDED_TEST_FEATURE = "inference.cohere.v2";
4344

4445
private static MockWebServer cohereEmbeddingsServer;
4546
private static MockWebServer cohereRerankServer;
47+
private static MockWebServer cohereCompletionsServer;
4648

4749
private enum ApiVersion {
4850
V1,
@@ -60,12 +62,16 @@ public static void startWebServer() throws IOException {
6062

6163
cohereRerankServer = new MockWebServer();
6264
cohereRerankServer.start();
65+
66+
cohereCompletionsServer = new MockWebServer();
67+
cohereCompletionsServer.start();
6368
}
6469

6570
@AfterClass
6671
public static void shutdown() {
6772
cohereEmbeddingsServer.close();
6873
cohereRerankServer.close();
74+
cohereCompletionsServer.close();
6975
}
7076

7177
@SuppressWarnings("unchecked")
@@ -326,6 +332,80 @@ private void assertRerank(String inferenceId) throws IOException {
326332
assertThat(inferenceMap.entrySet(), not(empty()));
327333
}
328334

335+
@SuppressWarnings("unchecked")
336+
public void testCohereCompletions() throws IOException {
337+
var completionsSupported = oldClusterHasFeature(COHERE_COMPLETIONS_ADDED_TEST_FEATURE);
338+
assumeTrue("Cohere completions not supported", completionsSupported);
339+
340+
ApiVersion oldClusterApiVersion = oldClusterHasFeature(COHERE_V2_API_ADDED_TEST_FEATURE) ? ApiVersion.V2 : ApiVersion.V1;
341+
342+
final String oldClusterId = "old-cluster-completions";
343+
344+
if (isOldCluster()) {
345+
// queue a response as PUT will call the service
346+
cohereCompletionsServer.enqueue(new MockResponse().setResponseCode(200).setBody(completionsResponse(oldClusterApiVersion)));
347+
put(oldClusterId, completionsConfig(getUrl(cohereCompletionsServer)), TaskType.COMPLETION);
348+
349+
var configs = (List<Map<String, Object>>) get(TaskType.COMPLETION, oldClusterId).get("endpoints");
350+
assertThat(configs, hasSize(1));
351+
assertEquals("cohere", configs.get(0).get("service"));
352+
var serviceSettings = (Map<String, Object>) configs.get(0).get("service_settings");
353+
assertThat(serviceSettings, hasEntry("model_id", "command"));
354+
} else if (isMixedCluster()) {
355+
var configs = (List<Map<String, Object>>) get(TaskType.COMPLETION, oldClusterId).get("endpoints");
356+
assertThat(configs, hasSize(1));
357+
assertEquals("cohere", configs.get(0).get("service"));
358+
var serviceSettings = (Map<String, Object>) configs.get(0).get("service_settings");
359+
assertThat(serviceSettings, hasEntry("model_id", "command"));
360+
} else if (isUpgradedCluster()) {
361+
// check old cluster model
362+
var configs = (List<Map<String, Object>>) get(TaskType.COMPLETION, oldClusterId).get("endpoints");
363+
var serviceSettings = (Map<String, Object>) configs.get(0).get("service_settings");
364+
assertThat(serviceSettings, hasEntry("model_id", "command"));
365+
366+
final String newClusterId = "new-cluster-completions";
367+
{
368+
cohereCompletionsServer.enqueue(new MockResponse().setResponseCode(200).setBody(completionsResponse(oldClusterApiVersion)));
369+
var inferenceMap = inference(oldClusterId, TaskType.COMPLETION, "some text");
370+
assertThat(inferenceMap.entrySet(), not(empty()));
371+
assertVersionInPath(cohereCompletionsServer.requests().getLast(), "chat", oldClusterApiVersion);
372+
}
373+
{
374+
// new cluster uses the V2 API
375+
cohereCompletionsServer.enqueue(new MockResponse().setResponseCode(200).setBody(completionsResponse(ApiVersion.V2)));
376+
put(newClusterId, completionsConfig(getUrl(cohereCompletionsServer)), TaskType.COMPLETION);
377+
378+
cohereCompletionsServer.enqueue(new MockResponse().setResponseCode(200).setBody(completionsResponse(ApiVersion.V2)));
379+
var inferenceMap = inference(newClusterId, TaskType.COMPLETION, "some text");
380+
assertThat(inferenceMap.entrySet(), not(empty()));
381+
assertVersionInPath(cohereCompletionsServer.requests().getLast(), "chat", ApiVersion.V2);
382+
}
383+
384+
{
385+
// new endpoints use the V2 API which require the model to be set
386+
final String upgradedClusterNoModel = "upgraded-cluster-missing-model-id";
387+
var jsonBody = Strings.format("""
388+
{
389+
"service": "cohere",
390+
"service_settings": {
391+
"url": "%s",
392+
"api_key": "XXXX"
393+
}
394+
}
395+
""", getUrl(cohereEmbeddingsServer));
396+
397+
var e = expectThrows(ResponseException.class, () -> put(upgradedClusterNoModel, jsonBody, TaskType.COMPLETION));
398+
assertThat(
399+
e.getMessage(),
400+
containsString("Validation Failed: 1: The [service_settings.model_id] field is required for the Cohere V2 API.")
401+
);
402+
}
403+
404+
delete(oldClusterId);
405+
delete(newClusterId);
406+
}
407+
}
408+
329409
private String embeddingConfigByte(String url) {
330410
return embeddingConfigTemplate(url, "byte");
331411
}
@@ -451,4 +531,86 @@ private String rerankResponse() {
451531
""";
452532
}
453533

534+
private String completionsConfig(String url) {
535+
return Strings.format("""
536+
{
537+
"service": "cohere",
538+
"service_settings": {
539+
"api_key": "XXXX",
540+
"model_id": "command",
541+
"url": "%s"
542+
}
543+
}
544+
""", url);
545+
}
546+
547+
private String completionsResponse(ApiVersion version) {
548+
return switch (version) {
549+
case V1 -> v1CompletionsResponse();
550+
case V2 -> v2CompletionsResponse();
551+
};
552+
}
553+
554+
private String v1CompletionsResponse() {
555+
return """
556+
{
557+
"response_id": "some id",
558+
"text": "result",
559+
"generation_id": "some id",
560+
"chat_history": [
561+
{
562+
"role": "USER",
563+
"message": "some input"
564+
},
565+
{
566+
"role": "CHATBOT",
567+
"message": "v1 response from the llm"
568+
}
569+
],
570+
"finish_reason": "COMPLETE",
571+
"meta": {
572+
"api_version": {
573+
"version": "1"
574+
},
575+
"billed_units": {
576+
"input_tokens": 4,
577+
"output_tokens": 191
578+
},
579+
"tokens": {
580+
"input_tokens": 70,
581+
"output_tokens": 191
582+
}
583+
}
584+
}
585+
""";
586+
}
587+
588+
private String v2CompletionsResponse() {
589+
return """
590+
{
591+
"id": "c14c80c3-18eb-4519-9460-6c92edd8cfb4",
592+
"finish_reason": "COMPLETE",
593+
"message": {
594+
"role": "assistant",
595+
"content": [
596+
{
597+
"type": "text",
598+
"text": "v2 response from the LLM"
599+
}
600+
]
601+
},
602+
"usage": {
603+
"billed_units": {
604+
"input_tokens": 1,
605+
"output_tokens": 2
606+
},
607+
"tokens": {
608+
"input_tokens": 3,
609+
"output_tokens": 4
610+
}
611+
}
612+
}
613+
""";
614+
}
615+
454616
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereUtils.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,16 @@ public class CohereUtils {
2828
public static final String DOCUMENTS_FIELD = "documents";
2929
public static final String EMBEDDING_TYPES_FIELD = "embedding_types";
3030
public static final String INPUT_TYPE_FIELD = "input_type";
31-
public static final String MESSAGE_FIELD = "message";
31+
public static final String V1_MESSAGE_FIELD = "message";
32+
public static final String V2_MESSAGES_FIELD = "messages";
3233
public static final String MODEL_FIELD = "model";
3334
public static final String QUERY_FIELD = "query";
35+
public static final String V2_ROLE_FIELD = "role";
3436
public static final String SEARCH_DOCUMENT = "search_document";
3537
public static final String SEARCH_QUERY = "search_query";
36-
public static final String TEXTS_FIELD = "texts";
3738
public static final String STREAM_FIELD = "stream";
39+
public static final String TEXTS_FIELD = "texts";
40+
public static final String USER_FIELD = "user";
3841

3942
public static Header createRequestSourceHeader() {
4043
return new BasicHeader(REQUEST_SOURCE_HEADER, ELASTIC_REQUEST_SOURCE);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1CompletionRequest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ public CohereV1CompletionRequest(List<String> input, CohereCompletionModel model
3030
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
3131
builder.startObject();
3232
// we only allow one input for completion, so always get the first one
33-
builder.field(CohereUtils.MESSAGE_FIELD, input.getFirst());
33+
builder.field(CohereUtils.V1_MESSAGE_FIELD, input.getFirst());
3434
if (getModelId() != null) {
3535
builder.field(CohereUtils.MODEL_FIELD, getModelId());
3636
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2CompletionRequest.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,13 @@ public CohereV2CompletionRequest(List<String> input, CohereCompletionModel model
2929
@Override
3030
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
3131
builder.startObject();
32+
builder.startArray(CohereUtils.V2_MESSAGES_FIELD);
33+
builder.startObject();
34+
builder.field(CohereUtils.V2_ROLE_FIELD, CohereUtils.USER_FIELD);
3235
// we only allow one input for completion, so always get the first one
33-
builder.field(CohereUtils.MESSAGE_FIELD, input.getFirst());
36+
builder.field("content", input.getFirst());
37+
builder.endObject();
38+
builder.endArray();
3439
builder.field(CohereUtils.MODEL_FIELD, getModelId());
3540
builder.field(CohereUtils.STREAM_FIELD, isStreaming());
3641
builder.endObject();

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreatorTests.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,10 @@ public void testCreate_CohereCompletionModel_WithModelSpecified() throws IOExcep
209209
assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), is("Bearer secret"));
210210

211211
var requestMap = entityAsMap(webServer.requests().get(0).getBody());
212-
assertThat(requestMap, is(Map.of("message", "abc", "model", "model", "stream", false)));
212+
assertThat(
213+
requestMap,
214+
is(Map.of("messages", List.of(Map.of("role", "user", "content", "abc")), "model", "model", "stream", false))
215+
);
213216
}
214217
}
215218
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereCompletionActionTests.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,10 @@ public void testExecute_ReturnsSuccessfulResponse_WithModelSpecified() throws IO
132132
);
133133

134134
var requestMap = entityAsMap(webServer.requests().get(0).getBody());
135-
assertThat(requestMap, is(Map.of("message", "abc", "model", "model", "stream", false)));
135+
assertThat(
136+
requestMap,
137+
is(Map.of("messages", List.of(Map.of("role", "user", "content", "abc")), "model", "model", "stream", false))
138+
);
136139
}
137140
}
138141

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2CompletionRequestTests.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,10 @@ public void testCreateRequest() throws IOException {
4646
assertThat(httpPost.getLastHeader(CohereUtils.REQUEST_SOURCE_HEADER).getValue(), is(CohereUtils.ELASTIC_REQUEST_SOURCE));
4747

4848
var requestMap = entityAsMap(httpPost.getEntity().getContent());
49-
assertThat(requestMap, is(Map.of("message", "abc", "model", "required model id", "stream", false)));
49+
assertThat(
50+
requestMap,
51+
is(Map.of("messages", List.of(Map.of("role", "user", "content", "abc")), "model", "required model id", "stream", false))
52+
);
5053
}
5154

5255
public void testDefaultUrl() {
@@ -88,6 +91,6 @@ public void testXContents() throws IOException {
8891
String xContentResult = Strings.toString(builder);
8992

9093
assertThat(xContentResult, CoreMatchers.is("""
91-
{"message":"some input","model":"model","stream":false}"""));
94+
{"messages":[{"role":"user","content":"some input"}],"model":"model","stream":false}"""));
9295
}
9396
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/response/CohereCompletionResponseEntityTests.java

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,42 @@ public void testFromResponse_CreatesResponseEntityForText() throws IOException {
6464
assertThat(chatCompletionResults.getResults().get(0).content(), is("result"));
6565
}
6666

67+
public void testFromResponseV2() throws IOException {
68+
String responseJson = """
69+
{
70+
"id": "abc123",
71+
"finish_reason": "COMPLETE",
72+
"message": {
73+
"role": "assistant",
74+
"content": [
75+
{
76+
"type": "text",
77+
"text": "Response from the llm"
78+
}
79+
]
80+
},
81+
"usage": {
82+
"billed_units": {
83+
"input_tokens": 1,
84+
"output_tokens": 4
85+
},
86+
"tokens": {
87+
"input_tokens": 2,
88+
"output_tokens": 5
89+
}
90+
}
91+
}
92+
""";
93+
94+
ChatCompletionResults chatCompletionResults = CohereCompletionResponseEntity.fromResponse(
95+
mock(Request.class),
96+
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
97+
);
98+
99+
assertThat(chatCompletionResults.getResults().size(), is(1));
100+
assertThat(chatCompletionResults.getResults().get(0).content(), is("Response from the llm"));
101+
}
102+
67103
public void testFromResponse_FailsWhenTextIsNotPresent() {
68104
String responseJson = """
69105
{

0 commit comments

Comments
 (0)