Skip to content

Commit

Permalink
renaming metrics (#1224)
Browse files Browse the repository at this point in the history
* renaming metrics

Signed-off-by: Dhrubo Saha <dhrubo@amazon.com>

* updating tests

Signed-off-by: Dhrubo Saha <dhrubo@amazon.com>

* updating test cases

Signed-off-by: Dhrubo Saha <dhrubo@amazon.com>

* removing the ML_NODE checking for node level stats

Signed-off-by: Dhrubo Saha <dhrubo@amazon.com>

* updating constructing new set

Signed-off-by: Dhrubo Saha <dhrubo@amazon.com>

* spotless Apply

Signed-off-by: Dhrubo Saha <dhrubo@amazon.com>

* updating ML_NODE_TOTAL_MODEL_COUNT to ML_DEPLOYED_MODEL_COUNT

Signed-off-by: Dhrubo Saha <dhrubo@amazon.com>

* fixing metrics count

Signed-off-by: Dhrubo Saha <dhrubo@amazon.com>

* spotless

Signed-off-by: Dhrubo Saha <dhrubo@amazon.com>

* fixing executing task

Signed-off-by: Dhrubo Saha <dhrubo@amazon.com>

* updating comment

Signed-off-by: Dhrubo Saha <dhrubo@amazon.com>

---------

Signed-off-by: Dhrubo Saha <dhrubo@amazon.com>
  • Loading branch information
dhrubo-os authored Aug 22, 2023
1 parent 47bf678 commit 86eb953
Show file tree
Hide file tree
Showing 34 changed files with 151 additions and 164 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@
import org.opensearch.ml.helper.ModelAccessControlHelper;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.stats.MLNodeLevelStat;
import org.opensearch.ml.stats.MLStats;
import org.opensearch.ml.task.MLTaskDispatcher;
import org.opensearch.ml.task.MLTaskManager;
Expand Down Expand Up @@ -148,8 +147,6 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLDepl
if (!allowCustomDeploymentPlan && !deployToAllNodes) {
throw new IllegalArgumentException("Don't allow custom deployment plan");
}
// mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment();
mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment();
DiscoveryNode[] allEligibleNodes = nodeFilter.getEligibleNodes(functionName);
Map<String, DiscoveryNode> nodeMapping = new HashMap<>();
for (DiscoveryNode node : allEligibleNodes) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@
import org.opensearch.ml.indices.MLIndicesHandler;
import org.opensearch.ml.model.MLModelGroupManager;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.stats.MLNodeLevelStat;
import org.opensearch.ml.stats.MLStats;
import org.opensearch.ml.task.MLTaskDispatcher;
import org.opensearch.ml.task.MLTaskManager;
Expand Down Expand Up @@ -234,12 +233,6 @@ private void registerModel(MLRegisterModelInput registerModelInput, ActionListen
throw new IllegalArgumentException("URL can't match trusted url regex");
}
}
// mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment();
mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment();
// //TODO: track executing task; track register failures
// mlStats.createCounterStatIfAbsent(FunctionName.TEXT_EMBEDDING,
// ActionName.REGISTER,
// MLActionLevelStat.ML_ACTION_REQUEST_COUNT).increment();
boolean isAsync = registerModelInput.getFunctionName() != FunctionName.REMOTE;
MLTask mlTask = MLTask
.builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,9 @@ MLStatsNodeResponse createMLStatsNodeResponse(MLStatsNodesRequest mlStatsNodesRe
MLStatsInput mlStatsInput = mlStatsNodesRequest.getMlStatsInput();
// return node level stats
if (mlStatsInput.getTargetStatLevels().contains(MLStatLevel.NODE)) {
if (mlStatsInput.retrieveStat(MLNodeLevelStat.ML_NODE_JVM_HEAP_USAGE)) {
if (mlStatsInput.retrieveStat(MLNodeLevelStat.ML_JVM_HEAP_USAGE)) {
long heapUsedPercent = jvmService.stats().getMem().getHeapUsedPercent();
statValues.put(MLNodeLevelStat.ML_NODE_JVM_HEAP_USAGE, heapUsedPercent);
statValues.put(MLNodeLevelStat.ML_JVM_HEAP_USAGE, heapUsedPercent);
}

for (Enum statName : mlStats.getNodeStats().keySet()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,7 @@ protected MLUndeployModelNodeResponse nodeOperation(MLUndeployModelNodeRequest r
}

private MLUndeployModelNodeResponse createUndeployModelNodeResponse(MLUndeployModelNodesRequest MLUndeployModelNodesRequest) {
mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment();
mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment();
mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).increment();

String[] modelIds = MLUndeployModelNodesRequest.getModelIds();

Expand All @@ -246,7 +245,7 @@ private MLUndeployModelNodeResponse createUndeployModelNodeResponse(MLUndeployMo
}

Map<String, String> modelUndeployStatus = mlModelManager.undeployModel(modelIds);
mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).decrement();
mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).decrement();
return new MLUndeployModelNodeResponse(clusterService.localNode(), modelUndeployStatus, modelWorkerNodesMap);
}
}
39 changes: 16 additions & 23 deletions plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ public MLModelManager(
public void registerModelMeta(MLRegisterModelMetaInput mlRegisterModelMetaInput, ActionListener<String> listener) {
try {
FunctionName functionName = mlRegisterModelMetaInput.getFunctionName();
mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment();
mlStats.getStat(MLNodeLevelStat.ML_REQUEST_COUNT).increment();
mlStats.createCounterStatIfAbsent(functionName, REGISTER, ML_ACTION_REQUEST_COUNT).increment();
String modelGroupId = mlRegisterModelMetaInput.getModelGroupId();
if (Strings.isBlank(modelGroupId)) {
Expand Down Expand Up @@ -322,9 +322,9 @@ public void registerMLModel(MLRegisterModelInput registerModelInput, MLTask mlTa

checkAndAddRunningTask(mlTask, maxRegisterTasksPerNode);
try {
mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment();
mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment();
mlStats.getStat(MLNodeLevelStat.ML_REQUEST_COUNT).increment();
mlStats.createCounterStatIfAbsent(mlTask.getFunctionName(), REGISTER, ML_ACTION_REQUEST_COUNT).increment();
mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).increment();

String modelGroupId = registerModelInput.getModelGroupId();
GetRequest getModelGroupRequest = new GetRequest(ML_MODEL_GROUP_INDEX).id(modelGroupId);
Expand Down Expand Up @@ -384,17 +384,14 @@ public void registerMLModel(MLRegisterModelInput registerModelInput, MLTask mlTa
} catch (Exception e) {
handleException(registerModelInput.getFunctionName(), mlTask.getTaskId(), e);
} finally {
mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment();
mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).decrement();
}
}

private void indexRemoteModel(MLRegisterModelInput registerModelInput, MLTask mlTask, String modelVersion) {
String taskId = mlTask.getTaskId();
FunctionName functionName = mlTask.getFunctionName();
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment();
mlStats.createCounterStatIfAbsent(functionName, REGISTER, ML_ACTION_REQUEST_COUNT).increment();
mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment();

String modelName = registerModelInput.getModelName();
String version = modelVersion == null ? registerModelInput.getVersion() : modelVersion;
Expand Down Expand Up @@ -443,8 +440,6 @@ private void indexRemoteModel(MLRegisterModelInput registerModelInput, MLTask ml
} catch (Exception e) {
logException("Failed to upload model", e, log);
handleException(functionName, taskId, e);
} finally {
mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment();
}
}

Expand All @@ -462,9 +457,6 @@ private void registerModelFromUrl(MLRegisterModelInput registerModelInput, MLTas
String taskId = mlTask.getTaskId();
FunctionName functionName = mlTask.getFunctionName();
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment();
mlStats.createCounterStatIfAbsent(functionName, REGISTER, ML_ACTION_REQUEST_COUNT).increment();
mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment();
String modelName = registerModelInput.getModelName();
String version = modelVersion == null ? registerModelInput.getVersion() : modelVersion;
String modelGroupId = registerModelInput.getModelGroupId();
Expand Down Expand Up @@ -509,8 +501,6 @@ private void registerModelFromUrl(MLRegisterModelInput registerModelInput, MLTas
} catch (Exception e) {
logException("Failed to register model", e, log);
handleException(functionName, taskId, e);
} finally {
mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment();
}
}

Expand Down Expand Up @@ -693,7 +683,7 @@ private void handleException(FunctionName functionName, String taskId, Exception
&& !(e instanceof MLResourceNotFoundException)
&& !(e instanceof IllegalArgumentException)) {
mlStats.createCounterStatIfAbsent(functionName, REGISTER, MLActionLevelStat.ML_ACTION_FAILURE_COUNT).increment();
mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_FAILURE_COUNT).increment();
mlStats.getStat(MLNodeLevelStat.ML_FAILURE_COUNT).increment();
}
Map<String, Object> updated = ImmutableMap.of(ERROR_FIELD, MLExceptionUtils.getRootCauseMessage(e), STATE_FIELD, FAILED);
mlTaskManager.updateMLTask(taskId, updated, TIMEOUT_IN_MILLIS, true);
Expand All @@ -718,7 +708,8 @@ public void deployModel(
ActionListener<String> listener
) {
mlStats.createCounterStatIfAbsent(functionName, ActionName.DEPLOY, ML_ACTION_REQUEST_COUNT).increment();
mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment();
mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).increment();
mlStats.getStat(MLNodeLevelStat.ML_REQUEST_COUNT).increment();
List<String> workerNodes = mlTask.getWorkerNodes();
if (modelCacheHelper.isModelDeployed(modelId)) {
if (workerNodes != null && workerNodes.size() > 0) {
Expand Down Expand Up @@ -800,7 +791,7 @@ public void deployModel(
MLExecutable mlExecutable = mlEngine.deployExecute(mlModel, params);
try {
modelCacheHelper.setMLExecutor(modelId, mlExecutable);
mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT).increment();
mlStats.getStat(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT).increment();
modelCacheHelper.setModelState(modelId, MLModelState.DEPLOYED);
listener.onResponse("successful");
} catch (Exception e) {
Expand All @@ -813,7 +804,7 @@ public void deployModel(
Predictable predictable = mlEngine.deploy(mlModel, params);
try {
modelCacheHelper.setPredictor(modelId, predictable);
mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT).increment();
mlStats.getStat(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT).increment();
modelCacheHelper.setModelState(modelId, MLModelState.DEPLOYED);
Long modelContentSizeInBytes = mlModel.getModelContentSizeInBytes();
long contentSize = modelContentSizeInBytes == null
Expand All @@ -837,6 +828,8 @@ public void deployModel(
})));
} catch (Exception e) {
handleDeployModelException(modelId, functionName, listener, e);
} finally {
mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).decrement();
}
}

Expand All @@ -846,7 +839,7 @@ private void handleDeployModelException(String modelId, FunctionName functionNam
&& !(e instanceof MLResourceNotFoundException)
&& !(e instanceof IllegalArgumentException)) {
mlStats.createCounterStatIfAbsent(functionName, ActionName.DEPLOY, MLActionLevelStat.ML_ACTION_FAILURE_COUNT).increment();
mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_FAILURE_COUNT).increment();
mlStats.getStat(MLNodeLevelStat.ML_FAILURE_COUNT).increment();
}
removeModel(modelId);
listener.onFailure(e);
Expand All @@ -855,7 +848,7 @@ private void handleDeployModelException(String modelId, FunctionName functionNam
private void setupPredictable(String modelId, MLModel mlModel, Map<String, Object> params) {
Predictable predictable = mlEngine.deploy(mlModel, params);
modelCacheHelper.setPredictor(modelId, predictable);
mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT).increment();
mlStats.getStat(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT).increment();
modelCacheHelper.setModelState(modelId, MLModelState.DEPLOYED);
}

Expand Down Expand Up @@ -1056,8 +1049,8 @@ public synchronized Map<String, String> undeployModel(String[] modelIds) {
for (String modelId : modelIds) {
if (modelCacheHelper.isModelDeployed(modelId)) {
modelUndeployStatus.put(modelId, UNDEPLOYED);
mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT).decrement();
mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment();
mlStats.getStat(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT).decrement();
mlStats.getStat(MLNodeLevelStat.ML_REQUEST_COUNT).increment();
mlStats
.createCounterStatIfAbsent(getModelFunctionName(modelId), ActionName.UNDEPLOY, ML_ACTION_REQUEST_COUNT)
.increment();
Expand All @@ -1070,7 +1063,7 @@ public synchronized Map<String, String> undeployModel(String[] modelIds) {
log.debug("undeploy all models {}", Arrays.toString(getLocalDeployedModels()));
for (String modelId : getLocalDeployedModels()) {
modelUndeployStatus.put(modelId, UNDEPLOYED);
mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT).decrement();
mlStats.getStat(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT).decrement();
mlStats.createCounterStatIfAbsent(getModelFunctionName(modelId), ActionName.UNDEPLOY, ML_ACTION_REQUEST_COUNT).increment();
removeModel(modelId);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -305,11 +305,11 @@ public Collection<Object> createComponents(
stats.put(MLClusterLevelStat.ML_MODEL_COUNT, new MLStat<>(true, new CounterSupplier()));
stats.put(MLClusterLevelStat.ML_CONNECTOR_COUNT, new MLStat<>(true, new CounterSupplier()));
// node level stats
stats.put(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT, new MLStat<>(false, new CounterSupplier()));
stats.put(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT, new MLStat<>(false, new CounterSupplier()));
stats.put(MLNodeLevelStat.ML_NODE_TOTAL_FAILURE_COUNT, new MLStat<>(false, new CounterSupplier()));
stats.put(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT, new MLStat<>(false, new CounterSupplier()));
stats.put(MLNodeLevelStat.ML_NODE_TOTAL_CIRCUIT_BREAKER_TRIGGER_COUNT, new MLStat<>(false, new CounterSupplier()));
stats.put(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT, new MLStat<>(false, new CounterSupplier()));
stats.put(MLNodeLevelStat.ML_REQUEST_COUNT, new MLStat<>(false, new CounterSupplier()));
stats.put(MLNodeLevelStat.ML_FAILURE_COUNT, new MLStat<>(false, new CounterSupplier()));
stats.put(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT, new MLStat<>(false, new CounterSupplier()));
stats.put(MLNodeLevelStat.ML_CIRCUIT_BREAKER_TRIGGER_COUNT, new MLStat<>(false, new CounterSupplier()));
this.mlStats = new MLStats(stats);

mlIndicesHandler = new MLIndicesHandler(clusterService, client);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

import org.opensearch.client.node.NodeClient;
Expand Down Expand Up @@ -60,6 +61,12 @@ public class RestMLStatsAction extends BaseRestHandler {
private static final String QUERY_ALL_MODEL_META_DOC =
"{\"query\":{\"bool\":{\"must_not\":{\"exists\":{\"field\":\"chunk_number\"}}}}}";

private static final Set<String> ML_NODE_STAT_NAMES = EnumSet
.allOf(MLNodeLevelStat.class)
.stream()
.map(stat -> stat.name())
.collect(Collectors.toSet());

/**
* Constructor
* @param mlStats MLStats object
Expand Down Expand Up @@ -148,6 +155,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
}

MLStatsInput createMlStatsInputFromRequestParams(RestRequest request) {

MLStatsInput mlStatsInput = new MLStatsInput();
Optional<String[]> nodeIds = splitCommaSeparatedParam(request, "nodeId");
if (nodeIds.isPresent()) {
Expand All @@ -158,7 +166,7 @@ MLStatsInput createMlStatsInputFromRequestParams(RestRequest request) {
for (String state : stats.get()) {
state = state.toUpperCase(Locale.ROOT);
// only support cluster and node level stats for bwc
if (state.startsWith("ML_NODE")) {
if (ML_NODE_STAT_NAMES.contains(state)) {
mlStatsInput.getNodeLevelStats().add(MLNodeLevelStat.from(state));
} else {
mlStatsInput.getClusterLevelStats().add(MLClusterLevelStat.from(state));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@
* This enum represents node level stats.
*/
public enum MLNodeLevelStat {
ML_NODE_JVM_HEAP_USAGE,
ML_NODE_EXECUTING_TASK_COUNT,
ML_NODE_TOTAL_REQUEST_COUNT,
ML_NODE_TOTAL_FAILURE_COUNT,
ML_NODE_TOTAL_MODEL_COUNT,
ML_NODE_TOTAL_CIRCUIT_BREAKER_TRIGGER_COUNT;
ML_JVM_HEAP_USAGE,
ML_EXECUTING_TASK_COUNT, // How many tasks are executing currently. If any task starts, then it will increase by 1,
// if the task finished then it will decrease by 0.
ML_REQUEST_COUNT,
ML_FAILURE_COUNT,
ML_DEPLOYED_MODEL_COUNT,
ML_CIRCUIT_BREAKER_TRIGGER_COUNT;

public static MLNodeLevelStat from(String value) {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ protected TransportResponseHandler<MLExecuteTaskResponse> getResponseHandler(Act
protected void executeTask(MLExecuteTaskRequest request, ActionListener<MLExecuteTaskResponse> listener) {
threadPool.executor(EXECUTE_THREAD_POOL).execute(() -> {
try {
mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment();
mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment();
mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).increment();
mlStats.getStat(MLNodeLevelStat.ML_REQUEST_COUNT).increment();
mlStats
.createCounterStatIfAbsent(request.getFunctionName(), ActionName.EXECUTE, MLActionLevelStat.ML_ACTION_REQUEST_COUNT)
.increment();
Expand All @@ -113,7 +113,7 @@ protected void executeTask(MLExecuteTaskRequest request, ActionListener<MLExecut
.increment();
listener.onFailure(e);
} finally {
mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).decrement();
mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).decrement();
}
});
}
Expand Down
Loading

0 comments on commit 86eb953

Please sign in to comment.