@@ -85,15 +85,16 @@ public class MetricsCorrelation extends DLModelExecute {
8585 private Client client ;
8686 private final Settings settings ;
8787 private final ClusterService clusterService ;
88- // As metrics correlation is an experimental feature we are marking the version as 1.0.0b1
88+ // As metrics correlation is an experimental feature we are marking the version
89+ // as 1.0.0b1
8990 public static final String MCORR_ML_VERSION = "1.0.0b1" ;
9091 // This is python based model which is developed in house.
9192 public static final String MODEL_TYPE = "in-house" ;
9293 // This is the opensearch release artifact url for the model
93- // TODO: we need to make this URL more dynamic so that user can define the version from the settings to pull
94+ // TODO: we need to make this URL more dynamic so that user can define the
95+ // version from the settings to pull
9496 // up the most updated model version.
95- public static final String MCORR_MODEL_URL =
96- "https://artifacts.opensearch.org/models/ml-models/amazon/metrics_correlation/1.0.0b1/torch_script/metrics_correlation-1.0.0b1-torch_script.zip" ;
97+ public static final String MCORR_MODEL_URL = "https://artifacts.opensearch.org/models/ml-models/amazon/metrics_correlation/1.0.0b1/torch_script/metrics_correlation-1.0.0b1-torch_script.zip" ;
9798
9899 public MetricsCorrelation (Client client , Settings settings , ClusterService clusterService ) {
99100 this .client = client ;
@@ -102,16 +103,14 @@ public MetricsCorrelation(Client client, Settings settings, ClusterService clust
102103 }
103104
104105 /**
105- * @param input input data for metrics correlation. This input expects a list of float array (List<float[]>)
106- * @return MetricsCorrelationOutput output of the metrics correlation algorithm is a list of objects. Each object
107- * contains 3 properties event_window, event_pattern and suspected_metrics
108- * @throws ExecuteException
109- */
110- /**
106+ * Executes the metrics correlation algorithm.
111107 *
112- * @param input input data for metrics correlation. This input expects a list of float array (List<float[]>)
113- * @param listener action listener which response is MetricsCorrelationOutput, output of the metrics correlation
114- * algorithm is a list of objects. Each object contains 3 properties event_window, event_pattern and suspected_metrics
108+ * @param input input data for metrics correlation. This input expects a list
109+ * of float arrays (List<float[]>)
110+ * @param listener action listener which receives MetricsCorrelationOutput, a
111+ * list of objects where each object
112+ * contains 3 properties: event_window, event_pattern, and
113+ * suspected_metrics
115114 */
116115 @ Override
117116 public void execute (Input input , ActionListener <org .opensearch .ml .common .output .Output > listener ) {
@@ -129,12 +128,17 @@ public void execute(Input input, ActionListener<org.opensearch.ml.common.output.
129128 boolean hasModelGroupIndex = clusterService .state ().getMetadata ().hasIndex (ML_MODEL_GROUP_INDEX );
130129 if (!hasModelGroupIndex ) { // Create model group index if it doesn't exist
131130 try (ThreadContext .StoredContext context = client .threadPool ().getThreadContext ().stashContext ()) {
132- CreateIndexRequest request = new CreateIndexRequest (ML_MODEL_GROUP_INDEX )
133- .mapping (ML_MODEL_GROUP_INDEX_MAPPING_PATH , XContentType .JSON );
131+ // Load the mapping content from the file
132+ String mappingContent = org .opensearch .ml .common .utils .IndexUtils
133+ .getMappingFromFile (ML_MODEL_GROUP_INDEX_MAPPING_PATH );
134+ CreateIndexRequest request = new CreateIndexRequest (ML_MODEL_GROUP_INDEX ).mapping (mappingContent ,
135+ XContentType .JSON );
134136 CreateIndexResponse createIndexResponse = client .admin ().indices ().create (request ).actionGet (1000 );
135137 if (!createIndexResponse .isAcknowledged ()) {
136138 throw new MLException ("Failed to create model group index" );
137139 }
140+ } catch (java .io .IOException e ) {
141+ throw new MLException ("Failed to load model group index mapping" , e );
138142 }
139143 }
140144
@@ -143,64 +147,75 @@ public void execute(Input input, ActionListener<org.opensearch.ml.common.output.
143147 log .warn ("Model Index Not found. Register metric correlation model" );
144148 try {
145149 registerModel (
146- ActionListener
147- .wrap (
148- registerModelResponse -> modelId = getTask (registerModelResponse . getTaskId ()). getModelId (),
149- ex -> log . error ( "Exception during registering the Metrics correlation model" , ex )
150- )
151- );
150+ ActionListener
151+ .wrap (
152+ registerModelResponse -> modelId = getTask (
153+ registerModelResponse . getTaskId ()). getModelId (),
154+ ex -> log . error (
155+ "Exception during registering the Metrics correlation model" , ex )) );
152156 } catch (InterruptedException ex ) {
153157 throw new RuntimeException (ex );
154158 }
155159 } else {
156160 try (ThreadContext .StoredContext context = client .threadPool ().getThreadContext ().stashContext ()) {
157- GetRequest getModelRequest = new GetRequest (ML_MODEL_INDEX ).id (FunctionName .METRICS_CORRELATION .name ());
161+ GetRequest getModelRequest = new GetRequest (ML_MODEL_INDEX )
162+ .id (FunctionName .METRICS_CORRELATION .name ());
158163 ActionListener <GetResponse > actionListener = ActionListener .wrap (r -> {
159164 if (r .isExists ()) {
160165 modelId = r .getId ();
161166 Map <String , Object > sourceAsMap = r .getSourceAsMap ();
162167 String state = (String ) sourceAsMap .get (MODEL_STATE_FIELD );
163- if (!MLModelState .DEPLOYED .name ().equals (state ) && !MLModelState .PARTIALLY_DEPLOYED .name ().equals (state )) {
164- // if we find a model in the index but the model is not deployed then we will deploy the model
168+ if (!MLModelState .DEPLOYED .name ().equals (state )
169+ && !MLModelState .PARTIALLY_DEPLOYED .name ().equals (state )) {
170+ // if we find a model in the index but the model is not deployed then we will
171+ // deploy the model
165172 deployModel (
166- r .getId (),
167- ActionListener
168- .wrap (
169- deployModelResponse -> modelId = getTask (deployModelResponse .getTaskId ()).getModelId (),
170- e -> log .error ("Metrics correlation model didn't get deployed to the index successfully" , e )
171- )
172- );
173+ r .getId (),
174+ ActionListener
175+ .wrap (
176+ deployModelResponse -> modelId = getTask (
177+ deployModelResponse .getTaskId ()).getModelId (),
178+ e -> log .error (
179+ "Metrics correlation model didn't get deployed to the index successfully" ,
180+ e )));
173181 }
174182 } else { // If model index doesn't exist, register model
175183 log .info ("metric correlation model not registered yet" );
176- // if we don't find any model in the index then we will register a model in the index
184+ // if we don't find any model in the index then we will register a model in the
185+ // index
177186 registerModel (
178- ActionListener
179- .wrap (
180- registerModelResponse -> modelId = getTask (registerModelResponse .getTaskId ()).getModelId (),
181- e -> log .error ("Metrics correlation model didn't get registered to the index successfully" , e )
182- )
183- );
187+ ActionListener
188+ .wrap (
189+ registerModelResponse -> modelId = getTask (
190+ registerModelResponse .getTaskId ()).getModelId (),
191+ e -> log .error (
192+ "Metrics correlation model didn't get registered to the index successfully" ,
193+ e )));
184194 }
185- }, e -> { log .error ("Failed to get model" , e ); });
195+ }, e -> {
196+ log .error ("Failed to get model" , e );
197+ });
186198 client .get (getModelRequest , ActionListener .runBefore (actionListener , context ::restore ));
187199 }
188200 }
189201 } else {
190202 MLModel model = getModel (modelId );
191- if (model .getModelState () != MLModelState .DEPLOYED && model .getModelState () != MLModelState .PARTIALLY_DEPLOYED ) {
203+ if (model .getModelState () != MLModelState .DEPLOYED
204+ && model .getModelState () != MLModelState .PARTIALLY_DEPLOYED ) {
192205 deployModel (
193- modelId ,
194- ActionListener
195- .wrap (
196- deployModelResponse -> modelId = getTask (deployModelResponse .getTaskId ()).getModelId (),
197- e -> log .error ("Metrics correlation model didn't get deployed to the index successfully" , e )
198- )
199- );
206+ modelId ,
207+ ActionListener
208+ .wrap (
209+ deployModelResponse -> modelId = getTask (deployModelResponse .getTaskId ())
210+ .getModelId (),
211+ e -> log .error (
212+ "Metrics correlation model didn't get deployed to the index successfully" ,
213+ e )));
200214 }
201215 }
202216
203- // We will be waiting here until actionListeners set the model id to the modelId.
217+ // We will be waiting here until actionListeners set the model id to the
218+ // modelId.
204219 waitUntil (() -> {
205220 if (modelId != null ) {
206221 MLModelState modelState = getModel (modelId ).getModelState ();
@@ -210,13 +225,14 @@ public void execute(Input input, ActionListener<org.opensearch.ml.common.output.
210225 } else if (modelState == MLModelState .UNDEPLOYED || modelState == MLModelState .DEPLOY_FAILED ) {
211226 log .info ("Model not deployed: " + modelState );
212227 deployModel (
213- modelId ,
214- ActionListener
215- .wrap (
216- deployModelResponse -> modelId = getTask (deployModelResponse .getTaskId ()).getModelId (),
217- e -> log .error ("Metrics correlation model didn't get deployed to the index successfully" , e )
218- )
219- );
228+ modelId ,
229+ ActionListener
230+ .wrap (
231+ deployModelResponse -> modelId = getTask (deployModelResponse .getTaskId ())
232+ .getModelId (),
233+ e -> log .error (
234+ "Metrics correlation model didn't get deployed to the index successfully" ,
235+ e )));
220236 return false ;
221237 }
222238 }
@@ -243,37 +259,39 @@ void registerModel(ActionListener<MLRegisterModelResponse> listener) throws Inte
243259 FunctionName functionName = FunctionName .METRICS_CORRELATION ;
244260 MLModelFormat modelFormat = MLModelFormat .TORCH_SCRIPT ;
245261
246- MLModelConfig modelConfig = MetricsCorrelationModelConfig .builder ().modelType (MODEL_TYPE ).allConfig (null ).build ();
262+ MLModelConfig modelConfig = MetricsCorrelationModelConfig .builder ().modelType (MODEL_TYPE ).allConfig (null )
263+ .build ();
247264 MLRegisterModelInput input = MLRegisterModelInput
248- .builder ()
249- .functionName (functionName )
250- .modelName (FunctionName .METRICS_CORRELATION .name ())
251- .version (MCORR_ML_VERSION )
252- .modelGroupId (functionName .name ())
253- .modelFormat (modelFormat )
254- .hashValue (MODEL_CONTENT_HASH )
255- .modelConfig (modelConfig )
256- .url (MCORR_MODEL_URL )
257- .deployModel (true )
258- .build ();
265+ .builder ()
266+ .functionName (functionName )
267+ .modelName (FunctionName .METRICS_CORRELATION .name ())
268+ .version (MCORR_ML_VERSION )
269+ .modelGroupId (functionName .name ())
270+ .modelFormat (modelFormat )
271+ .hashValue (MODEL_CONTENT_HASH )
272+ .modelConfig (modelConfig )
273+ .url (MCORR_MODEL_URL )
274+ .deployModel (true )
275+ .build ();
259276 MLRegisterModelRequest registerRequest = MLRegisterModelRequest .builder ().registerModelInput (input ).build ();
260277
261278 try (ThreadContext .StoredContext context = client .threadPool ().getThreadContext ().stashContext ()) {
262279 IndexRequest createModelGroupRequest = new IndexRequest (ML_MODEL_GROUP_INDEX ).id (functionName .name ());
263280 MLModelGroup modelGroup = MLModelGroup
264- .builder ()
265- .name (functionName .name ())
266- .access (AccessMode .PUBLIC .getValue ())
267- .createdTime (Instant .now ())
268- .build ();
281+ .builder ()
282+ .name (functionName .name ())
283+ .access (AccessMode .PUBLIC .getValue ())
284+ .createdTime (Instant .now ())
285+ .build ();
269286 XContentBuilder builder = XContentBuilder .builder (XContentType .JSON .xContent ());
270287 modelGroup .toXContent (builder , ToXContent .EMPTY_PARAMS );
271288 createModelGroupRequest .source (builder );
272289 client .index (createModelGroupRequest , ActionListener .runBefore (ActionListener .wrap (r -> {
273- client .execute (MLRegisterModelAction .INSTANCE , registerRequest , ActionListener .wrap (listener ::onResponse , e -> {
274- log .error ("Failed to Register Model" , e );
275- listener .onFailure (e );
276- }));
290+ client .execute (MLRegisterModelAction .INSTANCE , registerRequest ,
291+ ActionListener .wrap (listener ::onResponse , e -> {
292+ log .error ("Failed to Register Model" , e );
293+ listener .onFailure (e );
294+ }));
277295 }, listener ::onFailure ), context ::restore ));
278296 } catch (IOException e ) {
279297 throw new MLException (e );
@@ -283,7 +301,8 @@ void registerModel(ActionListener<MLRegisterModelResponse> listener) throws Inte
283301
284302 @ VisibleForTesting
285303 void deployModel (final String modelId , ActionListener <MLDeployModelResponse > listener ) {
286- MLDeployModelRequest loadRequest = MLDeployModelRequest .builder ().modelId (modelId ).async (false ).dispatchTask (false ).build ();
304+ MLDeployModelRequest loadRequest = MLDeployModelRequest .builder ().modelId (modelId ).async (false )
305+ .dispatchTask (false ).build ();
287306 client .execute (MLDeployModelAction .INSTANCE , loadRequest , ActionListener .wrap (listener ::onResponse , e -> {
288307 log .error ("Failed to deploy Model" , e );
289308 listener .onFailure (e );
@@ -310,25 +329,25 @@ public MetricsCorrelationTranslator getTranslator() {
310329 SearchRequest getSearchRequest () {
311330 SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder ();
312331 searchSourceBuilder
313- .fetchSource (
314- new String [] {
315- MLModel .MODEL_ID_FIELD ,
316- MLModel .MODEL_NAME_FIELD ,
317- MODEL_STATE_FIELD ,
318- MLModel .MODEL_VERSION_FIELD ,
319- MLModel .MODEL_CONTENT_FIELD },
320- new String [] { MLModel .MODEL_CONTENT_FIELD }
321- );
332+ .fetchSource (
333+ new String [] {
334+ MLModel .MODEL_ID_FIELD ,
335+ MLModel .MODEL_NAME_FIELD ,
336+ MODEL_STATE_FIELD ,
337+ MLModel .MODEL_VERSION_FIELD ,
338+ MLModel .MODEL_CONTENT_FIELD },
339+ new String [] { MLModel .MODEL_CONTENT_FIELD });
322340
323341 BoolQueryBuilder boolQueryBuilder = QueryBuilders
324- .boolQuery ()
325- .should (termQuery (MLModel .MODEL_NAME_FIELD , FunctionName .METRICS_CORRELATION .name ()))
326- .should (termQuery (MLModel .MODEL_VERSION_FIELD , MCORR_ML_VERSION ));
342+ .boolQuery ()
343+ .should (termQuery (MLModel .MODEL_NAME_FIELD , FunctionName .METRICS_CORRELATION .name ()))
344+ .should (termQuery (MLModel .MODEL_VERSION_FIELD , MCORR_ML_VERSION ));
327345 searchSourceBuilder .query (boolQueryBuilder );
328346 return new SearchRequest ().source (searchSourceBuilder ).indices (CommonValue .ML_MODEL_INDEX );
329347 }
330348
331- public static boolean waitUntil (BooleanSupplier breakSupplier , long maxWaitTime , TimeUnit unit ) throws ExecuteException {
349+ public static boolean waitUntil (BooleanSupplier breakSupplier , long maxWaitTime , TimeUnit unit )
350+ throws ExecuteException {
332351 long maxTimeInMillis = TimeUnit .MILLISECONDS .convert (maxWaitTime , unit );
333352 long timeInMillis = 1 ;
334353 long sum = 0 ;
@@ -370,13 +389,15 @@ public MLModel getModel(String modelId) {
370389
371390 /**
372391 * Parse model output to model tensor output and apply result filter.
373- * @param output model output
392+ *
393+ * @param output model output
374394 * @param resultFilter result filter
375395 * @return model tensor output
376396 */
377397 public MCorrModelTensors parseModelTensorOutput (ai .djl .modality .Output output , ModelResultFilter resultFilter ) {
378398
379- // This is where we are making the pause. We need find out what will be the best way
399+ // This is where we are making the pause. We need find out what will be the best
400+ // way
380401 // to represent the model output.
381402 if (output == null ) {
382403 throw new MLException ("No output generated" );
0 commit comments