Skip to content

Commit 637a299

Browse files
[ML] Add queue_capacity setting to start deployment API (#79369)
1 parent 2b4fe8f commit 637a299

File tree

12 files changed

+108
-27
lines changed

12 files changed

+108
-27
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ public static class Request extends MasterNodeRequest<Request> implements ToXCon
6060
public static final ParseField WAIT_FOR = new ParseField("wait_for");
6161
public static final ParseField INFERENCE_THREADS = TaskParams.INFERENCE_THREADS;
6262
public static final ParseField MODEL_THREADS = TaskParams.MODEL_THREADS;
63+
public static final ParseField QUEUE_CAPACITY = TaskParams.QUEUE_CAPACITY;
6364

6465
public static final ObjectParser<Request, Void> PARSER = new ObjectParser<>(NAME, Request::new);
6566

@@ -69,6 +70,7 @@ public static class Request extends MasterNodeRequest<Request> implements ToXCon
6970
PARSER.declareString((request, waitFor) -> request.setWaitForState(AllocationStatus.State.fromString(waitFor)), WAIT_FOR);
7071
PARSER.declareInt(Request::setInferenceThreads, INFERENCE_THREADS);
7172
PARSER.declareInt(Request::setModelThreads, MODEL_THREADS);
73+
PARSER.declareInt(Request::setQueueCapacity, QUEUE_CAPACITY);
7274
}
7375

7476
public static Request parseRequest(String modelId, XContentParser parser) {
@@ -87,6 +89,7 @@ public static Request parseRequest(String modelId, XContentParser parser) {
8789
private AllocationStatus.State waitForState = AllocationStatus.State.STARTED;
8890
private int modelThreads = 1;
8991
private int inferenceThreads = 1;
92+
private int queueCapacity = 1024;
9093

9194
private Request() {}
9295

@@ -101,6 +104,7 @@ public Request(StreamInput in) throws IOException {
101104
waitForState = in.readEnum(AllocationStatus.State.class);
102105
modelThreads = in.readVInt();
103106
inferenceThreads = in.readVInt();
107+
queueCapacity = in.readVInt();
104108
}
105109

106110
public final void setModelId(String modelId) {
@@ -144,6 +148,14 @@ public void setInferenceThreads(int inferenceThreads) {
144148
this.inferenceThreads = inferenceThreads;
145149
}
146150

151+
public int getQueueCapacity() {
152+
return queueCapacity;
153+
}
154+
155+
public void setQueueCapacity(int queueCapacity) {
156+
this.queueCapacity = queueCapacity;
157+
}
158+
147159
@Override
148160
public void writeTo(StreamOutput out) throws IOException {
149161
super.writeTo(out);
@@ -152,6 +164,7 @@ public void writeTo(StreamOutput out) throws IOException {
152164
out.writeEnum(waitForState);
153165
out.writeVInt(modelThreads);
154166
out.writeVInt(inferenceThreads);
167+
out.writeVInt(queueCapacity);
155168
}
156169

157170
@Override
@@ -162,6 +175,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
162175
builder.field(WAIT_FOR.getPreferredName(), waitForState);
163176
builder.field(MODEL_THREADS.getPreferredName(), modelThreads);
164177
builder.field(INFERENCE_THREADS.getPreferredName(), inferenceThreads);
178+
builder.field(QUEUE_CAPACITY.getPreferredName(), queueCapacity);
165179
builder.endObject();
166180
return builder;
167181
}
@@ -183,12 +197,15 @@ public ActionRequestValidationException validate() {
183197
if (inferenceThreads < 1) {
184198
validationException.addValidationError("[" + INFERENCE_THREADS + "] must be a positive integer");
185199
}
200+
if (queueCapacity < 1 || queueCapacity > 10000) {
201+
validationException.addValidationError("[" + QUEUE_CAPACITY + "] must be in [1, 10000]");
202+
}
186203
return validationException.validationErrors().isEmpty() ? null : validationException;
187204
}
188205

189206
@Override
190207
public int hashCode() {
191-
return Objects.hash(modelId, timeout, waitForState, modelThreads, inferenceThreads);
208+
return Objects.hash(modelId, timeout, waitForState, modelThreads, inferenceThreads, queueCapacity);
192209
}
193210

194211
@Override
@@ -204,7 +221,8 @@ public boolean equals(Object obj) {
204221
&& Objects.equals(timeout, other.timeout)
205222
&& Objects.equals(waitForState, other.waitForState)
206223
&& modelThreads == other.modelThreads
207-
&& inferenceThreads == other.inferenceThreads;
224+
&& inferenceThreads == other.inferenceThreads
225+
&& queueCapacity == other.queueCapacity;
208226
}
209227

210228
@Override
@@ -226,16 +244,20 @@ public static boolean mayAllocateToNode(DiscoveryNode node) {
226244
private static final ParseField MODEL_BYTES = new ParseField("model_bytes");
227245
public static final ParseField MODEL_THREADS = new ParseField("model_threads");
228246
public static final ParseField INFERENCE_THREADS = new ParseField("inference_threads");
247+
public static final ParseField QUEUE_CAPACITY = new ParseField("queue_capacity");
248+
229249
private static final ConstructingObjectParser<TaskParams, Void> PARSER = new ConstructingObjectParser<>(
230250
"trained_model_deployment_params",
231251
true,
232-
a -> new TaskParams((String)a[0], (Long)a[1], (int) a[2], (int) a[3])
252+
a -> new TaskParams((String)a[0], (Long)a[1], (int) a[2], (int) a[3], (int) a[4])
233253
);
254+
234255
static {
235256
PARSER.declareString(ConstructingObjectParser.constructorArg(), TrainedModelConfig.MODEL_ID);
236257
PARSER.declareLong(ConstructingObjectParser.constructorArg(), MODEL_BYTES);
237258
PARSER.declareInt(ConstructingObjectParser.constructorArg(), INFERENCE_THREADS);
238259
PARSER.declareInt(ConstructingObjectParser.constructorArg(), MODEL_THREADS);
260+
PARSER.declareInt(ConstructingObjectParser.constructorArg(), QUEUE_CAPACITY);
239261
}
240262

241263
public static TaskParams fromXContent(XContentParser parser) {
@@ -253,8 +275,9 @@ public static TaskParams fromXContent(XContentParser parser) {
253275
private final long modelBytes;
254276
private final int inferenceThreads;
255277
private final int modelThreads;
278+
private final int queueCapacity;
256279

257-
public TaskParams(String modelId, long modelBytes, int inferenceThreads, int modelThreads) {
280+
public TaskParams(String modelId, long modelBytes, int inferenceThreads, int modelThreads, int queueCapacity) {
258281
this.modelId = Objects.requireNonNull(modelId);
259282
this.modelBytes = modelBytes;
260283
if (modelBytes < 0) {
@@ -268,13 +291,18 @@ public TaskParams(String modelId, long modelBytes, int inferenceThreads, int mod
268291
if (modelThreads < 1) {
269292
throw new IllegalArgumentException(MODEL_THREADS + " must be positive");
270293
}
294+
this.queueCapacity = queueCapacity;
295+
if (queueCapacity < 1 || queueCapacity > 10000) {
296+
throw new IllegalArgumentException(QUEUE_CAPACITY + " must be in [1, 10000]");
297+
}
271298
}
272299

273300
public TaskParams(StreamInput in) throws IOException {
274301
this.modelId = in.readString();
275302
this.modelBytes = in.readVLong();
276303
this.inferenceThreads = in.readVInt();
277304
this.modelThreads = in.readVInt();
305+
this.queueCapacity = in.readVInt();
278306
}
279307

280308
public String getModelId() {
@@ -296,6 +324,7 @@ public void writeTo(StreamOutput out) throws IOException {
296324
out.writeVLong(modelBytes);
297325
out.writeVInt(inferenceThreads);
298326
out.writeVInt(modelThreads);
327+
out.writeVInt(queueCapacity);
299328
}
300329

301330
@Override
@@ -305,13 +334,14 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
305334
builder.field(MODEL_BYTES.getPreferredName(), modelBytes);
306335
builder.field(INFERENCE_THREADS.getPreferredName(), inferenceThreads);
307336
builder.field(MODEL_THREADS.getPreferredName(), modelThreads);
337+
builder.field(QUEUE_CAPACITY.getPreferredName(), queueCapacity);
308338
builder.endObject();
309339
return builder;
310340
}
311341

312342
@Override
313343
public int hashCode() {
314-
return Objects.hash(modelId, modelBytes, inferenceThreads, modelThreads);
344+
return Objects.hash(modelId, modelBytes, inferenceThreads, modelThreads, queueCapacity);
315345
}
316346

317347
@Override
@@ -323,7 +353,8 @@ public boolean equals(Object o) {
323353
return Objects.equals(modelId, other.modelId)
324354
&& modelBytes == other.modelBytes
325355
&& inferenceThreads == other.inferenceThreads
326-
&& modelThreads == other.modelThreads;
356+
&& modelThreads == other.modelThreads
357+
&& queueCapacity == other.queueCapacity;
327358
}
328359

329360
@Override
@@ -342,6 +373,10 @@ public int getInferenceThreads() {
342373
public int getModelThreads() {
343374
return modelThreads;
344375
}
376+
377+
public int getQueueCapacity() {
378+
return queueCapacity;
379+
}
345380
}
346381

347382
public interface TaskMatcher {

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/CreateTrainedModelAllocationActionRequestTests.java

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,7 @@ public class CreateTrainedModelAllocationActionRequestTests extends AbstractWire
1414

1515
@Override
1616
protected Request createTestInstance() {
17-
return new Request(
18-
new StartTrainedModelDeploymentAction.TaskParams(
19-
randomAlphaOfLength(10),
20-
randomNonNegativeLong(),
21-
randomIntBetween(1, 8),
22-
randomIntBetween(1, 8)
23-
)
24-
);
17+
return new Request(StartTrainedModelDeploymentTaskParamsTests.createRandom());
2518
}
2619

2720
@Override

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentRequestTests.java

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import java.io.IOException;
1919

2020
import static org.hamcrest.Matchers.containsString;
21+
import static org.hamcrest.Matchers.equalTo;
2122
import static org.hamcrest.Matchers.is;
2223
import static org.hamcrest.Matchers.not;
2324
import static org.hamcrest.Matchers.nullValue;
@@ -53,6 +54,9 @@ public static Request createRandom() {
5354
if (randomBoolean()) {
5455
request.setModelThreads(randomIntBetween(1, 8));
5556
}
57+
if (randomBoolean()) {
58+
request.setQueueCapacity(randomIntBetween(1, 10000));
59+
}
5660
return request;
5761
}
5862

@@ -95,4 +99,43 @@ public void testValidate_GivenModelThreadsIsNegative() {
9599
assertThat(e, is(not(nullValue())));
96100
assertThat(e.getMessage(), containsString("[model_threads] must be a positive integer"));
97101
}
102+
103+
public void testValidate_GivenQueueCapacityIsZero() {
104+
Request request = createRandom();
105+
request.setQueueCapacity(0);
106+
107+
ActionRequestValidationException e = request.validate();
108+
109+
assertThat(e, is(not(nullValue())));
110+
assertThat(e.getMessage(), containsString("[queue_capacity] must be in [1, 10000]"));
111+
}
112+
113+
public void testValidate_GivenQueueCapacityIsNegative() {
114+
Request request = createRandom();
115+
request.setQueueCapacity(randomIntBetween(Integer.MIN_VALUE, -1));
116+
117+
ActionRequestValidationException e = request.validate();
118+
119+
assertThat(e, is(not(nullValue())));
120+
assertThat(e.getMessage(), containsString("[queue_capacity] must be in [1, 10000]"));
121+
}
122+
123+
public void testValidate_GivenQueueCapacityIsGreaterThan10000() {
124+
Request request = createRandom();
125+
request.setQueueCapacity(randomIntBetween(10001, Integer.MAX_VALUE));
126+
127+
ActionRequestValidationException e = request.validate();
128+
129+
assertThat(e, is(not(nullValue())));
130+
assertThat(e.getMessage(), containsString("[queue_capacity] must be in [1, 10000]"));
131+
}
132+
133+
public void testDefaults() {
134+
Request request = new Request(randomAlphaOfLength(10));
135+
assertThat(request.getTimeout(), equalTo(TimeValue.timeValueSeconds(20)));
136+
assertThat(request.getWaitForState(), equalTo(AllocationStatus.State.STARTED));
137+
assertThat(request.getInferenceThreads(), equalTo(1));
138+
assertThat(request.getModelThreads(), equalTo(1));
139+
assertThat(request.getQueueCapacity(), equalTo(1024));
140+
}
98141
}

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentTaskParamsTests.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ public static StartTrainedModelDeploymentAction.TaskParams createRandom() {
3636
randomAlphaOfLength(10),
3737
randomNonNegativeLong(),
3838
randomIntBetween(1, 8),
39-
randomIntBetween(1, 8)
39+
randomIntBetween(1, 8),
40+
randomIntBetween(1, 10000)
4041
);
4142
}
4243
}

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/allocation/TrainedModelAllocationTests.java

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@
1313
import org.elasticsearch.cluster.node.DiscoveryNode;
1414
import org.elasticsearch.cluster.node.DiscoveryNodeRole;
1515
import org.elasticsearch.common.io.stream.Writeable;
16-
import org.elasticsearch.xcontent.XContentParser;
1716
import org.elasticsearch.test.AbstractSerializingTestCase;
17+
import org.elasticsearch.xcontent.XContentParser;
1818
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
19+
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentTaskParamsTests;
1920

2021
import java.io.IOException;
2122
import java.util.List;
@@ -31,9 +32,7 @@
3132
public class TrainedModelAllocationTests extends AbstractSerializingTestCase<TrainedModelAllocation> {
3233

3334
public static TrainedModelAllocation randomInstance() {
34-
TrainedModelAllocation.Builder builder = TrainedModelAllocation.Builder.empty(
35-
new StartTrainedModelDeploymentAction.TaskParams(randomAlphaOfLength(10), randomNonNegativeLong(), 1, 1)
36-
);
35+
TrainedModelAllocation.Builder builder = TrainedModelAllocation.Builder.empty(randomParams());
3736
List<String> nodes = Stream.generate(() -> randomAlphaOfLength(10)).limit(randomInt(5)).collect(Collectors.toList());
3837
for (String node : nodes) {
3938
if (randomBoolean()) {
@@ -249,7 +248,7 @@ private static DiscoveryNode buildNode() {
249248
}
250249

251250
private static StartTrainedModelDeploymentAction.TaskParams randomParams() {
252-
return new StartTrainedModelDeploymentAction.TaskParams(randomAlphaOfLength(10), randomNonNegativeLong(), 1, 1);
251+
return StartTrainedModelDeploymentTaskParamsTests.createRandom();
253252
}
254253

255254
private static void assertUnchanged(

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
import org.elasticsearch.cluster.service.ClusterService;
2727
import org.elasticsearch.common.inject.Inject;
2828
import org.elasticsearch.common.settings.Settings;
29-
import org.elasticsearch.xcontent.NamedXContentRegistry;
3029
import org.elasticsearch.core.TimeValue;
3130
import org.elasticsearch.license.LicenseUtils;
3231
import org.elasticsearch.license.XPackLicenseState;
@@ -35,6 +34,7 @@
3534
import org.elasticsearch.tasks.Task;
3635
import org.elasticsearch.threadpool.ThreadPool;
3736
import org.elasticsearch.transport.TransportService;
37+
import org.elasticsearch.xcontent.NamedXContentRegistry;
3838
import org.elasticsearch.xpack.core.XPackField;
3939
import org.elasticsearch.xpack.core.ml.MachineLearningField;
4040
import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAllocationAction;
@@ -161,7 +161,8 @@ protected void masterOperation(Task task, StartTrainedModelDeploymentAction.Requ
161161
trainedModelConfig.getModelId(),
162162
modelBytes,
163163
request.getInferenceThreads(),
164-
request.getModelThreads()
164+
request.getModelThreads(),
165+
request.getQueueCapacity()
165166
);
166167
PersistentTasksCustomMetadata persistentTasks = clusterService.state().getMetadata().custom(
167168
PersistentTasksCustomMetadata.TYPE);

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,7 @@ public void onFailure(Exception e) {
307307

308308
@Override
309309
protected void doRun() throws Exception {
310+
logger.info("Request [{}] running", requestId);
310311
final String requestIdStr = String.valueOf(requestId);
311312
try {
312313
// The request builder expect a list of inputs which are then batched.
@@ -392,7 +393,11 @@ class ProcessContext {
392393
this.task = Objects.requireNonNull(task);
393394
resultProcessor = new PyTorchResultProcessor(task.getModelId());
394395
this.stateStreamer = new PyTorchStateStreamer(client, executorService, xContentRegistry);
395-
this.executorService = new ProcessWorkerExecutorService(threadPool.getThreadContext(), "pytorch_inference", 1024);
396+
this.executorService = new ProcessWorkerExecutorService(
397+
threadPool.getThreadContext(),
398+
"pytorch_inference",
399+
task.getParams().getQueueCapacity()
400+
);
396401
}
397402

398403
PyTorchResultProcessor getResultProcessor() {

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestStartTrainedModelDeploymentAction.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import static org.elasticsearch.rest.RestRequest.Method.POST;
2424
import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.INFERENCE_THREADS;
2525
import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.MODEL_THREADS;
26+
import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.QUEUE_CAPACITY;
2627
import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.TIMEOUT;
2728
import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.WAIT_FOR;
2829

@@ -59,6 +60,7 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient
5960
));
6061
request.setInferenceThreads(restRequest.paramAsInt(INFERENCE_THREADS.getPreferredName(), request.getInferenceThreads()));
6162
request.setModelThreads(restRequest.paramAsInt(MODEL_THREADS.getPreferredName(), request.getModelThreads()));
63+
request.setQueueCapacity(restRequest.paramAsInt(QUEUE_CAPACITY.getPreferredName(), request.getQueueCapacity()));
6264
}
6365

6466
return channel -> client.execute(StartTrainedModelDeploymentAction.INSTANCE, request, new RestToXContentListener<>(channel));

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationClusterServiceTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -940,7 +940,7 @@ private static DiscoveryNode buildOldNode(String name, boolean isML, long native
940940
}
941941

942942
private static StartTrainedModelDeploymentAction.TaskParams newParams(String modelId, long modelSize) {
943-
return new StartTrainedModelDeploymentAction.TaskParams(modelId, modelSize, 1, 1);
943+
return new StartTrainedModelDeploymentAction.TaskParams(modelId, modelSize, 1, 1, 1024);
944944
}
945945

946946
private static void assertNodeState(TrainedModelAllocationMetadata metadata, String modelId, String nodeId, RoutingState routingState) {

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationMetadataTests.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,8 @@ private static StartTrainedModelDeploymentAction.TaskParams randomParams(String
9999
modelId,
100100
randomNonNegativeLong(),
101101
randomIntBetween(1, 8),
102-
randomIntBetween(1, 8)
102+
randomIntBetween(1, 8),
103+
randomIntBetween(1, 10000)
103104
);
104105
}
105106

0 commit comments

Comments
 (0)