Skip to content

Commit b34279c

Browse files
benwtrentAdam Locke
authored andcommitted
[ML] track inference model feature usage per node (elastic#79752)
This adds feature usage tracking for deployed inference models. The models are tracked under the existing, inference feature and contain context related to the model ID. I decided to track the feature via the allocation task to keep the logic similar between allocation tasks and licensed persistent tasks. closes: elastic#76452
1 parent 932523e commit b34279c

File tree

6 files changed

+114
-7
lines changed

6 files changed

+114
-7
lines changed

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,11 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin,
468468
"model-inference",
469469
License.OperationMode.PLATINUM
470470
);
471+
public static final LicensedFeature.Persistent ML_PYTORCH_MODEL_INFERENCE_FEATURE = LicensedFeature.persistent(
472+
MachineLearningField.ML_FEATURE_FAMILY,
473+
"pytorch-model-inference",
474+
License.OperationMode.PLATINUM
475+
);
471476

472477
@Override
473478
public Map<String, Processor.Factory> getProcessors(Processor.Parameters parameters) {

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportCreateTrainedModelAllocationAction.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
1717
import org.elasticsearch.cluster.service.ClusterService;
1818
import org.elasticsearch.common.inject.Inject;
19+
import org.elasticsearch.license.XPackLicenseState;
1920
import org.elasticsearch.tasks.Task;
2021
import org.elasticsearch.threadpool.ThreadPool;
2122
import org.elasticsearch.transport.TransportService;
@@ -40,7 +41,8 @@ public TransportCreateTrainedModelAllocationAction(
4041
ClusterService clusterService,
4142
ThreadPool threadPool,
4243
ActionFilters actionFilters,
43-
IndexNameExpressionResolver indexNameExpressionResolver
44+
IndexNameExpressionResolver indexNameExpressionResolver,
45+
XPackLicenseState licenseState
4446
) {
4547
super(
4648
CreateTrainedModelAllocationAction.NAME,
@@ -62,7 +64,8 @@ public TransportCreateTrainedModelAllocationAction(
6264
clusterService,
6365
deploymentManager,
6466
transportService.getTaskManager(),
65-
threadPool
67+
threadPool,
68+
licenseState
6669
)
6770
);
6871
}

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationNodeService.java

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import org.elasticsearch.common.component.LifecycleListener;
2222
import org.elasticsearch.common.util.set.Sets;
2323
import org.elasticsearch.core.TimeValue;
24+
import org.elasticsearch.license.XPackLicenseState;
2425
import org.elasticsearch.tasks.Task;
2526
import org.elasticsearch.tasks.TaskAwareRequest;
2627
import org.elasticsearch.tasks.TaskId;
@@ -52,6 +53,7 @@
5253

5354
import static org.elasticsearch.xpack.core.ml.MlTasks.TRAINED_MODEL_ALLOCATION_TASK_NAME_PREFIX;
5455
import static org.elasticsearch.xpack.core.ml.MlTasks.TRAINED_MODEL_ALLOCATION_TASK_TYPE;
56+
import static org.elasticsearch.xpack.ml.MachineLearning.ML_PYTORCH_MODEL_INFERENCE_FEATURE;
5557

5658
public class TrainedModelAllocationNodeService implements ClusterStateListener {
5759

@@ -65,6 +67,7 @@ public class TrainedModelAllocationNodeService implements ClusterStateListener {
6567
private final Map<String, TrainedModelDeploymentTask> modelIdToTask;
6668
private final ThreadPool threadPool;
6769
private final Deque<TrainedModelDeploymentTask> loadingModels;
70+
private final XPackLicenseState licenseState;
6871
private volatile Scheduler.Cancellable scheduledFuture;
6972
private volatile boolean stopped;
7073
private volatile String nodeId;
@@ -74,14 +77,16 @@ public TrainedModelAllocationNodeService(
7477
ClusterService clusterService,
7578
DeploymentManager deploymentManager,
7679
TaskManager taskManager,
77-
ThreadPool threadPool
80+
ThreadPool threadPool,
81+
XPackLicenseState licenseState
7882
) {
7983
this.trainedModelAllocationService = trainedModelAllocationService;
8084
this.deploymentManager = deploymentManager;
8185
this.taskManager = taskManager;
8286
this.modelIdToTask = new ConcurrentHashMap<>();
8387
this.loadingModels = new ConcurrentLinkedDeque<>();
8488
this.threadPool = threadPool;
89+
this.licenseState = licenseState;
8590
clusterService.addLifecycleListener(new LifecycleListener() {
8691
@Override
8792
public void afterStart() {
@@ -102,7 +107,8 @@ public void beforeStop() {
102107
DeploymentManager deploymentManager,
103108
TaskManager taskManager,
104109
ThreadPool threadPool,
105-
String nodeId
110+
String nodeId,
111+
XPackLicenseState licenseState
106112
) {
107113
this.trainedModelAllocationService = trainedModelAllocationService;
108114
this.deploymentManager = deploymentManager;
@@ -111,6 +117,7 @@ public void beforeStop() {
111117
this.loadingModels = new ConcurrentLinkedDeque<>();
112118
this.threadPool = threadPool;
113119
this.nodeId = nodeId;
120+
this.licenseState = licenseState;
114121
clusterService.addLifecycleListener(new LifecycleListener() {
115122
@Override
116123
public void afterStart() {
@@ -265,7 +272,17 @@ public TaskId getParentTask() {
265272

266273
@Override
267274
public Task createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
268-
return new TrainedModelDeploymentTask(id, type, action, parentTaskId, headers, params, trainedModelAllocationNodeService);
275+
return new TrainedModelDeploymentTask(
276+
id,
277+
type,
278+
action,
279+
parentTaskId,
280+
headers,
281+
params,
282+
trainedModelAllocationNodeService,
283+
licenseState,
284+
ML_PYTORCH_MODEL_INFERENCE_FEATURE
285+
);
269286
}
270287
};
271288
}

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
import org.apache.lucene.util.SetOnce;
1313
import org.elasticsearch.action.ActionListener;
1414
import org.elasticsearch.core.TimeValue;
15+
import org.elasticsearch.license.LicensedFeature;
16+
import org.elasticsearch.license.XPackLicenseState;
1517
import org.elasticsearch.tasks.CancellableTask;
1618
import org.elasticsearch.tasks.TaskId;
1719
import org.elasticsearch.xpack.core.ml.MlTasks;
@@ -26,6 +28,7 @@
2628
import java.util.Map;
2729
import java.util.Optional;
2830

31+
2932
public class TrainedModelDeploymentTask extends CancellableTask implements StartTrainedModelDeploymentAction.TaskMatcher {
3033

3134
private static final Logger logger = LogManager.getLogger(TrainedModelDeploymentTask.class);
@@ -35,6 +38,8 @@ public class TrainedModelDeploymentTask extends CancellableTask implements Start
3538
private volatile boolean stopped;
3639
private final SetOnce<String> stoppedReason = new SetOnce<>();
3740
private final SetOnce<InferenceConfig> inferenceConfig = new SetOnce<>();
41+
private final XPackLicenseState licenseState;
42+
private final LicensedFeature.Persistent licensedFeature;
3843

3944
public TrainedModelDeploymentTask(
4045
long id,
@@ -43,18 +48,23 @@ public TrainedModelDeploymentTask(
4348
TaskId parentTask,
4449
Map<String, String> headers,
4550
TaskParams taskParams,
46-
TrainedModelAllocationNodeService trainedModelAllocationNodeService
51+
TrainedModelAllocationNodeService trainedModelAllocationNodeService,
52+
XPackLicenseState licenseState,
53+
LicensedFeature.Persistent licensedFeature
4754
) {
4855
super(id, type, action, MlTasks.trainedModelDeploymentTaskId(taskParams.getModelId()), parentTask, headers);
4956
this.params = taskParams;
5057
this.trainedModelAllocationNodeService = ExceptionsHelper.requireNonNull(
5158
trainedModelAllocationNodeService,
5259
"trainedModelAllocationNodeService"
5360
);
61+
this.licenseState = licenseState;
62+
this.licensedFeature = licensedFeature;
5463
}
5564

5665
void init(InferenceConfig inferenceConfig) {
5766
this.inferenceConfig.set(inferenceConfig);
67+
licensedFeature.startTracking(licenseState, "model-" + params.getModelId());
5868
}
5969

6070
public String getModelId() {
@@ -71,12 +81,14 @@ public TaskParams getParams() {
7181

7282
public void stop(String reason) {
7383
logger.debug("[{}] Stopping due to reason [{}]", getModelId(), reason);
84+
licensedFeature.stopTracking(licenseState, "model-" + params.getModelId());
7485
stopped = true;
7586
stoppedReason.trySet(reason);
7687
trainedModelAllocationNodeService.stopDeploymentAndNotify(this, reason);
7788
}
7889

7990
public void stopWithoutNotification(String reason) {
91+
licensedFeature.stopTracking(licenseState, "model-" + params.getModelId());
8092
logger.debug("[{}] Stopping due to reason [{}]", getModelId(), reason);
8193
stoppedReason.trySet(reason);
8294
stopped = true;

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationNodeServiceTests.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import org.elasticsearch.cluster.service.ClusterService;
2424
import org.elasticsearch.common.settings.Settings;
2525
import org.elasticsearch.core.TimeValue;
26+
import org.elasticsearch.license.XPackLicenseState;
2627
import org.elasticsearch.tasks.TaskManager;
2728
import org.elasticsearch.test.ESTestCase;
2829
import org.elasticsearch.threadpool.ScalingExecutorBuilder;
@@ -507,7 +508,8 @@ private TrainedModelAllocationNodeService createService() {
507508
deploymentManager,
508509
taskManager,
509510
threadPool,
510-
NODE_ID
511+
NODE_ID,
512+
mock(XPackLicenseState.class)
511513
);
512514
}
513515

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.ml.inference.deployment;
9+
10+
import org.elasticsearch.license.LicensedFeature;
11+
import org.elasticsearch.license.XPackLicenseState;
12+
import org.elasticsearch.tasks.TaskId;
13+
import org.elasticsearch.test.ESTestCase;
14+
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
15+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PassThroughConfig;
16+
import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationNodeService;
17+
18+
import java.util.Map;
19+
import java.util.function.Consumer;
20+
21+
import static org.elasticsearch.xpack.core.ml.MlTasks.TRAINED_MODEL_ALLOCATION_TASK_NAME_PREFIX;
22+
import static org.elasticsearch.xpack.core.ml.MlTasks.TRAINED_MODEL_ALLOCATION_TASK_TYPE;
23+
import static org.mockito.Mockito.mock;
24+
import static org.mockito.Mockito.times;
25+
import static org.mockito.Mockito.verify;
26+
27+
public class TrainedModelDeploymentTaskTests extends ESTestCase {
28+
29+
void assertTrackingComplete(Consumer<TrainedModelDeploymentTask> method, String modelId) {
30+
XPackLicenseState licenseState = mock(XPackLicenseState.class);
31+
LicensedFeature.Persistent feature = mock(LicensedFeature.Persistent.class);
32+
TrainedModelDeploymentTask task = new TrainedModelDeploymentTask(
33+
0,
34+
TRAINED_MODEL_ALLOCATION_TASK_TYPE,
35+
TRAINED_MODEL_ALLOCATION_TASK_NAME_PREFIX + modelId,
36+
TaskId.EMPTY_TASK_ID,
37+
Map.of(),
38+
new StartTrainedModelDeploymentAction.TaskParams(
39+
modelId,
40+
randomLongBetween(1, Long.MAX_VALUE),
41+
randomInt(5),
42+
randomInt(5),
43+
randomInt(5)
44+
),
45+
mock(TrainedModelAllocationNodeService.class),
46+
licenseState,
47+
feature
48+
);
49+
50+
task.init(new PassThroughConfig(null, null, null));
51+
verify(feature, times(1)).startTracking(licenseState, "model-" + modelId);
52+
method.accept(task);
53+
verify(feature, times(1)).stopTracking(licenseState, "model-" + modelId);
54+
}
55+
56+
public void testOnStopWithoutNotification() {
57+
assertTrackingComplete(t -> t.stopWithoutNotification("foo"), randomAlphaOfLength(10));
58+
}
59+
60+
public void testOnStop() {
61+
assertTrackingComplete(t -> t.stop("foo"), randomAlphaOfLength(10));
62+
}
63+
64+
public void testCancelled() {
65+
assertTrackingComplete(TrainedModelDeploymentTask::onCancelled, randomAlphaOfLength(10));
66+
}
67+
68+
}

0 commit comments

Comments
 (0)