Skip to content

Commit 7e04653

Browse files
authored
fixing metrics correlation algorithm (opensearch-project#4200)
* fixing metrics correlation algorithm Signed-off-by: Dhrubo Saha <dhrubo@amazon.com> * adding unit test Signed-off-by: Dhrubo Saha <dhrubo@amazon.com> --------- Signed-off-by: Dhrubo Saha <dhrubo@amazon.com>
1 parent d252824 commit 7e04653

File tree

2 files changed

+605
-532
lines changed

2 files changed

+605
-532
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelation.java

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -85,12 +85,14 @@ public class MetricsCorrelation extends DLModelExecute {
8585
private Client client;
8686
private final Settings settings;
8787
private final ClusterService clusterService;
88-
// As metrics correlation is an experimental feature we are marking the version as 1.0.0b1
88+
// As metrics correlation is an experimental feature we are marking the version
89+
// as 1.0.0b1
8990
public static final String MCORR_ML_VERSION = "1.0.0b1";
9091
// This is python based model which is developed in house.
9192
public static final String MODEL_TYPE = "in-house";
9293
// This is the opensearch release artifact url for the model
93-
// TODO: we need to make this URL more dynamic so that user can define the version from the settings to pull
94+
// TODO: we need to make this URL more dynamic so that user can define the
95+
// version from the settings to pull
9496
// up the most updated model version.
9597
public static final String MCORR_MODEL_URL =
9698
"https://artifacts.opensearch.org/models/ml-models/amazon/metrics_correlation/1.0.0b1/torch_script/metrics_correlation-1.0.0b1-torch_script.zip";
@@ -102,16 +104,14 @@ public MetricsCorrelation(Client client, Settings settings, ClusterService clust
102104
}
103105

104106
/**
105-
* @param input input data for metrics correlation. This input expects a list of float array (List<float[]>)
106-
* @return MetricsCorrelationOutput output of the metrics correlation algorithm is a list of objects. Each object
107-
* contains 3 properties event_window, event_pattern and suspected_metrics
108-
* @throws ExecuteException
109-
*/
110-
/**
107+
* Executes the metrics correlation algorithm.
111108
*
112-
* @param input input data for metrics correlation. This input expects a list of float array (List<float[]>)
113-
* @param listener action listener which response is MetricsCorrelationOutput, output of the metrics correlation
114-
* algorithm is a list of objects. Each object contains 3 properties event_window, event_pattern and suspected_metrics
109+
* @param input input data for metrics correlation. This input expects a list
110+
* of float arrays (List<float[]>)
111+
* @param listener action listener which receives MetricsCorrelationOutput, a
112+
* list of objects where each object
113+
* contains 3 properties: event_window, event_pattern, and
114+
* suspected_metrics
115115
*/
116116
@Override
117117
public void execute(Input input, ActionListener<org.opensearch.ml.common.output.Output> listener) {
@@ -129,12 +129,15 @@ public void execute(Input input, ActionListener<org.opensearch.ml.common.output.
129129
boolean hasModelGroupIndex = clusterService.state().getMetadata().hasIndex(ML_MODEL_GROUP_INDEX);
130130
if (!hasModelGroupIndex) { // Create model group index if it doesn't exist
131131
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
132-
CreateIndexRequest request = new CreateIndexRequest(ML_MODEL_GROUP_INDEX)
133-
.mapping(ML_MODEL_GROUP_INDEX_MAPPING_PATH, XContentType.JSON);
132+
// Load the mapping content from the file
133+
String mappingContent = org.opensearch.ml.common.utils.IndexUtils.getMappingFromFile(ML_MODEL_GROUP_INDEX_MAPPING_PATH);
134+
CreateIndexRequest request = new CreateIndexRequest(ML_MODEL_GROUP_INDEX).mapping(mappingContent, XContentType.JSON);
134135
CreateIndexResponse createIndexResponse = client.admin().indices().create(request).actionGet(1000);
135136
if (!createIndexResponse.isAcknowledged()) {
136137
throw new MLException("Failed to create model group index");
137138
}
139+
} catch (java.io.IOException e) {
140+
throw new MLException("Failed to load model group index mapping", e);
138141
}
139142
}
140143

@@ -161,7 +164,8 @@ public void execute(Input input, ActionListener<org.opensearch.ml.common.output.
161164
Map<String, Object> sourceAsMap = r.getSourceAsMap();
162165
String state = (String) sourceAsMap.get(MODEL_STATE_FIELD);
163166
if (!MLModelState.DEPLOYED.name().equals(state) && !MLModelState.PARTIALLY_DEPLOYED.name().equals(state)) {
164-
// if we find a model in the index but the model is not deployed then we will deploy the model
167+
// if we find a model in the index but the model is not deployed then we will
168+
// deploy the model
165169
deployModel(
166170
r.getId(),
167171
ActionListener
@@ -173,7 +177,8 @@ public void execute(Input input, ActionListener<org.opensearch.ml.common.output.
173177
}
174178
} else { // If model index doesn't exist, register model
175179
log.info("metric correlation model not registered yet");
176-
// if we don't find any model in the index then we will register a model in the index
180+
// if we don't find any model in the index then we will register a model in the
181+
// index
177182
registerModel(
178183
ActionListener
179184
.wrap(
@@ -200,7 +205,8 @@ public void execute(Input input, ActionListener<org.opensearch.ml.common.output.
200205
}
201206
}
202207

203-
// We will be waiting here until actionListeners set the model id to the modelId.
208+
// We will be waiting here until actionListeners set the model id to the
209+
// modelId.
204210
waitUntil(() -> {
205211
if (modelId != null) {
206212
MLModelState modelState = getModel(modelId).getModelState();
@@ -370,13 +376,15 @@ public MLModel getModel(String modelId) {
370376

371377
/**
372378
* Parse model output to model tensor output and apply result filter.
373-
* @param output model output
379+
*
380+
* @param output model output
374381
* @param resultFilter result filter
375382
* @return model tensor output
376383
*/
377384
public MCorrModelTensors parseModelTensorOutput(ai.djl.modality.Output output, ModelResultFilter resultFilter) {
378385

379-
// This is where we are making the pause. We need find out what will be the best way
386+
// This is where we are making the pause. We need find out what will be the best
387+
// way
380388
// to represent the model output.
381389
if (output == null) {
382390
throw new MLException("No output generated");

0 commit comments

Comments
 (0)