-
Notifications
You must be signed in to change notification settings - Fork 25.3k
Add AI21 support to Inference Plugin #131238
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add AI21 support to Inference Plugin #131238
Conversation
# Conflicts: # server/src/main/java/org/elasticsearch/TransportVersions.java
…etionRequestEntity and Ai21ChatCompletionRequest
# Conflicts: # server/src/main/java/org/elasticsearch/TransportVersions.java # x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java # x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedChatCompletionResponseHandler.java # x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralUnifiedChatCompletionResponseHandler.java
…arity and functionality
Pinging @elastic/ml-core (Team:ML) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking good, left a few suggestions
* @return an ErrorResponse instance | ||
*/ | ||
public static ErrorResponse fromString(String response) { | ||
if (Objects.nonNull(response) && response.isBlank() == false) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's use Strings.isEmpty() == false
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Used isNullOrBlank(). Let me know if it is acceptable.
) { | ||
var actionCreator = new Ai21ActionCreator(getSender(), getServiceComponents()); | ||
|
||
if (Objects.requireNonNull(model) instanceof Ai21ChatCompletionModel mistralChatCompletionModel) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's rename this to ai21
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch. Missed that. Fixed now!
public InferenceServiceConfiguration getConfiguration() { | ||
return Configuration.get(); | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's disable this one from showing up in Kibana like we did for llama.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
* This class extends RateLimitGroupingModel to handle rate limiting based on model and API key. | ||
*/ | ||
public class Ai21ChatCompletionModel extends Ai21Model { | ||
public static final String API_COMPLETIONS_PATH = "https://api.ai21.com/studio/v1/chat/completions"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: How about we define this as parts:
URIBuilder().setScheme("https")
.setHost(OpenAiUtils.HOST)
.setPathSegments(OpenAiUtils.VERSION_1, OpenAiUtils.EMBEDDINGS_PATH)
.build();
Does this need to be public?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is being used in Ai21ChatCompletionRequestTests, that is causing it being public.
I replaced constant url with url building.
* This class is responsible for creating a request to the AI21 chat completion model. | ||
* It constructs an HTTP POST request with the necessary headers and body content. | ||
*/ | ||
public class Ai21ChatCompletionRequest implements Request { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this the same as openai? If so how about we just use the OpenAI request and omit this class?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is not the same as OpenAI. OpenAI has organizationId logic in it. Model is different and inherits from different class hierarchy.
* Ai21ChatCompletionRequestEntity is responsible for creating the request entity for Ai21 chat completion. | ||
* It implements ToXContentObject to allow serialization to XContent format. | ||
*/ | ||
public class Ai21ChatCompletionRequestEntity implements ToXContentObject { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as the openai question, let's see if we can just the openai class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OpenAI is not usable here. OpenAI has logic for "user" field in task settings and also uses different model that inherits from different class hierarchy.
@@ -124,7 +124,7 @@ public static UnifiedChatCompletionErrorResponse fromResponse(HttpResult respons | |||
} | |||
} | |||
|
|||
static UnifiedChatCompletionErrorResponse fromString(String response) { | |||
public static UnifiedChatCompletionErrorResponse fromString(String response) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious why we need this as public?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because we're adding fromString method to base ErrorResponse class and it is public and having this method here as package private would assign weaker access privileges which is not allowed. We would get compilation error.
@@ -141,7 +141,7 @@ private static class StreamingHuggingFaceErrorResponseEntity extends ErrorRespon | |||
* @param response the raw JSON string representing an error | |||
* @return a parsed {@link ErrorResponse} or {@link ErrorResponse#UNDEFINED_ERROR} if parsing fails | |||
*/ | |||
private static ErrorResponse fromString(String response) { | |||
public static ErrorResponse fromString(String response) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious why this needs to be public?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same thing here. We need to keep access privileges consistent.
@@ -1,29 +0,0 @@ | |||
/* |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a reminder to try separate refactoring and new functionality PRs. If the refactoring is required, then including in the new functionality PR is generally ok if the refactoring is small. I'd still say if possible, default to creating a dedicated refactoring PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Understood. Thank you.
@@ -138,8 +138,8 @@ public static class ChatCompletionChunkParser { | |||
(p, c) -> ChatCompletionChunkParser.ChoiceParser.parse(p), | |||
new ParseField(CHOICES_FIELD) | |||
); | |||
PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField(MODEL_FIELD)); | |||
PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField(OBJECT_FIELD)); | |||
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(MODEL_FIELD)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just wondering if you could comment on the change here. I assume AI21's response format doesn't always include these?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is true. AI21's streaming response doesn't set these fields.
# Conflicts: # server/src/main/java/org/elasticsearch/TransportVersions.java # x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java # x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java
@jonathan-buttner your comments are addressed. |
@elasticmachine test this please |
# Conflicts: # server/src/main/java/org/elasticsearch/TransportVersions.java
# Conflicts: # server/src/main/java/org/elasticsearch/TransportVersions.java
Creation of new AI21 inference provider integration allowing completion (both streaming and non-streaming) and chat_completion (only streaming) to be executed as part of inference API.
Changes were tested locally against next models:
jamba-large
jamba-mini
Create Completion Endpoint
Success:
Invalid Model:
Auth Failed:
Perform Completion
Success Non Streaming:
Success Streaming:
Create Completion Endpoint
Success:
Invalid Model:
Auth Failed:
Perform Chat Completion
Success Simple:
Success Complex is not supported due to error on AI21 side, requesting content to be string type and not anything else.:
ES Inference:
gradle check
?