Skip to content

Commit 91646ed

Browse files
committed
address review comments
1 parent 539f468 commit 91646ed

File tree

7 files changed

+59
-45
lines changed

7 files changed

+59
-45
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/GetRerankerAction.java renamed to x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/GetRerankerWindowSizeAction.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,12 @@
1717
import java.io.IOException;
1818
import java.util.Objects;
1919

20-
public class GetRerankerAction extends ActionType<GetRerankerAction.Response> {
20+
public class GetRerankerWindowSizeAction extends ActionType<GetRerankerWindowSizeAction.Response> {
2121

22-
public static final GetRerankerAction INSTANCE = new GetRerankerAction();
23-
public static final String NAME = "cluster:internal/xpack/inference/rerank/get";
22+
public static final GetRerankerWindowSizeAction INSTANCE = new GetRerankerWindowSizeAction();
23+
public static final String NAME = "cluster:internal/xpack/inference/rerankwindowsize/get";
2424

25-
public GetRerankerAction() {
25+
public GetRerankerWindowSizeAction() {
2626
super(NAME);
2727
}
2828

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/RerankWindowSizeIT.java

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import org.elasticsearch.plugins.Plugin;
1212
import org.elasticsearch.test.ESIntegTestCase;
1313
import org.elasticsearch.test.ESTestCase;
14-
import org.elasticsearch.xpack.core.inference.action.GetRerankerAction;
14+
import org.elasticsearch.xpack.core.inference.action.GetRerankerWindowSizeAction;
1515
import org.elasticsearch.xpack.inference.LocalStateInferencePlugin;
1616
import org.elasticsearch.xpack.inference.Utils;
1717
import org.elasticsearch.xpack.inference.mock.TestInferenceServicePlugin;
@@ -39,15 +39,17 @@ protected Collection<Class<? extends Plugin>> nodePlugins() {
3939
}
4040

4141
public void testRerankWindowSizeAction() {
42-
var response = client().execute(GetRerankerAction.INSTANCE, new GetRerankerAction.Request("rerank-endpoint")).actionGet();
42+
var response = client().execute(GetRerankerWindowSizeAction.INSTANCE, new GetRerankerWindowSizeAction.Request("rerank-endpoint"))
43+
.actionGet();
4344
assertEquals(333, response.getWindowSize());
4445
}
4546

4647
public void testActionNotAReranker() {
4748
var e = expectThrows(
4849
ElasticsearchStatusException.class,
49-
() -> client().execute(GetRerankerAction.INSTANCE, new GetRerankerAction.Request("sparse-endpoint")).actionGet()
50+
() -> client().execute(GetRerankerWindowSizeAction.INSTANCE, new GetRerankerWindowSizeAction.Request("sparse-endpoint"))
51+
.actionGet()
5052
);
51-
assertThat(e.getMessage(), containsString("Inference endpoint [sparse-endpoint] is not a reranker"));
53+
assertThat(e.getMessage(), containsString("Inference endpoint [sparse-endpoint] does not have the rerank task type"));
5254
}
5355
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
import org.elasticsearch.xpack.core.inference.action.GetInferenceDiagnosticsAction;
6363
import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction;
6464
import org.elasticsearch.xpack.core.inference.action.GetInferenceServicesAction;
65-
import org.elasticsearch.xpack.core.inference.action.GetRerankerAction;
65+
import org.elasticsearch.xpack.core.inference.action.GetRerankerWindowSizeAction;
6666
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
6767
import org.elasticsearch.xpack.core.inference.action.InferenceActionProxy;
6868
import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction;
@@ -73,7 +73,7 @@
7373
import org.elasticsearch.xpack.inference.action.TransportGetInferenceDiagnosticsAction;
7474
import org.elasticsearch.xpack.inference.action.TransportGetInferenceModelAction;
7575
import org.elasticsearch.xpack.inference.action.TransportGetInferenceServicesAction;
76-
import org.elasticsearch.xpack.inference.action.TransportGetRerankerAction;
76+
import org.elasticsearch.xpack.inference.action.TransportGetRerankerWindowSizeAction;
7777
import org.elasticsearch.xpack.inference.action.TransportInferenceAction;
7878
import org.elasticsearch.xpack.inference.action.TransportInferenceActionProxy;
7979
import org.elasticsearch.xpack.inference.action.TransportInferenceUsageAction;
@@ -236,7 +236,7 @@ public List<ActionHandler> getActions() {
236236
new ActionHandler(GetInferenceDiagnosticsAction.INSTANCE, TransportGetInferenceDiagnosticsAction.class),
237237
new ActionHandler(GetInferenceServicesAction.INSTANCE, TransportGetInferenceServicesAction.class),
238238
new ActionHandler(UnifiedCompletionAction.INSTANCE, TransportUnifiedCompletionInferenceAction.class),
239-
new ActionHandler(GetRerankerAction.INSTANCE, TransportGetRerankerAction.class)
239+
new ActionHandler(GetRerankerWindowSizeAction.INSTANCE, TransportGetRerankerWindowSizeAction.class)
240240
);
241241
}
242242

Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,42 +22,50 @@
2222
import org.elasticsearch.tasks.Task;
2323
import org.elasticsearch.threadpool.ThreadPool;
2424
import org.elasticsearch.transport.TransportService;
25-
import org.elasticsearch.xpack.core.inference.action.GetRerankerAction;
26-
import org.elasticsearch.xpack.inference.InferencePlugin;
25+
import org.elasticsearch.xpack.core.inference.action.GetRerankerWindowSizeAction;
2726
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
2827

29-
import java.util.concurrent.Executor;
30-
31-
public class TransportGetRerankerAction extends HandledTransportAction<GetRerankerAction.Request, GetRerankerAction.Response> {
28+
public class TransportGetRerankerWindowSizeAction extends HandledTransportAction<
29+
GetRerankerWindowSizeAction.Request,
30+
GetRerankerWindowSizeAction.Response> {
3231

3332
private final ModelRegistry modelRegistry;
3433
private final InferenceServiceRegistry serviceRegistry;
35-
private final Executor executor;
3634

3735
@Inject
38-
public TransportGetRerankerAction(
36+
public TransportGetRerankerWindowSizeAction(
3937
TransportService transportService,
4038
ActionFilters actionFilters,
4139
ThreadPool threadPool,
4240
ModelRegistry modelRegistry,
4341
InferenceServiceRegistry serviceRegistry
4442
) {
45-
super(GetRerankerAction.NAME, transportService, actionFilters, GetRerankerAction.Request::new, EsExecutors.DIRECT_EXECUTOR_SERVICE);
43+
super(
44+
GetRerankerWindowSizeAction.NAME,
45+
transportService,
46+
actionFilters,
47+
GetRerankerWindowSizeAction.Request::new,
48+
EsExecutors.DIRECT_EXECUTOR_SERVICE
49+
);
4650
this.modelRegistry = modelRegistry;
4751
this.serviceRegistry = serviceRegistry;
48-
this.executor = threadPool.executor(InferencePlugin.UTILITY_THREAD_POOL_NAME);
4952
}
5053

5154
@Override
52-
protected void doExecute(Task task, GetRerankerAction.Request request, ActionListener<GetRerankerAction.Response> listener) {
55+
protected void doExecute(
56+
Task task,
57+
GetRerankerWindowSizeAction.Request request,
58+
ActionListener<GetRerankerWindowSizeAction.Response> listener
59+
) {
5360

5461
SubscribableListener.<UnparsedModel>newForked(l -> modelRegistry.getModel(request.getInferenceEntityId(), l)).<
55-
GetRerankerAction.Response>andThen((l, unparsedModel) -> {
62+
GetRerankerWindowSizeAction.Response>andThen((l, unparsedModel) -> {
5663
if (unparsedModel.taskType() != TaskType.RERANK) {
5764
throw new ElasticsearchStatusException(
58-
"Inference endpoint [{}] is not a reranker",
65+
"Inference endpoint [{}] does not have the {} task type",
5966
RestStatus.BAD_REQUEST,
60-
request.getInferenceEntityId()
67+
request.getInferenceEntityId(),
68+
TaskType.RERANK
6169
);
6270
}
6371

@@ -71,26 +79,30 @@ protected void doExecute(Task task, GetRerankerAction.Request request, ActionLis
7179
);
7280
}
7381

74-
var model = service.get()
75-
.parsePersistedConfig(unparsedModel.inferenceEntityId(), unparsedModel.taskType(), unparsedModel.settings());
76-
7782
if (service.get() instanceof RerankingInferenceService rerankingInferenceService) {
83+
var model = service.get()
84+
.parsePersistedConfig(unparsedModel.inferenceEntityId(), unparsedModel.taskType(), unparsedModel.settings());
85+
7886
l.onResponse(
79-
new GetRerankerAction.Response(rerankWindowSize(rerankingInferenceService, model.getServiceSettings().modelId()))
87+
new GetRerankerWindowSizeAction.Response(
88+
rerankWindowSize(rerankingInferenceService, model.getServiceSettings().modelId())
89+
)
8090
);
8191
} else {
8292
throw new IllegalStateException(
8393
"Inference endpoint ["
8494
+ request.getInferenceEntityId()
85-
+ "] is a reranker but the service ["
95+
+ "] has task type ["
96+
+ TaskType.RERANK
97+
+ "] but the service ["
8698
+ service.get().name()
8799
+ "] does not support reranking"
88100
);
89101
}
90102
}).addListener(listener);
91103
}
92104

93-
public int rerankWindowSize(RerankingInferenceService service, String modelId) {
105+
private int rerankWindowSize(RerankingInferenceService service, String modelId) {
94106
return service.rerankerWindowSize(modelId);
95107
}
96108
}
Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,23 +9,23 @@
99

1010
import org.elasticsearch.common.io.stream.Writeable;
1111
import org.elasticsearch.test.AbstractWireSerializingTestCase;
12-
import org.elasticsearch.xpack.core.inference.action.GetRerankerAction;
12+
import org.elasticsearch.xpack.core.inference.action.GetRerankerWindowSizeAction;
1313

1414
import java.io.IOException;
1515

16-
public class GetRerankerActionRequestTests extends AbstractWireSerializingTestCase<GetRerankerAction.Request> {
16+
public class GetRerankerWindowSizeActionRequestTests extends AbstractWireSerializingTestCase<GetRerankerWindowSizeAction.Request> {
1717
@Override
18-
protected Writeable.Reader<GetRerankerAction.Request> instanceReader() {
19-
return GetRerankerAction.Request::new;
18+
protected Writeable.Reader<GetRerankerWindowSizeAction.Request> instanceReader() {
19+
return GetRerankerWindowSizeAction.Request::new;
2020
}
2121

2222
@Override
23-
protected GetRerankerAction.Request createTestInstance() {
24-
return new GetRerankerAction.Request(randomAlphaOfLength(8));
23+
protected GetRerankerWindowSizeAction.Request createTestInstance() {
24+
return new GetRerankerWindowSizeAction.Request(randomAlphaOfLength(8));
2525
}
2626

2727
@Override
28-
protected GetRerankerAction.Request mutateInstance(GetRerankerAction.Request instance) throws IOException {
28+
protected GetRerankerWindowSizeAction.Request mutateInstance(GetRerankerWindowSizeAction.Request instance) throws IOException {
2929
return randomValueOtherThan(instance, this::createTestInstance);
3030
}
3131
}
Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,23 +9,23 @@
99

1010
import org.elasticsearch.common.io.stream.Writeable;
1111
import org.elasticsearch.test.AbstractWireSerializingTestCase;
12-
import org.elasticsearch.xpack.core.inference.action.GetRerankerAction;
12+
import org.elasticsearch.xpack.core.inference.action.GetRerankerWindowSizeAction;
1313

1414
import java.io.IOException;
1515

16-
public class GetRerankerActionResponseTests extends AbstractWireSerializingTestCase<GetRerankerAction.Response> {
16+
public class GetRerankerWindowSizeActionResponseTests extends AbstractWireSerializingTestCase<GetRerankerWindowSizeAction.Response> {
1717
@Override
18-
protected Writeable.Reader<GetRerankerAction.Response> instanceReader() {
19-
return GetRerankerAction.Response::new;
18+
protected Writeable.Reader<GetRerankerWindowSizeAction.Response> instanceReader() {
19+
return GetRerankerWindowSizeAction.Response::new;
2020
}
2121

2222
@Override
23-
protected GetRerankerAction.Response createTestInstance() {
24-
return new GetRerankerAction.Response(randomNonNegativeInt());
23+
protected GetRerankerWindowSizeAction.Response createTestInstance() {
24+
return new GetRerankerWindowSizeAction.Response(randomNonNegativeInt());
2525
}
2626

2727
@Override
28-
protected GetRerankerAction.Response mutateInstance(GetRerankerAction.Response instance) throws IOException {
28+
protected GetRerankerWindowSizeAction.Response mutateInstance(GetRerankerWindowSizeAction.Response instance) throws IOException {
2929
return randomValueOtherThan(instance, this::createTestInstance);
3030
}
3131
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@
104104
import static org.mockito.Mockito.verifyNoMoreInteractions;
105105
import static org.mockito.Mockito.when;
106106

107-
public class ElasticInferenceServiceTests extends ESSingleNodeTestCase { // TODO why is this this a single node test?
107+
public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
108108

109109
private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
110110
private final MockWebServer webServer = new MockWebServer();

0 commit comments

Comments
 (0)