|
19 | 19 | import org.elasticsearch.inference.SimilarityMeasure;
|
20 | 20 | import org.elasticsearch.inference.TaskType;
|
21 | 21 | import org.elasticsearch.test.ESTestCase;
|
22 |
| -import org.elasticsearch.test.http.MockResponse; |
23 | 22 | import org.elasticsearch.test.http.MockWebServer;
|
24 | 23 | import org.elasticsearch.threadpool.ThreadPool;
|
25 | 24 | import org.elasticsearch.xpack.core.inference.action.InferenceAction;
|
26 |
| -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; |
27 | 25 | import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
|
28 | 26 | import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
|
29 | 27 | import org.elasticsearch.xpack.inference.services.custom.CustomModel;
|
30 |
| -import org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser; |
31 | 28 | import org.junit.After;
|
32 | 29 | import org.junit.Assume;
|
33 | 30 | import org.junit.Before;
|
|
45 | 42 | import static org.elasticsearch.xpack.inference.Utils.getRequestConfigMap;
|
46 | 43 | import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
|
47 | 44 | import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
|
48 |
| -import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; |
49 | 45 | import static org.hamcrest.Matchers.containsString;
|
50 |
| -import static org.hamcrest.Matchers.instanceOf; |
51 | 46 | import static org.hamcrest.Matchers.is;
|
52 | 47 | import static org.mockito.Mockito.mock;
|
53 | 48 |
|
| 49 | +/** |
| 50 | + * Base class for testing inference services. |
| 51 | + * <p> |
| 52 | + * This class provides common unit tests for inference services, such as testing the model creation, and calling the infer method. |
| 53 | + * |
| 54 | + * To use this class, extend it and pass the constructor a configuration. |
| 55 | + * </p> |
| 56 | + */ |
54 | 57 | public abstract class AbstractServiceTests extends ESTestCase {
|
55 | 58 |
|
56 | 59 | protected final MockWebServer webServer = new MockWebServer();
|
@@ -125,8 +128,6 @@ public CommonConfig(TaskType taskType, @Nullable TaskType unsupportedTaskType) {
|
125 | 128 |
|
126 | 129 | protected abstract Map<String, Object> createSecretSettingsMap();
|
127 | 130 |
|
128 |
| - protected abstract CustomModel createEmbeddingModel(TextEmbeddingResponseParser embeddingResponseParser, String url); |
129 |
| - |
130 | 131 | protected abstract void assertModel(Model model, TaskType taskType);
|
131 | 132 |
|
132 | 133 | protected abstract EnumSet<TaskType> supportedStreamingTasks();
|
@@ -434,62 +435,6 @@ public void testParsePersistedConfigWithSecrets_ThrowsWhenAnExtraKeyExistsInSecr
|
434 | 435 | }
|
435 | 436 | }
|
436 | 437 |
|
437 |
| - // infer tests |
438 |
| - |
439 |
| - public void testInfer_SendsRequest() throws IOException { |
440 |
| - try (var service = testConfiguration.commonConfig.createService(threadPool, clientManager)) { |
441 |
| - String responseJson = """ |
442 |
| - { |
443 |
| - "object": "list", |
444 |
| - "data": [ |
445 |
| - { |
446 |
| - "object": "embedding", |
447 |
| - "index": 0, |
448 |
| - "embedding": [ |
449 |
| - 0.0123, |
450 |
| - -0.0123 |
451 |
| - ] |
452 |
| - } |
453 |
| - ], |
454 |
| - "model": "text-embedding-ada-002-v2", |
455 |
| - "usage": { |
456 |
| - "prompt_tokens": 8, |
457 |
| - "total_tokens": 8 |
458 |
| - } |
459 |
| - } |
460 |
| - """; |
461 |
| - |
462 |
| - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); |
463 |
| - |
464 |
| - var model = testConfiguration.commonConfig.createEmbeddingModel( |
465 |
| - new TextEmbeddingResponseParser("$.data[*].embedding"), |
466 |
| - getUrl(webServer) |
467 |
| - ); |
468 |
| - PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>(); |
469 |
| - service.infer( |
470 |
| - model, |
471 |
| - null, |
472 |
| - null, |
473 |
| - null, |
474 |
| - List.of("test input"), |
475 |
| - false, |
476 |
| - new HashMap<>(), |
477 |
| - InputType.INTERNAL_SEARCH, |
478 |
| - InferenceAction.Request.DEFAULT_TIMEOUT, |
479 |
| - listener |
480 |
| - ); |
481 |
| - |
482 |
| - InferenceServiceResults results = listener.actionGet(TIMEOUT); |
483 |
| - assertThat(results, instanceOf(TextEmbeddingFloatResults.class)); |
484 |
| - |
485 |
| - var embeddingResults = (TextEmbeddingFloatResults) results; |
486 |
| - assertThat( |
487 |
| - embeddingResults.embeddings(), |
488 |
| - is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.0123F, -0.0123F }))) |
489 |
| - ); |
490 |
| - } |
491 |
| - } |
492 |
| - |
493 | 438 | public void testInfer_ThrowsErrorWhenModelIsNotValid() throws IOException {
|
494 | 439 | try (var service = testConfiguration.commonConfig.createService(threadPool, clientManager)) {
|
495 | 440 | var listener = new PlainActionFuture<InferenceServiceResults>();
|
|
0 commit comments