Skip to content

Add AI21 support to Inference Plugin #131238

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 20 commits into
base: main
Choose a base branch
from

Conversation

Jan-Kazlouski-elastic
Copy link
Contributor

@Jan-Kazlouski-elastic Jan-Kazlouski-elastic commented Jul 14, 2025

Creation of new AI21 inference provider integration allowing completion (both streaming and non-streaming) and chat_completion (only streaming) to be executed as part of inference API.

Changes were tested locally against next models:
jamba-large
jamba-mini

Create Completion Endpoint

Success:

PUT {{base-url}}/_inference/completion/ai21-completion
RQ
{
    "service": "ai21",
    "service_settings": {
        "api_key": "{{ai21-api-key}}",
        "model_id": "jamba-mini"
    }
}
RS
{
    "inference_id": "ai21-completion",
    "task_type": "completion",
    "service": "ai21",
    "service_settings": {
        "model_id": "jamba-mini",
        "rate_limit": {
            "requests_per_minute": 200
        }
    }
}

Invalid Model:

PUT {{base-url}}/_inference/completion/ai21-completion
RQ
{
    "service": "ai21",
    "service_settings": {
        "api_key": "{{ai21-api-key}}",
        "model_id": "invalid-model"
    }
}
RS
{
    "error": {
        "root_cause": [
            {
                "type": "status_exception",
                "reason": "Received an input validation error response for request from inference entity id [ai21-completion] status [422]. Error message: [{\"detail\":\"The provided model is not supported. See https://docs.ai21.com/docs/jamba-foundation-models#api-versioning for a list of supported models\"}]"
            }
        ],
        "type": "status_exception",
        "reason": "Could not complete inference endpoint creation as validation call to service threw an exception.",
        "caused_by": {
            "type": "status_exception",
            "reason": "Received an input validation error response for request from inference entity id [ai21-completion] status [422]. Error message: [{\"detail\":\"The provided model is not supported. See https://docs.ai21.com/docs/jamba-foundation-models#api-versioning for a list of supported models\"}]"
        }
    },
    "status": 400
}

Auth Failed:

PUT {{base-url}}/_inference/completion/ai21-completion
RQ
{
    "service": "ai21",
    "service_settings": {
        "api_key": "invalid-key",
        "model_id": "jamba-mini"
    }
}
RS
{
    "error": {
        "root_cause": [
            {
                "type": "status_exception",
                "reason": "Received an authentication error status code for request from inference entity id [ai21-completion] status [401]. Error message: [{\"detail\":\"Forbidden: Bad or missing Apikey/JWT.\"}]"
            }
        ],
        "type": "status_exception",
        "reason": "Could not complete inference endpoint creation as validation call to service threw an exception.",
        "caused_by": {
            "type": "status_exception",
            "reason": "Received an authentication error status code for request from inference entity id [ai21-completion] status [401]. Error message: [{\"detail\":\"Forbidden: Bad or missing Apikey/JWT.\"}]"
        }
    },
    "status": 400
}
Perform Completion

Success Non Streaming:

POST {{base-url}}/_inference/completion/ai21-completion
RQ
{
    "input": "The sky above the port was the color of television tuned to a dead channel."
}
RS
{
    "completion": [
        {
            "result": "That's a striking opening line from William Gibson's **Neuromancer**. It paints a vivid, dystopian image, evoking a sense of decay and the eerie beauty of a world on the edge of technological collapse. What drew you to this line? Are you exploring cyberpunk themes or looking for inspiration for writing?"
        }
    ]
}

Success Streaming:

POST {{base-url}}/_inference/completion/ai21-completion/_stream
RQ
{
    "input": "The sky above the port was the color of television tuned to a dead channel."
}
RS
event: message
data: {"completion":[{"delta":"That"},{"delta":"'"},{"delta":"s"},{"delta":" a"},{"delta":" memorable"}]}

event: message
data: {"completion":[{"delta":" opening"},{"delta":" line"},{"delta":" from"},{"delta":" William"},{"delta":" Gibson"},{"delta":"'"}]}

event: message
data: {"completion":[{"delta":"s"},{"delta":" groundbreaking"},{"delta":" novel"},{"delta":" Ne"},{"delta":"ur"}]}

event: message
data: {"completion":[{"delta":"om"},{"delta":"ancer"},{"delta":"."},{"delta":" It"},{"delta":" sets"},{"delta":" a"}]}

event: message
data: {"completion":[{"delta":" vivid"},{"delta":","},{"delta":" dyst"},{"delta":"opian"}]}

event: message
data: {"completion":[{"delta":" tone"},{"delta":" for"},{"delta":" the"},{"delta":" cyber"},{"delta":"punk"},{"delta":" world"}]}

event: message
data: {"completion":[{"delta":" he"}]}

event: message
data: {"completion":[{"delta":" created"},{"delta":"."},{"delta":" Would"},{"delta":" you"},{"delta":" like"}]}

event: message
data: {"completion":[{"delta":" to"},{"delta":" discuss"},{"delta":" the"},{"delta":" book"}]}

event: message
data: {"completion":[{"delta":","},{"delta":" the"},{"delta":" genre"},{"delta":","},{"delta":" or"}]}

event: message
data: {"completion":[{"delta":" something"},{"delta":" else"},{"delta":" related"},{"delta":" to"},{"delta":" it"}]}

event: message
data: {"completion":[{"delta":"?"}]}

event: message
data: [DONE]

Create Completion Endpoint

Success:

PUT {{base-url}}/_inference/chat_completion/ai21-chat-completion
RQ:
{
    "service": "ai21",
    "service_settings": {
        "api_key": "{{ai21-api-key}}",
        "model_id": "jamba-mini"
    }
}
RS:
{
    "inference_id": "ai21-chat-completion",
    "task_type": "chat_completion",
    "service": "ai21",
    "service_settings": {
        "model_id": "jamba-mini",
        "rate_limit": {
            "requests_per_minute": 200
        }
    }
}

Invalid Model:

PUT {{base-url}}/_inference/chat_completion/ai21-chat-completion
RQ:
{
    "service": "ai21",
    "service_settings": {
        "api_key": "{{ai21-api-key}}",
        "model_id": "invalid-model"
    }
}
RS:
{
    "error": {
        "root_cause": [
            {
                "type": "unified_chat_completion_exception",
                "reason": "Received an input validation error response for request from inference entity id [ai21-chat-completion] status [422]. Error message: [{\"detail\":\"The provided model is not supported. See https://docs.ai21.com/docs/jamba-foundation-models#api-versioning for a list of supported models\"}]"
            }
        ],
        "type": "status_exception",
        "reason": "Could not complete inference endpoint creation as validation call to service threw an exception.",
        "caused_by": {
            "type": "unified_chat_completion_exception",
            "reason": "Received an input validation error response for request from inference entity id [ai21-chat-completion] status [422]. Error message: [{\"detail\":\"The provided model is not supported. See https://docs.ai21.com/docs/jamba-foundation-models#api-versioning for a list of supported models\"}]"
        }
    },
    "status": 400
}

Auth Failed:

PUT {{base-url}}/_inference/chat_completion/ai21-chat-completion
RQ:
{
    "service": "ai21",
    "service_settings": {
        "api_key": "invalid-key",
        "model_id": "jamba-mini"
    }
}
RS:
{
    "error": {
        "root_cause": [
            {
                "type": "unified_chat_completion_exception",
                "reason": "Received an authentication error status code for request from inference entity id [ai21-chat-completion] status [401]. Error message: [{\"detail\":\"Forbidden: Bad or missing Apikey/JWT.\"}]"
            }
        ],
        "type": "status_exception",
        "reason": "Could not complete inference endpoint creation as validation call to service threw an exception.",
        "caused_by": {
            "type": "unified_chat_completion_exception",
            "reason": "Received an authentication error status code for request from inference entity id [ai21-chat-completion] status [401]. Error message: [{\"detail\":\"Forbidden: Bad or missing Apikey/JWT.\"}]"
        }
    },
    "status": 400
}
Perform Chat Completion

Success Simple:

POST {{base-url}}/_inference/chat_completion/ai21-chat-completion/_stream
RQ
{
    "model": "jamba-mini",
    "messages": [
        {
            "role": "user",
            "content": "What is deep learning?"
        }
    ],
    "max_completion_tokens": 10
}
RS
event: message
data: {"id":"chatcmpl-ca8f0d76-4633-50b1-971b-3821e4a9eea1","choices":[{"delta":{"role":"assistant"},"index":0}],"model":null,"object":null}

event: message
data: {"id":"chatcmpl-ca8f0d76-4633-50b1-971b-3821e4a9eea1","choices":[{"delta":{"content":"Deep"},"index":0}],"model":null,"object":null}

event: message
data: {"id":"chatcmpl-ca8f0d76-4633-50b1-971b-3821e4a9eea1","choices":[{"delta":{"content":" learning"},"index":0}],"model":null,"object":null}

event: message
data: {"id":"chatcmpl-ca8f0d76-4633-50b1-971b-3821e4a9eea1","choices":[{"delta":{"content":" is"},"index":0}],"model":null,"object":null}

event: message
data: {"id":"chatcmpl-ca8f0d76-4633-50b1-971b-3821e4a9eea1","choices":[{"delta":{"content":" a"},"index":0}],"model":null,"object":null}

event: message
data: {"id":"chatcmpl-ca8f0d76-4633-50b1-971b-3821e4a9eea1","choices":[{"delta":{"content":" subset"},"index":0}],"model":null,"object":null}

event: message
data: {"id":"chatcmpl-ca8f0d76-4633-50b1-971b-3821e4a9eea1","choices":[{"delta":{"content":" of"},"index":0}],"model":null,"object":null}

event: message
data: {"id":"chatcmpl-ca8f0d76-4633-50b1-971b-3821e4a9eea1","choices":[{"delta":{"content":" machine"},"index":0}],"model":null,"object":null}

event: message
data: {"id":"chatcmpl-ca8f0d76-4633-50b1-971b-3821e4a9eea1","choices":[{"delta":{"content":" learning"},"index":0}],"model":null,"object":null}

event: message
data: {"id":"chatcmpl-ca8f0d76-4633-50b1-971b-3821e4a9eea1","choices":[{"delta":{"content":" that"},"index":0}],"model":null,"object":null}

event: message
data: {"id":"chatcmpl-ca8f0d76-4633-50b1-971b-3821e4a9eea1","choices":[{"delta":{"content":" involves"},"finish_reason":"length","index":0}],"model":null,"object":null,"usage":{"completion_tokens":10,"prompt_tokens":15,"total_tokens":25}}

event: message
data: [DONE]


Success Complex is not supported due to error on AI21 side, requesting content to be string type and not anything else.:

POST https://api.ai21.com/studio/v1/chat/completions
RQ:
{
    "model": "jamba-mini",
    "max_tokens": 10,
    "messages": [{
            "role": "user",
            "content": [{
                    "type": "text",
                    "text": "What's the price of a scarf?"
                }
            ]
        }
    ],
    "tools": [{
            "type": "function",
            "function": {
                "name": "get_current_price",
                "description": "Get the current price of a item",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "item": {
                            "id": "123"
                        }
                    }
                }
            }
        }
    ],
    "tool_choice": {
        "type": "function",
        "function": {
            "name": "get_current_price"
        }
    }
}
RS:
{
    "detail": [
        {
            "loc": [
                "body",
                "messages",
                0,
                "UserMessage",
                "content"
            ],
            "msg": "str type expected",
            "type": "type_error.str"
        }
    ]
}

ES Inference:

POST {{base-url}}/_inference/chat_completion/ai21-chat-completion/_stream
RQ:
{
    "model": "llama3.2:3b",
    "messages": [{
            "role": "user",
            "content": [{
                    "type": "text",
                    "text": "What's the price of a scarf?"
                }
            ]
        }
    ],
    "tools": [{
            "type": "function",
            "function": {
                "name": "get_current_price",
                "description": "Get the current price of a item",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "item": {
                            "id": "123"
                        }
                    }
                }
            }
        }
    ],
    "tool_choice": {
        "type": "function",
        "function": {
            "name": "get_current_price"
        }
    }
}
RS:
{
    "error": {
        "code": "unprocessable_entity",
        "message": "Received an input validation error response for request from inference entity id [ai21-chat-completion] status [422]. Error message: [{\"detail\":[{\"loc\":[\"body\",\"messages\",0,\"UserMessage\",\"content\"],\"msg\":\"str type expected\",\"type\":\"type_error.str\"}]}]",
        "type": "ai21_error"
    }
}
  • - Have you signed the contributor license agreement?
  • - Have you followed the contributor guidelines?
  • - If submitting code, have you built your formula locally prior to submission with gradle check?
  • - If submitting code, is your pull request against main? Unless there is a good reason otherwise, we prefer pull requests against main and will backport as needed.
  • - If submitting code, have you checked that your submission is for an OS and architecture that we support?
  • - If you are submitting this code for a class then read our policy for that.

@Jan-Kazlouski-elastic Jan-Kazlouski-elastic changed the title Ai21 chat completion Add AI21 support to Inference Plugin Jul 14, 2025
@elasticsearchmachine elasticsearchmachine added v9.2.0 external-contributor Pull request authored by a developer outside the Elasticsearch team labels Jul 14, 2025
…etionRequestEntity and Ai21ChatCompletionRequest
@Jan-Kazlouski-elastic Jan-Kazlouski-elastic marked this pull request as ready for review July 16, 2025 15:35
@elasticsearchmachine elasticsearchmachine added the needs:triage Requires assignment of a team area label label Jul 16, 2025
# Conflicts:
#	server/src/main/java/org/elasticsearch/TransportVersions.java
#	x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java
#	x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedChatCompletionResponseHandler.java
#	x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralUnifiedChatCompletionResponseHandler.java
@gbanasiak gbanasiak added the :ml Machine learning label Jul 21, 2025
@elasticsearchmachine elasticsearchmachine added Team:ML Meta label for the ML team and removed needs:triage Requires assignment of a team area label labels Jul 21, 2025
@elasticsearchmachine
Copy link
Collaborator

Pinging @elastic/ml-core (Team:ML)

Copy link
Contributor

@jonathan-buttner jonathan-buttner left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good, left a few suggestions

* @return an ErrorResponse instance
*/
public static ErrorResponse fromString(String response) {
if (Objects.nonNull(response) && response.isBlank() == false) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use Strings.isEmpty() == false

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Used isNullOrBlank(). Let me know if it is acceptable.

) {
var actionCreator = new Ai21ActionCreator(getSender(), getServiceComponents());

if (Objects.requireNonNull(model) instanceof Ai21ChatCompletionModel mistralChatCompletionModel) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's rename this to ai21

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. Missed that. Fixed now!

public InferenceServiceConfiguration getConfiguration() {
return Configuration.get();
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's disable this one from showing up in Kibana like we did for llama.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

* This class extends RateLimitGroupingModel to handle rate limiting based on model and API key.
*/
public class Ai21ChatCompletionModel extends Ai21Model {
public static final String API_COMPLETIONS_PATH = "https://api.ai21.com/studio/v1/chat/completions";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: How about we define this as parts:

URIBuilder().setScheme("https")
            .setHost(OpenAiUtils.HOST)
            .setPathSegments(OpenAiUtils.VERSION_1, OpenAiUtils.EMBEDDINGS_PATH)
            .build();

Does this need to be public?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is being used in Ai21ChatCompletionRequestTests, that is causing it being public.
I replaced constant url with url building.

* This class is responsible for creating a request to the AI21 chat completion model.
* It constructs an HTTP POST request with the necessary headers and body content.
*/
public class Ai21ChatCompletionRequest implements Request {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the same as openai? If so how about we just use the OpenAI request and omit this class?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not the same as OpenAI. OpenAI has organizationId logic in it. Model is different and inherits from different class hierarchy.

* Ai21ChatCompletionRequestEntity is responsible for creating the request entity for Ai21 chat completion.
* It implements ToXContentObject to allow serialization to XContent format.
*/
public class Ai21ChatCompletionRequestEntity implements ToXContentObject {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as the openai question, let's see if we can just the openai class.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OpenAI is not usable here. OpenAI has logic for "user" field in task settings and also uses different model that inherits from different class hierarchy.

@@ -124,7 +124,7 @@ public static UnifiedChatCompletionErrorResponse fromResponse(HttpResult respons
}
}

static UnifiedChatCompletionErrorResponse fromString(String response) {
public static UnifiedChatCompletionErrorResponse fromString(String response) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious why we need this as public?

Copy link
Contributor Author

@Jan-Kazlouski-elastic Jan-Kazlouski-elastic Jul 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because we're adding fromString method to base ErrorResponse class and it is public and having this method here as package private would assign weaker access privileges which is not allowed. We would get compilation error.

@@ -141,7 +141,7 @@ private static class StreamingHuggingFaceErrorResponseEntity extends ErrorRespon
* @param response the raw JSON string representing an error
* @return a parsed {@link ErrorResponse} or {@link ErrorResponse#UNDEFINED_ERROR} if parsing fails
*/
private static ErrorResponse fromString(String response) {
public static ErrorResponse fromString(String response) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious why this needs to be public?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same thing here. We need to keep access privileges consistent.

@@ -1,29 +0,0 @@
/*
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a reminder to try separate refactoring and new functionality PRs. If the refactoring is required, then including in the new functionality PR is generally ok if the refactoring is small. I'd still say if possible, default to creating a dedicated refactoring PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Understood. Thank you.

@@ -138,8 +138,8 @@ public static class ChatCompletionChunkParser {
(p, c) -> ChatCompletionChunkParser.ChoiceParser.parse(p),
new ParseField(CHOICES_FIELD)
);
PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField(MODEL_FIELD));
PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField(OBJECT_FIELD));
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(MODEL_FIELD));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just wondering if you could comment on the change here. I assume AI21's response format doesn't always include these?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is true. AI21's streaming response doesn't set these fields.

# Conflicts:
#	server/src/main/java/org/elasticsearch/TransportVersions.java
#	x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java
#	x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java
@Jan-Kazlouski-elastic
Copy link
Contributor Author

@jonathan-buttner your comments are addressed.

@jonathan-buttner
Copy link
Contributor

@elasticmachine test this please

# Conflicts:
#	server/src/main/java/org/elasticsearch/TransportVersions.java
# Conflicts:
#	server/src/main/java/org/elasticsearch/TransportVersions.java
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
>enhancement external-contributor Pull request authored by a developer outside the Elasticsearch team :ml Machine learning Team:ML Meta label for the ML team v9.2.0
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants