Skip to content

Commit bc4bcc5

Browse files
committed
fixing metrics correlation algorithm
Signed-off-by: Dhrubo Saha <dhrubo@amazon.com>
1 parent 266bcfe commit bc4bcc5

File tree

2 files changed

+131
-92
lines changed

2 files changed

+131
-92
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelation.java

Lines changed: 113 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -85,15 +85,16 @@ public class MetricsCorrelation extends DLModelExecute {
8585
private Client client;
8686
private final Settings settings;
8787
private final ClusterService clusterService;
88-
// As metrics correlation is an experimental feature we are marking the version as 1.0.0b1
88+
// As metrics correlation is an experimental feature we are marking the version
89+
// as 1.0.0b1
8990
public static final String MCORR_ML_VERSION = "1.0.0b1";
9091
// This is python based model which is developed in house.
9192
public static final String MODEL_TYPE = "in-house";
9293
// This is the opensearch release artifact url for the model
93-
// TODO: we need to make this URL more dynamic so that user can define the version from the settings to pull
94+
// TODO: we need to make this URL more dynamic so that user can define the
95+
// version from the settings to pull
9496
// up the most updated model version.
95-
public static final String MCORR_MODEL_URL =
96-
"https://artifacts.opensearch.org/models/ml-models/amazon/metrics_correlation/1.0.0b1/torch_script/metrics_correlation-1.0.0b1-torch_script.zip";
97+
public static final String MCORR_MODEL_URL = "https://artifacts.opensearch.org/models/ml-models/amazon/metrics_correlation/1.0.0b1/torch_script/metrics_correlation-1.0.0b1-torch_script.zip";
9798

9899
public MetricsCorrelation(Client client, Settings settings, ClusterService clusterService) {
99100
this.client = client;
@@ -102,16 +103,14 @@ public MetricsCorrelation(Client client, Settings settings, ClusterService clust
102103
}
103104

104105
/**
105-
* @param input input data for metrics correlation. This input expects a list of float array (List<float[]>)
106-
* @return MetricsCorrelationOutput output of the metrics correlation algorithm is a list of objects. Each object
107-
* contains 3 properties event_window, event_pattern and suspected_metrics
108-
* @throws ExecuteException
109-
*/
110-
/**
106+
* Executes the metrics correlation algorithm.
111107
*
112-
* @param input input data for metrics correlation. This input expects a list of float array (List<float[]>)
113-
* @param listener action listener which response is MetricsCorrelationOutput, output of the metrics correlation
114-
* algorithm is a list of objects. Each object contains 3 properties event_window, event_pattern and suspected_metrics
108+
* @param input input data for metrics correlation. This input expects a list
109+
* of float arrays (List<float[]>)
110+
* @param listener action listener which receives MetricsCorrelationOutput, a
111+
* list of objects where each object
112+
* contains 3 properties: event_window, event_pattern, and
113+
* suspected_metrics
115114
*/
116115
@Override
117116
public void execute(Input input, ActionListener<org.opensearch.ml.common.output.Output> listener) {
@@ -129,12 +128,17 @@ public void execute(Input input, ActionListener<org.opensearch.ml.common.output.
129128
boolean hasModelGroupIndex = clusterService.state().getMetadata().hasIndex(ML_MODEL_GROUP_INDEX);
130129
if (!hasModelGroupIndex) { // Create model group index if it doesn't exist
131130
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
132-
CreateIndexRequest request = new CreateIndexRequest(ML_MODEL_GROUP_INDEX)
133-
.mapping(ML_MODEL_GROUP_INDEX_MAPPING_PATH, XContentType.JSON);
131+
// Load the mapping content from the file
132+
String mappingContent = org.opensearch.ml.common.utils.IndexUtils
133+
.getMappingFromFile(ML_MODEL_GROUP_INDEX_MAPPING_PATH);
134+
CreateIndexRequest request = new CreateIndexRequest(ML_MODEL_GROUP_INDEX).mapping(mappingContent,
135+
XContentType.JSON);
134136
CreateIndexResponse createIndexResponse = client.admin().indices().create(request).actionGet(1000);
135137
if (!createIndexResponse.isAcknowledged()) {
136138
throw new MLException("Failed to create model group index");
137139
}
140+
} catch (java.io.IOException e) {
141+
throw new MLException("Failed to load model group index mapping", e);
138142
}
139143
}
140144

@@ -143,64 +147,75 @@ public void execute(Input input, ActionListener<org.opensearch.ml.common.output.
143147
log.warn("Model Index Not found. Register metric correlation model");
144148
try {
145149
registerModel(
146-
ActionListener
147-
.wrap(
148-
registerModelResponse -> modelId = getTask(registerModelResponse.getTaskId()).getModelId(),
149-
ex -> log.error("Exception during registering the Metrics correlation model", ex)
150-
)
151-
);
150+
ActionListener
151+
.wrap(
152+
registerModelResponse -> modelId = getTask(
153+
registerModelResponse.getTaskId()).getModelId(),
154+
ex -> log.error(
155+
"Exception during registering the Metrics correlation model", ex)));
152156
} catch (InterruptedException ex) {
153157
throw new RuntimeException(ex);
154158
}
155159
} else {
156160
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
157-
GetRequest getModelRequest = new GetRequest(ML_MODEL_INDEX).id(FunctionName.METRICS_CORRELATION.name());
161+
GetRequest getModelRequest = new GetRequest(ML_MODEL_INDEX)
162+
.id(FunctionName.METRICS_CORRELATION.name());
158163
ActionListener<GetResponse> actionListener = ActionListener.wrap(r -> {
159164
if (r.isExists()) {
160165
modelId = r.getId();
161166
Map<String, Object> sourceAsMap = r.getSourceAsMap();
162167
String state = (String) sourceAsMap.get(MODEL_STATE_FIELD);
163-
if (!MLModelState.DEPLOYED.name().equals(state) && !MLModelState.PARTIALLY_DEPLOYED.name().equals(state)) {
164-
// if we find a model in the index but the model is not deployed then we will deploy the model
168+
if (!MLModelState.DEPLOYED.name().equals(state)
169+
&& !MLModelState.PARTIALLY_DEPLOYED.name().equals(state)) {
170+
// if we find a model in the index but the model is not deployed then we will
171+
// deploy the model
165172
deployModel(
166-
r.getId(),
167-
ActionListener
168-
.wrap(
169-
deployModelResponse -> modelId = getTask(deployModelResponse.getTaskId()).getModelId(),
170-
e -> log.error("Metrics correlation model didn't get deployed to the index successfully", e)
171-
)
172-
);
173+
r.getId(),
174+
ActionListener
175+
.wrap(
176+
deployModelResponse -> modelId = getTask(
177+
deployModelResponse.getTaskId()).getModelId(),
178+
e -> log.error(
179+
"Metrics correlation model didn't get deployed to the index successfully",
180+
e)));
173181
}
174182
} else { // If model index doesn't exist, register model
175183
log.info("metric correlation model not registered yet");
176-
// if we don't find any model in the index then we will register a model in the index
184+
// if we don't find any model in the index then we will register a model in the
185+
// index
177186
registerModel(
178-
ActionListener
179-
.wrap(
180-
registerModelResponse -> modelId = getTask(registerModelResponse.getTaskId()).getModelId(),
181-
e -> log.error("Metrics correlation model didn't get registered to the index successfully", e)
182-
)
183-
);
187+
ActionListener
188+
.wrap(
189+
registerModelResponse -> modelId = getTask(
190+
registerModelResponse.getTaskId()).getModelId(),
191+
e -> log.error(
192+
"Metrics correlation model didn't get registered to the index successfully",
193+
e)));
184194
}
185-
}, e -> { log.error("Failed to get model", e); });
195+
}, e -> {
196+
log.error("Failed to get model", e);
197+
});
186198
client.get(getModelRequest, ActionListener.runBefore(actionListener, context::restore));
187199
}
188200
}
189201
} else {
190202
MLModel model = getModel(modelId);
191-
if (model.getModelState() != MLModelState.DEPLOYED && model.getModelState() != MLModelState.PARTIALLY_DEPLOYED) {
203+
if (model.getModelState() != MLModelState.DEPLOYED
204+
&& model.getModelState() != MLModelState.PARTIALLY_DEPLOYED) {
192205
deployModel(
193-
modelId,
194-
ActionListener
195-
.wrap(
196-
deployModelResponse -> modelId = getTask(deployModelResponse.getTaskId()).getModelId(),
197-
e -> log.error("Metrics correlation model didn't get deployed to the index successfully", e)
198-
)
199-
);
206+
modelId,
207+
ActionListener
208+
.wrap(
209+
deployModelResponse -> modelId = getTask(deployModelResponse.getTaskId())
210+
.getModelId(),
211+
e -> log.error(
212+
"Metrics correlation model didn't get deployed to the index successfully",
213+
e)));
200214
}
201215
}
202216

203-
// We will be waiting here until actionListeners set the model id to the modelId.
217+
// We will be waiting here until actionListeners set the model id to the
218+
// modelId.
204219
waitUntil(() -> {
205220
if (modelId != null) {
206221
MLModelState modelState = getModel(modelId).getModelState();
@@ -210,13 +225,14 @@ public void execute(Input input, ActionListener<org.opensearch.ml.common.output.
210225
} else if (modelState == MLModelState.UNDEPLOYED || modelState == MLModelState.DEPLOY_FAILED) {
211226
log.info("Model not deployed: " + modelState);
212227
deployModel(
213-
modelId,
214-
ActionListener
215-
.wrap(
216-
deployModelResponse -> modelId = getTask(deployModelResponse.getTaskId()).getModelId(),
217-
e -> log.error("Metrics correlation model didn't get deployed to the index successfully", e)
218-
)
219-
);
228+
modelId,
229+
ActionListener
230+
.wrap(
231+
deployModelResponse -> modelId = getTask(deployModelResponse.getTaskId())
232+
.getModelId(),
233+
e -> log.error(
234+
"Metrics correlation model didn't get deployed to the index successfully",
235+
e)));
220236
return false;
221237
}
222238
}
@@ -243,37 +259,39 @@ void registerModel(ActionListener<MLRegisterModelResponse> listener) throws Inte
243259
FunctionName functionName = FunctionName.METRICS_CORRELATION;
244260
MLModelFormat modelFormat = MLModelFormat.TORCH_SCRIPT;
245261

246-
MLModelConfig modelConfig = MetricsCorrelationModelConfig.builder().modelType(MODEL_TYPE).allConfig(null).build();
262+
MLModelConfig modelConfig = MetricsCorrelationModelConfig.builder().modelType(MODEL_TYPE).allConfig(null)
263+
.build();
247264
MLRegisterModelInput input = MLRegisterModelInput
248-
.builder()
249-
.functionName(functionName)
250-
.modelName(FunctionName.METRICS_CORRELATION.name())
251-
.version(MCORR_ML_VERSION)
252-
.modelGroupId(functionName.name())
253-
.modelFormat(modelFormat)
254-
.hashValue(MODEL_CONTENT_HASH)
255-
.modelConfig(modelConfig)
256-
.url(MCORR_MODEL_URL)
257-
.deployModel(true)
258-
.build();
265+
.builder()
266+
.functionName(functionName)
267+
.modelName(FunctionName.METRICS_CORRELATION.name())
268+
.version(MCORR_ML_VERSION)
269+
.modelGroupId(functionName.name())
270+
.modelFormat(modelFormat)
271+
.hashValue(MODEL_CONTENT_HASH)
272+
.modelConfig(modelConfig)
273+
.url(MCORR_MODEL_URL)
274+
.deployModel(true)
275+
.build();
259276
MLRegisterModelRequest registerRequest = MLRegisterModelRequest.builder().registerModelInput(input).build();
260277

261278
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
262279
IndexRequest createModelGroupRequest = new IndexRequest(ML_MODEL_GROUP_INDEX).id(functionName.name());
263280
MLModelGroup modelGroup = MLModelGroup
264-
.builder()
265-
.name(functionName.name())
266-
.access(AccessMode.PUBLIC.getValue())
267-
.createdTime(Instant.now())
268-
.build();
281+
.builder()
282+
.name(functionName.name())
283+
.access(AccessMode.PUBLIC.getValue())
284+
.createdTime(Instant.now())
285+
.build();
269286
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
270287
modelGroup.toXContent(builder, ToXContent.EMPTY_PARAMS);
271288
createModelGroupRequest.source(builder);
272289
client.index(createModelGroupRequest, ActionListener.runBefore(ActionListener.wrap(r -> {
273-
client.execute(MLRegisterModelAction.INSTANCE, registerRequest, ActionListener.wrap(listener::onResponse, e -> {
274-
log.error("Failed to Register Model", e);
275-
listener.onFailure(e);
276-
}));
290+
client.execute(MLRegisterModelAction.INSTANCE, registerRequest,
291+
ActionListener.wrap(listener::onResponse, e -> {
292+
log.error("Failed to Register Model", e);
293+
listener.onFailure(e);
294+
}));
277295
}, listener::onFailure), context::restore));
278296
} catch (IOException e) {
279297
throw new MLException(e);
@@ -283,7 +301,8 @@ void registerModel(ActionListener<MLRegisterModelResponse> listener) throws Inte
283301

284302
@VisibleForTesting
285303
void deployModel(final String modelId, ActionListener<MLDeployModelResponse> listener) {
286-
MLDeployModelRequest loadRequest = MLDeployModelRequest.builder().modelId(modelId).async(false).dispatchTask(false).build();
304+
MLDeployModelRequest loadRequest = MLDeployModelRequest.builder().modelId(modelId).async(false)
305+
.dispatchTask(false).build();
287306
client.execute(MLDeployModelAction.INSTANCE, loadRequest, ActionListener.wrap(listener::onResponse, e -> {
288307
log.error("Failed to deploy Model", e);
289308
listener.onFailure(e);
@@ -310,25 +329,25 @@ public MetricsCorrelationTranslator getTranslator() {
310329
SearchRequest getSearchRequest() {
311330
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
312331
searchSourceBuilder
313-
.fetchSource(
314-
new String[] {
315-
MLModel.MODEL_ID_FIELD,
316-
MLModel.MODEL_NAME_FIELD,
317-
MODEL_STATE_FIELD,
318-
MLModel.MODEL_VERSION_FIELD,
319-
MLModel.MODEL_CONTENT_FIELD },
320-
new String[] { MLModel.MODEL_CONTENT_FIELD }
321-
);
332+
.fetchSource(
333+
new String[] {
334+
MLModel.MODEL_ID_FIELD,
335+
MLModel.MODEL_NAME_FIELD,
336+
MODEL_STATE_FIELD,
337+
MLModel.MODEL_VERSION_FIELD,
338+
MLModel.MODEL_CONTENT_FIELD },
339+
new String[] { MLModel.MODEL_CONTENT_FIELD });
322340

323341
BoolQueryBuilder boolQueryBuilder = QueryBuilders
324-
.boolQuery()
325-
.should(termQuery(MLModel.MODEL_NAME_FIELD, FunctionName.METRICS_CORRELATION.name()))
326-
.should(termQuery(MLModel.MODEL_VERSION_FIELD, MCORR_ML_VERSION));
342+
.boolQuery()
343+
.should(termQuery(MLModel.MODEL_NAME_FIELD, FunctionName.METRICS_CORRELATION.name()))
344+
.should(termQuery(MLModel.MODEL_VERSION_FIELD, MCORR_ML_VERSION));
327345
searchSourceBuilder.query(boolQueryBuilder);
328346
return new SearchRequest().source(searchSourceBuilder).indices(CommonValue.ML_MODEL_INDEX);
329347
}
330348

331-
public static boolean waitUntil(BooleanSupplier breakSupplier, long maxWaitTime, TimeUnit unit) throws ExecuteException {
349+
public static boolean waitUntil(BooleanSupplier breakSupplier, long maxWaitTime, TimeUnit unit)
350+
throws ExecuteException {
332351
long maxTimeInMillis = TimeUnit.MILLISECONDS.convert(maxWaitTime, unit);
333352
long timeInMillis = 1;
334353
long sum = 0;
@@ -370,13 +389,15 @@ public MLModel getModel(String modelId) {
370389

371390
/**
372391
* Parse model output to model tensor output and apply result filter.
373-
* @param output model output
392+
*
393+
* @param output model output
374394
* @param resultFilter result filter
375395
* @return model tensor output
376396
*/
377397
public MCorrModelTensors parseModelTensorOutput(ai.djl.modality.Output output, ModelResultFilter resultFilter) {
378398

379-
// This is where we are making the pause. We need find out what will be the best way
399+
// This is where we are making the pause. We need find out what will be the best
400+
// way
380401
// to represent the model output.
381402
if (output == null) {
382403
throw new MLException("No output generated");

0 commit comments

Comments
 (0)