Skip to content

Commit 59f75b9

Browse files
Cleaning up
1 parent dc02425 commit 59f75b9

File tree

7 files changed

+323
-81
lines changed

7 files changed

+323
-81
lines changed

docs/changelog/125679.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
pr: 125679
2-
summary: Custom Inference Service
2+
summary: Adding support for generic Inference services
33
area: Machine Learning
44
type: enhancement
55
issues: []

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,6 @@ public static ElasticsearchStatusException unknownSettingsError(Map<String, Obje
214214
}
215215

216216
public static ElasticsearchStatusException unknownSettingsError(Map<String, Object> config, String field, String scope) {
217-
// TODO map as JSON
218217
return new ElasticsearchStatusException(
219218
"Model configuration contains unknown settings [{}] while parsing field [{}] for settings [{}]",
220219
RestStatus.BAD_REQUEST,
@@ -546,15 +545,15 @@ public static List<Tuple<String, String>> extractOptionalListOfStringTuples(
546545

547546
var firstElement = listEntry.get(0);
548547
var secondElement = listEntry.get(1);
549-
validateTuple(firstElement, settingName, scope, "the first element", tuplesIndex, validationException);
550-
validateTuple(secondElement, settingName, scope, "the second element", tuplesIndex, validationException);
548+
validateString(firstElement, settingName, scope, "the first element", tuplesIndex, validationException);
549+
validateString(secondElement, settingName, scope, "the second element", tuplesIndex, validationException);
551550
tuples.add(new Tuple<>((String) firstElement, (String) secondElement));
552551
}
553552

554553
return tuples;
555554
}
556555

557-
private static void validateTuple(
556+
private static void validateString(
558557
Object tupleValue,
559558
String settingName,
560559
String scope,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomResponseHandler.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import org.elasticsearch.xpack.inference.services.custom.response.ErrorResponseParser;
1919

2020
/**
21-
* Defines how to handle various errors returned from the custom integration.
21+
* Defines how to handle various response types returned from the custom integration.
2222
*/
2323
public class CustomResponseHandler extends BaseResponseHandler {
2424
public CustomResponseHandler(String requestType, ResponseParser parseFunction, ErrorResponseParser errorParser) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ public CustomServiceSettings(StreamInput in) throws IOException {
197197

198198
@Override
199199
public String modelId() {
200-
// returning null because the model id is embedded in the url
200+
// returning null because the model id is embedded in the url or the request body
201201
return null;
202202
}
203203

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractServiceTests.java

Lines changed: 8 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,12 @@
1919
import org.elasticsearch.inference.SimilarityMeasure;
2020
import org.elasticsearch.inference.TaskType;
2121
import org.elasticsearch.test.ESTestCase;
22-
import org.elasticsearch.test.http.MockResponse;
2322
import org.elasticsearch.test.http.MockWebServer;
2423
import org.elasticsearch.threadpool.ThreadPool;
2524
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
26-
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
2725
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
2826
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
2927
import org.elasticsearch.xpack.inference.services.custom.CustomModel;
30-
import org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser;
3128
import org.junit.After;
3229
import org.junit.Assume;
3330
import org.junit.Before;
@@ -45,12 +42,18 @@
4542
import static org.elasticsearch.xpack.inference.Utils.getRequestConfigMap;
4643
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
4744
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
48-
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
4945
import static org.hamcrest.Matchers.containsString;
50-
import static org.hamcrest.Matchers.instanceOf;
5146
import static org.hamcrest.Matchers.is;
5247
import static org.mockito.Mockito.mock;
5348

49+
/**
50+
* Base class for testing inference services.
51+
* <p>
52+
* This class provides common unit tests for inference services, such as testing the model creation, and calling the infer method.
53+
*
54+
* To use this class, extend it and pass the constructor a configuration.
55+
* </p>
56+
*/
5457
public abstract class AbstractServiceTests extends ESTestCase {
5558

5659
protected final MockWebServer webServer = new MockWebServer();
@@ -125,8 +128,6 @@ public CommonConfig(TaskType taskType, @Nullable TaskType unsupportedTaskType) {
125128

126129
protected abstract Map<String, Object> createSecretSettingsMap();
127130

128-
protected abstract CustomModel createEmbeddingModel(TextEmbeddingResponseParser embeddingResponseParser, String url);
129-
130131
protected abstract void assertModel(Model model, TaskType taskType);
131132

132133
protected abstract EnumSet<TaskType> supportedStreamingTasks();
@@ -434,62 +435,6 @@ public void testParsePersistedConfigWithSecrets_ThrowsWhenAnExtraKeyExistsInSecr
434435
}
435436
}
436437

437-
// infer tests
438-
439-
public void testInfer_SendsRequest() throws IOException {
440-
try (var service = testConfiguration.commonConfig.createService(threadPool, clientManager)) {
441-
String responseJson = """
442-
{
443-
"object": "list",
444-
"data": [
445-
{
446-
"object": "embedding",
447-
"index": 0,
448-
"embedding": [
449-
0.0123,
450-
-0.0123
451-
]
452-
}
453-
],
454-
"model": "text-embedding-ada-002-v2",
455-
"usage": {
456-
"prompt_tokens": 8,
457-
"total_tokens": 8
458-
}
459-
}
460-
""";
461-
462-
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
463-
464-
var model = testConfiguration.commonConfig.createEmbeddingModel(
465-
new TextEmbeddingResponseParser("$.data[*].embedding"),
466-
getUrl(webServer)
467-
);
468-
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
469-
service.infer(
470-
model,
471-
null,
472-
null,
473-
null,
474-
List.of("test input"),
475-
false,
476-
new HashMap<>(),
477-
InputType.INTERNAL_SEARCH,
478-
InferenceAction.Request.DEFAULT_TIMEOUT,
479-
listener
480-
);
481-
482-
InferenceServiceResults results = listener.actionGet(TIMEOUT);
483-
assertThat(results, instanceOf(TextEmbeddingFloatResults.class));
484-
485-
var embeddingResults = (TextEmbeddingFloatResults) results;
486-
assertThat(
487-
embeddingResults.embeddings(),
488-
is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.0123F, -0.0123F })))
489-
);
490-
}
491-
}
492-
493438
public void testInfer_ThrowsErrorWhenModelIsNotValid() throws IOException {
494439
try (var service = testConfiguration.commonConfig.createService(threadPool, clientManager)) {
495440
var listener = new PlainActionFuture<InferenceServiceResults>();

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomModelTests.java

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,12 @@
2525
import static org.hamcrest.Matchers.is;
2626

2727
public class CustomModelTests extends ESTestCase {
28-
public static String taskSettingsKey = "test_taskSettings_key";
29-
public static String taskSettingsValue = "test_taskSettings_value";
28+
private static final String taskSettingsKey = "test_taskSettings_key";
29+
private static final String taskSettingsValue = "test_taskSettings_value";
3030

31-
public static String secretSettingsKey = "test_secret_key";
32-
public static SerializableSecureString secretSettingsValue = new SerializableSecureString("test_secret_value");
33-
public static String url = "http://www.abc.com";
34-
public static String path = "/endpoint";
31+
private static final String secretSettingsKey = "test_secret_key";
32+
private static final SerializableSecureString secretSettingsValue = new SerializableSecureString("test_secret_value");
33+
private static final String url = "http://www.abc.com";
3534

3635
public void testOverride_DoesNotModifiedFields_TaskSettingsIsEmpty() {
3736
var model = createModel(

0 commit comments

Comments
 (0)