Skip to content

[8.19] Add Hugging Face Chat Completion support to Inference Plugin (#127254) #128152

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/127254.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 127254
summary: "[ML] Add HuggingFace Chat Completion support to the Inference Plugin"
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ static TransportVersion def(int id) {
public static final TransportVersion INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT_8_19 = def(8_841_0_28);
public static final TransportVersion ESQL_REPORT_SHARD_PARTITIONING_8_19 = def(8_841_0_29);
public static final TransportVersion ESQL_DRIVER_TASK_DESCRIPTION_8_19 = def(8_841_0_30);
public static final TransportVersion ML_INFERENCE_HUGGING_FACE_CHAT_COMPLETION_ADDED_8_19 = def(8_841_0_31);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ public void testGetServicesWithRerankTaskType() throws IOException {

public void testGetServicesWithCompletionTaskType() throws IOException {
List<Object> services = getServices(TaskType.COMPLETION);
assertThat(services.size(), equalTo(10));
assertThat(services.size(), equalTo(11));

var providers = providers(services);

Expand All @@ -140,19 +140,23 @@ public void testGetServicesWithCompletionTaskType() throws IOException {
"deepseek",
"googleaistudio",
"openai",
"streaming_completion_test_service"
"streaming_completion_test_service",
"hugging_face"
).toArray()
)
);
}

public void testGetServicesWithChatCompletionTaskType() throws IOException {
List<Object> services = getServices(TaskType.CHAT_COMPLETION);
assertThat(services.size(), equalTo(4));
assertThat(services.size(), equalTo(5));

var providers = providers(services);

assertThat(providers, containsInAnyOrder(List.of("deepseek", "elastic", "openai", "streaming_completion_test_service").toArray()));
assertThat(
providers,
containsInAnyOrder(List.of("deepseek", "elastic", "openai", "streaming_completion_test_service", "hugging_face").toArray())
);
}

public void testGetServicesWithSparseEmbeddingTaskType() throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankServiceSettings;
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankTaskSettings;
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings;
import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserServiceSettings;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankServiceSettings;
Expand Down Expand Up @@ -357,6 +358,13 @@ private static void addHuggingFaceNamedWriteables(List<NamedWriteableRegistry.En
namedWriteables.add(
new NamedWriteableRegistry.Entry(ServiceSettings.class, HuggingFaceServiceSettings.NAME, HuggingFaceServiceSettings::new)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
HuggingFaceChatCompletionServiceSettings.NAME,
HuggingFaceChatCompletionServiceSettings::new
)
);
}

private static void addGoogleAiStudioNamedWritables(List<NamedWriteableRegistry.Entry> namedWriteables) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services.huggingface;

import org.elasticsearch.core.Nullable;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.services.huggingface.response.HuggingFaceErrorResponseEntity;
import org.elasticsearch.xpack.inference.services.openai.OpenAiUnifiedChatCompletionResponseHandler;

import java.util.Locale;
import java.util.Optional;

import static org.elasticsearch.core.Strings.format;

/**
* Handles streaming chat completion responses and error parsing for Hugging Face inference endpoints.
* Adapts the OpenAI handler to support Hugging Face's simpler error schema with fields like "message" and "http_status_code".
*/
public class HuggingFaceChatCompletionResponseHandler extends OpenAiUnifiedChatCompletionResponseHandler {

private static final String HUGGING_FACE_ERROR = "hugging_face_error";

public HuggingFaceChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) {
super(requestType, parseFunction, HuggingFaceErrorResponseEntity::fromResponse);
}

@Override
protected Exception buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) {
assert request.isStreaming() : "Only streaming requests support this format";
var responseStatusCode = result.response().getStatusLine().getStatusCode();
if (request.isStreaming()) {
var errorMessage = errorMessage(message, request, result, errorResponse, responseStatusCode);
var restStatus = toRestStatus(responseStatusCode);
return errorResponse instanceof HuggingFaceErrorResponseEntity
? new UnifiedChatCompletionException(
restStatus,
errorMessage,
HUGGING_FACE_ERROR,
restStatus.name().toLowerCase(Locale.ROOT)
)
: new UnifiedChatCompletionException(
restStatus,
errorMessage,
createErrorType(errorResponse),
restStatus.name().toLowerCase(Locale.ROOT)
);
} else {
return super.buildError(message, request, result, errorResponse);
}
}

@Override
protected Exception buildMidStreamError(Request request, String message, Exception e) {
var errorResponse = StreamingHuggingFaceErrorResponseEntity.fromString(message);
if (errorResponse instanceof StreamingHuggingFaceErrorResponseEntity streamingHuggingFaceErrorResponseEntity) {
return new UnifiedChatCompletionException(
RestStatus.INTERNAL_SERVER_ERROR,
format(
"%s for request from inference entity id [%s]. Error message: [%s]",
SERVER_ERROR_OBJECT,
request.getInferenceEntityId(),
errorResponse.getErrorMessage()
),
HUGGING_FACE_ERROR,
extractErrorCode(streamingHuggingFaceErrorResponseEntity)
);
} else if (e != null) {
return UnifiedChatCompletionException.fromThrowable(e);
} else {
return new UnifiedChatCompletionException(
RestStatus.INTERNAL_SERVER_ERROR,
format("%s for request from inference entity id [%s]", SERVER_ERROR_OBJECT, request.getInferenceEntityId()),
createErrorType(errorResponse),
"stream_error"
);
}
}

private static String extractErrorCode(StreamingHuggingFaceErrorResponseEntity streamingHuggingFaceErrorResponseEntity) {
return streamingHuggingFaceErrorResponseEntity.httpStatusCode() != null
? String.valueOf(streamingHuggingFaceErrorResponseEntity.httpStatusCode())
: null;
}

/**
* Represents a structured error response specifically for streaming operations
* using HuggingFace APIs. This is separate from non-streaming error responses,
* which are handled by {@link HuggingFaceErrorResponseEntity}.
* An example error response for failed field validation for streaming operation would look like
* <code>
* {
* "error": "Input validation error: cannot compile regex from schema",
* "http_status_code": 422
* }
* </code>
*/
private static class StreamingHuggingFaceErrorResponseEntity extends ErrorResponse {
private static final ConstructingObjectParser<Optional<ErrorResponse>, Void> ERROR_PARSER = new ConstructingObjectParser<>(
HUGGING_FACE_ERROR,
true,
args -> Optional.ofNullable((StreamingHuggingFaceErrorResponseEntity) args[0])
);
private static final ConstructingObjectParser<StreamingHuggingFaceErrorResponseEntity, Void> ERROR_BODY_PARSER =
new ConstructingObjectParser<>(
HUGGING_FACE_ERROR,
true,
args -> new StreamingHuggingFaceErrorResponseEntity(args[0] != null ? (String) args[0] : "unknown", (Integer) args[1])
);

static {
ERROR_BODY_PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField("message"));
ERROR_BODY_PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), new ParseField("http_status_code"));

ERROR_PARSER.declareObjectOrNull(
ConstructingObjectParser.optionalConstructorArg(),
ERROR_BODY_PARSER,
null,
new ParseField("error")
);
}

/**
* Parses a streaming HuggingFace error response from a JSON string.
*
* @param response the raw JSON string representing an error
* @return a parsed {@link ErrorResponse} or {@link ErrorResponse#UNDEFINED_ERROR} if parsing fails
*/
private static ErrorResponse fromString(String response) {
try (
XContentParser parser = XContentFactory.xContent(XContentType.JSON)
.createParser(XContentParserConfiguration.EMPTY, response)
) {
return ERROR_PARSER.apply(parser, null).orElse(ErrorResponse.UNDEFINED_ERROR);
} catch (Exception e) {
// swallow the error
}

return ErrorResponse.UNDEFINED_ERROR;
}

@Nullable
private final Integer httpStatusCode;

StreamingHuggingFaceErrorResponseEntity(String errorMessage, @Nullable Integer httpStatusCode) {
super(errorMessage);
this.httpStatusCode = httpStatusCode;
}

@Nullable
public Integer httpStatusCode() {
return httpStatusCode;
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,18 @@

import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel;
import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.inference.services.huggingface.action.HuggingFaceActionVisitor;
import org.elasticsearch.xpack.inference.services.settings.ApiKeySecrets;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;

import java.util.Objects;

public abstract class HuggingFaceModel extends Model {
public abstract class HuggingFaceModel extends RateLimitGroupingModel {
private final HuggingFaceRateLimitServiceSettings rateLimitServiceSettings;
private final SecureString apiKey;

Expand All @@ -38,6 +39,16 @@ public HuggingFaceRateLimitServiceSettings rateLimitServiceSettings() {
return rateLimitServiceSettings;
}

@Override
public int rateLimitGroupingHash() {
return Objects.hash(rateLimitServiceSettings.uri(), apiKey);
}

@Override
public RateLimitSettings rateLimitSettings() {
return rateLimitServiceSettings.rateLimitSettings();
}

public SecureString apiKey() {
return apiKey;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
import org.elasticsearch.xpack.inference.external.http.sender.ExecutableInferenceRequest;
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
import org.elasticsearch.xpack.inference.services.huggingface.request.HuggingFaceInferenceRequest;
import org.elasticsearch.xpack.inference.services.huggingface.request.embeddings.HuggingFaceEmbeddingsRequest;

import java.util.List;
import java.util.Objects;
Expand Down Expand Up @@ -64,7 +64,7 @@ public void execute(
) {
List<String> docsInput = EmbeddingsInput.of(inferenceInputs).getStringInputs();
var truncatedInput = truncate(docsInput, model.getTokenLimit());
var request = new HuggingFaceInferenceRequest(truncator, truncatedInput, model);
var request = new HuggingFaceEmbeddingsRequest(truncator, truncatedInput, model);

execute(new ExecutableInferenceRequest(requestSender, logger, request, responseHandler, hasRequestCompletedFunction, listener));
}
Expand Down
Loading