diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/anomalylocalization/AnomalyLocalizerImpl.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/anomalylocalization/AnomalyLocalizerImpl.java index a3b1998bc4..031f601abe 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/anomalylocalization/AnomalyLocalizerImpl.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/anomalylocalization/AnomalyLocalizerImpl.java @@ -20,11 +20,14 @@ import java.util.Map; import java.util.Optional; import java.util.PriorityQueue; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; import java.util.stream.IntStream; import org.opensearch.action.ActionListener; +import org.opensearch.action.LatchedActionListener; import org.opensearch.action.NotifyOnceListener; import org.opensearch.action.search.MultiSearchRequest; import org.opensearch.action.search.MultiSearchResponse; @@ -36,6 +39,9 @@ import org.opensearch.index.query.QueryBuilders; import org.opensearch.index.query.RangeQueryBuilder; import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.ml.common.parameter.Input; +import org.opensearch.ml.common.parameter.Output; +import org.opensearch.ml.engine.Executable; import org.opensearch.search.aggregations.AggregationBuilder; import org.opensearch.search.aggregations.AggregationBuilders; import org.opensearch.search.aggregations.bucket.composite.CompositeAggregation; @@ -58,7 +64,7 @@ * Implementation of AnomalyLocalizer. */ @Log4j2 -public class AnomalyLocalizerImpl implements AnomalyLocalizer { +public class AnomalyLocalizerImpl implements AnomalyLocalizer, Executable { // Localize when the change of new value over base value is over the percentage. protected static final double MIN_DIFF_PCT = 0.01; @@ -101,19 +107,22 @@ public void getLocalizationResults(AnomalyLocalizationInput input, ActionListene /** * Bucketizes data by time and get overall aggregates. */ - private void localizeByBuckets(AnomalyLocalizationInput input, AggregationBuilder agg, AnomalyLocalizationOutput output, ActionListener listener) { + private void localizeByBuckets(AnomalyLocalizationInput input, AggregationBuilder agg, AnomalyLocalizationOutput output, + ActionListener listener) { LocalizationTimeBuckets timeBuckets = getTimeBuckets(input); getOverallAggregates(input, timeBuckets, agg, output, listener); } - private void getOverallAggregates(AnomalyLocalizationInput input, LocalizationTimeBuckets timeBuckets, AggregationBuilder agg, AnomalyLocalizationOutput output, + private void getOverallAggregates(AnomalyLocalizationInput input, LocalizationTimeBuckets timeBuckets, AggregationBuilder agg, + AnomalyLocalizationOutput output, ActionListener listener) { MultiSearchRequest searchRequest = newSearchRequestForOverallAggregates(input, agg, timeBuckets); client.multiSearch(searchRequest, wrap(r -> onOverallAggregatesResponse(r, input, agg, output, timeBuckets, listener), listener::onFailure)); } - private void onOverallAggregatesResponse(MultiSearchResponse response, AnomalyLocalizationInput input, AggregationBuilder agg, AnomalyLocalizationOutput output, + private void onOverallAggregatesResponse(MultiSearchResponse response, AnomalyLocalizationInput input, AggregationBuilder agg, + AnomalyLocalizationOutput output, LocalizationTimeBuckets timeBuckets, ActionListener listener) { AnomalyLocalizationOutput.Result result = new AnomalyLocalizationOutput.Result(); List> intervals = timeBuckets.getAllIntervals(); @@ -134,7 +143,8 @@ private void onOverallAggregatesResponse(MultiSearchResponse response, AnomalyLo /** * Identifies buckets of data that need localization and localizes entities in the bucket. */ - private void getLocalizedEntities(AnomalyLocalizationInput input, AggregationBuilder agg, AnomalyLocalizationOutput.Result result, AnomalyLocalizationOutput output, + private void getLocalizedEntities(AnomalyLocalizationInput input, AggregationBuilder agg, AnomalyLocalizationOutput.Result result, + AnomalyLocalizationOutput output, ActionListener listener) { if (setBase(result, input)) { Counter counter = new HybridCounter(); @@ -155,8 +165,10 @@ private boolean isResultComplete(AnomalyLocalizationOutput.Result result) { return result.getBuckets().stream().allMatch(e -> e.getCompleted() == null || e.getCompleted().get() == true); } - private void processBaseEntry(AnomalyLocalizationInput input, AggregationBuilder agg, AnomalyLocalizationOutput.Result result, AnomalyLocalizationOutput.Bucket bucket, Counter counter, - Optional> afterKey, AnomalyLocalizationOutput output, ActionListener listener) { + private void processBaseEntry(AnomalyLocalizationInput input, AggregationBuilder agg, AnomalyLocalizationOutput.Result result, + AnomalyLocalizationOutput.Bucket bucket, Counter counter, + Optional> afterKey, AnomalyLocalizationOutput output, + ActionListener listener) { SearchRequest request = newSearchRequestForEntry(input, agg, bucket, afterKey); client.search(request, wrap(r -> onBaseEntryResponse(r, input, agg, result, bucket, counter, output, listener), listener::onFailure)); @@ -165,8 +177,10 @@ private void processBaseEntry(AnomalyLocalizationInput input, AggregationBuilder /** * Keeps info from entities in the base bucket to compare entities from new buckets against. */ - private void onBaseEntryResponse(SearchResponse response, AnomalyLocalizationInput input, AggregationBuilder agg, AnomalyLocalizationOutput.Result result, - AnomalyLocalizationOutput.Bucket bucket, Counter counter, AnomalyLocalizationOutput output, ActionListener listener) { + private void onBaseEntryResponse(SearchResponse response, AnomalyLocalizationInput input, AggregationBuilder agg, + AnomalyLocalizationOutput.Result result, + AnomalyLocalizationOutput.Bucket bucket, Counter counter, AnomalyLocalizationOutput output, + ActionListener listener) { Optional respAgg = Optional.ofNullable(response.getAggregations()).map(aggs -> (CompositeAggregation) aggs.get(agg.getName())); respAgg.map(a -> a.getBuckets()).orElse(Collections.emptyList()).stream().forEach(b -> { @@ -193,8 +207,10 @@ private void onBaseEntryResponse(SearchResponse response, AnomalyLocalizationInp } } - private void processNewEntry(AnomalyLocalizationInput input, AggregationBuilder agg, AnomalyLocalizationOutput.Result result, AnomalyLocalizationOutput.Bucket bucket, Optional> afterKey, PriorityQueue queue, AnomalyLocalizationOutput output, ActionListener listener) { + private void processNewEntry(AnomalyLocalizationInput input, AggregationBuilder agg, AnomalyLocalizationOutput.Result result, + AnomalyLocalizationOutput.Bucket bucket, Optional> afterKey, PriorityQueue queue, AnomalyLocalizationOutput output, + ActionListener listener) { SearchRequest request = newSearchRequestForEntry(input, agg, bucket, afterKey); client.search(request, wrap(r -> onNewEntryResponse(r, input, agg, result, bucket, queue, output, listener), listener::onFailure)); } @@ -202,8 +218,10 @@ private void processNewEntry(AnomalyLocalizationInput input, AggregationBuilder /** * Chooses entities from the new bucket that contribute the most to the overall change. */ - private void onNewEntryResponse(SearchResponse response, AnomalyLocalizationInput input, AggregationBuilder agg, AnomalyLocalizationOutput.Result result, - AnomalyLocalizationOutput.Bucket outputBucket, PriorityQueue queue, AnomalyLocalizationOutput output, + private void onNewEntryResponse(SearchResponse response, AnomalyLocalizationInput input, AggregationBuilder agg, + AnomalyLocalizationOutput.Result result, + AnomalyLocalizationOutput.Bucket outputBucket, PriorityQueue queue, + AnomalyLocalizationOutput output, ActionListener listener) { Optional respAgg = Optional.ofNullable(response.getAggregations()).map(aggs -> (CompositeAggregation) aggs.get(agg.getName())); @@ -235,8 +253,10 @@ private void onNewEntryResponse(SearchResponse response, AnomalyLocalizationInpu /** * Updates to date entity contribution values in final output. */ - private void onEntityKeysResponse(SearchResponse response, AnomalyLocalizationInput input, AggregationBuilder agg, AnomalyLocalizationOutput.Result result, - AnomalyLocalizationOutput.Bucket bucket, PriorityQueue queue, AnomalyLocalizationOutput output, + private void onEntityKeysResponse(SearchResponse response, AnomalyLocalizationInput input, AggregationBuilder agg, + AnomalyLocalizationOutput.Result result, + AnomalyLocalizationOutput.Bucket bucket, PriorityQueue queue, + AnomalyLocalizationOutput output, ActionListener listener) { List entities = new ArrayList(queue); Optional respAgg = Optional.ofNullable(response.getAggregations()).map(aggs -> (Filters) aggs.get(agg.getName())); @@ -257,7 +277,8 @@ private void onEntityKeysResponse(SearchResponse response, AnomalyLocalizationIn outputIfResultsAreComplete(output, listener); } - private SearchRequest newSearchRequestForEntityKeys(AnomalyLocalizationInput input, AggregationBuilder agg, AnomalyLocalizationOutput.Bucket bucket, + private SearchRequest newSearchRequestForEntityKeys(AnomalyLocalizationInput input, AggregationBuilder agg, + AnomalyLocalizationOutput.Bucket bucket, List> keys) { RangeQueryBuilder timeRangeFilter = new RangeQueryBuilder(input.getTimeFieldName()) .from(bucket.getBase().get().getStartTime(), true) @@ -283,7 +304,8 @@ private List toStringKey(Map key, AnomalyLocalizationInp return input.getAttributeFieldNames().stream().map(name -> key.get(name).toString()).collect(Collectors.toList()); } - private SearchRequest newSearchRequestForEntry(AnomalyLocalizationInput input, AggregationBuilder agg, AnomalyLocalizationOutput.Bucket bucket, Optional> afterKey) { RangeQueryBuilder timeRangeFilter = new RangeQueryBuilder(input.getTimeFieldName()) .from(bucket.getStartTime(), true) @@ -397,4 +419,23 @@ protected List> getAllIntervals() { return intervals; } } + + @Override + public Output execute(Input input) { + CountDownLatch latch = new CountDownLatch(1); + AtomicReference outRef = new AtomicReference<>(); + AtomicReference exRef = new AtomicReference<>(); + getLocalizationResults((AnomalyLocalizationInput) input, + new LatchedActionListener(ActionListener.wrap(o -> outRef.set(o), e -> exRef.set(e)), latch)); + try { + latch.await(); + } catch (InterruptedException e) { + throw new IllegalStateException(e); + } + if (exRef.get() != null) { + throw new RuntimeException(exRef.get()); + } else { + return outRef.get(); + } + } } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/anomalylocalization/AnomalyLocalizerImplTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/anomalylocalization/AnomalyLocalizerImplTests.java index ced511a48e..d610cfaf02 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/anomalylocalization/AnomalyLocalizerImplTests.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/anomalylocalization/AnomalyLocalizerImplTests.java @@ -61,7 +61,6 @@ public class AnomalyLocalizerImplTests { private String indexName = "indexName"; private String attributeFieldNameOne = "attributeOne"; private AggregationBuilder agg = AggregationBuilders.count("count").field("field"); - ; private String timeFieldName = "timeFieldName"; private long startTime = 0; private long endTime = 2; @@ -87,7 +86,8 @@ public void setup() { settings = Settings.builder().build(); anomalyLocalizer = new AnomalyLocalizerImpl(client, settings); - input = new AnomalyLocalizationInput(indexName, Arrays.asList(attributeFieldNameOne), Arrays.asList(agg), timeFieldName, startTime, endTime, + input = new AnomalyLocalizationInput(indexName, Arrays.asList(attributeFieldNameOne), Arrays.asList(agg), timeFieldName, + startTime, endTime, minTimeInterval, numOutput, Optional.empty(), Optional.empty()); when(valueOne.value()).thenReturn(0.); @@ -232,7 +232,8 @@ public void testGetLocalizedResultsGivenNoAnomaly() { @Test public void testGetLocalizedResultsGivenAnomaly() { when(valueThree.value()).thenReturn(Double.NaN); - input = new AnomalyLocalizationInput(indexName, Arrays.asList(attributeFieldNameOne), Arrays.asList(agg), timeFieldName, startTime, endTime, + input = new AnomalyLocalizationInput(indexName, Arrays.asList(attributeFieldNameOne), Arrays.asList(agg), timeFieldName, + startTime, endTime, minTimeInterval, numOutput, Optional.of(1L), Optional.of(mock(QueryBuilder.class))); anomalyLocalizer.getLocalizationResults(input, outputListener); @@ -245,7 +246,8 @@ public void testGetLocalizedResultsGivenAnomaly() { @Test(expected = RuntimeException.class) public void testGetLocalizedResultsForInvalidTimeRange() { - input = new AnomalyLocalizationInput(indexName, Arrays.asList(attributeFieldNameOne), Arrays.asList(agg), timeFieldName, startTime, startTime, + input = new AnomalyLocalizationInput(indexName, Arrays.asList(attributeFieldNameOne), Arrays.asList(agg), timeFieldName, + startTime, startTime, minTimeInterval, numOutput, Optional.empty(), Optional.empty()); anomalyLocalizer.getLocalizationResults(input, outputListener); @@ -308,7 +310,8 @@ public void testGetLocalizedResultsOverallUnchange() { @Test public void testGetLocalizedResultsFilterEntity() { - input = new AnomalyLocalizationInput(indexName, Arrays.asList(attributeFieldNameOne), Arrays.asList(agg), timeFieldName, startTime, endTime, + input = new AnomalyLocalizationInput(indexName, Arrays.asList(attributeFieldNameOne), Arrays.asList(agg), timeFieldName, + startTime, endTime, minTimeInterval, 2, Optional.empty(), Optional.empty()); anomalyLocalizer.getLocalizationResults(input, outputListener); @@ -318,6 +321,33 @@ public void testGetLocalizedResultsFilterEntity() { AnomalyLocalizationOutput actualOutput = outputCaptor.getValue(); assertEquals(expectedOutput, actualOutput); } + + @Test + public void testExecuteSucceed() { + AnomalyLocalizationOutput actualOutput = (AnomalyLocalizationOutput) anomalyLocalizer.execute(input); + + assertEquals(expectedOutput, actualOutput); + } + + @Test(expected = RuntimeException.class) + public void testExecuteFail() { + doAnswer(new Answer() { + public Object answer(InvocationOnMock invocation) { + Object[] args = invocation.getArguments(); + ActionListener listener = (ActionListener) args[1]; + listener.onFailure(new RuntimeException()); + return null; + } + } + ).when(client).multiSearch(any(), any()); + anomalyLocalizer.execute(input); + } + + @Test(expected = RuntimeException.class) + public void testExecuteInterrupted() { + Thread.currentThread().interrupt(); + anomalyLocalizer.execute(input); + } } diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 5ef1958047..aa06a6406d 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -53,6 +53,7 @@ import org.opensearch.ml.common.transport.training.MLTrainingTaskAction; import org.opensearch.ml.engine.MLEngineClassLoader; import org.opensearch.ml.engine.algorithms.anomalylocalization.AnomalyLocalizationInput; +import org.opensearch.ml.engine.algorithms.anomalylocalization.AnomalyLocalizerImpl; import org.opensearch.ml.engine.algorithms.sample.LocalSampleCalculator; import org.opensearch.ml.indices.MLIndicesHandler; import org.opensearch.ml.indices.MLInputDatasetHandler; @@ -184,6 +185,9 @@ public Collection createComponents( LocalSampleCalculator localSampleCalculator = new LocalSampleCalculator(client, settings); MLEngineClassLoader.register(FunctionName.LOCAL_SAMPLE_CALCULATOR, localSampleCalculator); + AnomalyLocalizerImpl anomalyLocalizer = new AnomalyLocalizerImpl(client, settings); + MLEngineClassLoader.register(FunctionName.ANOMALY_LOCALIZATION, anomalyLocalizer); + return ImmutableList .of( mlStats,