Skip to content

Commit

Permalink
anomaly localization integration step 4 and 5
Browse files Browse the repository at this point in the history
Signed-off-by: lai <57818076+wnbts@users.noreply.github.com>
  • Loading branch information
wnbts committed Jan 21, 2022
1 parent 1d5da1d commit d397a54
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,14 @@
import java.util.Map;
import java.util.Optional;
import java.util.PriorityQueue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import org.opensearch.action.ActionListener;
import org.opensearch.action.LatchedActionListener;
import org.opensearch.action.NotifyOnceListener;
import org.opensearch.action.search.MultiSearchRequest;
import org.opensearch.action.search.MultiSearchResponse;
Expand All @@ -36,6 +39,9 @@
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.RangeQueryBuilder;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.ml.common.parameter.Input;
import org.opensearch.ml.common.parameter.Output;
import org.opensearch.ml.engine.Executable;
import org.opensearch.search.aggregations.AggregationBuilder;
import org.opensearch.search.aggregations.AggregationBuilders;
import org.opensearch.search.aggregations.bucket.composite.CompositeAggregation;
Expand All @@ -58,7 +64,7 @@
* Implementation of AnomalyLocalizer.
*/
@Log4j2
public class AnomalyLocalizerImpl implements AnomalyLocalizer {
public class AnomalyLocalizerImpl implements AnomalyLocalizer, Executable {

// Localize when the change of new value over base value is over the percentage.
protected static final double MIN_DIFF_PCT = 0.01;
Expand Down Expand Up @@ -101,19 +107,22 @@ public void getLocalizationResults(AnomalyLocalizationInput input, ActionListene
/**
* Bucketizes data by time and get overall aggregates.
*/
private void localizeByBuckets(AnomalyLocalizationInput input, AggregationBuilder agg, AnomalyLocalizationOutput output, ActionListener<AnomalyLocalizationOutput> listener) {
private void localizeByBuckets(AnomalyLocalizationInput input, AggregationBuilder agg, AnomalyLocalizationOutput output,
ActionListener<AnomalyLocalizationOutput> listener) {
LocalizationTimeBuckets timeBuckets = getTimeBuckets(input);
getOverallAggregates(input, timeBuckets, agg, output, listener);
}

private void getOverallAggregates(AnomalyLocalizationInput input, LocalizationTimeBuckets timeBuckets, AggregationBuilder agg, AnomalyLocalizationOutput output,
private void getOverallAggregates(AnomalyLocalizationInput input, LocalizationTimeBuckets timeBuckets, AggregationBuilder agg,
AnomalyLocalizationOutput output,
ActionListener<AnomalyLocalizationOutput> listener) {
MultiSearchRequest searchRequest = newSearchRequestForOverallAggregates(input, agg, timeBuckets);
client.multiSearch(searchRequest, wrap(r -> onOverallAggregatesResponse(r, input, agg, output, timeBuckets, listener),
listener::onFailure));
}

private void onOverallAggregatesResponse(MultiSearchResponse response, AnomalyLocalizationInput input, AggregationBuilder agg, AnomalyLocalizationOutput output,
private void onOverallAggregatesResponse(MultiSearchResponse response, AnomalyLocalizationInput input, AggregationBuilder agg,
AnomalyLocalizationOutput output,
LocalizationTimeBuckets timeBuckets, ActionListener<AnomalyLocalizationOutput> listener) {
AnomalyLocalizationOutput.Result result = new AnomalyLocalizationOutput.Result();
List<Map.Entry<Long, Long>> intervals = timeBuckets.getAllIntervals();
Expand All @@ -134,7 +143,8 @@ private void onOverallAggregatesResponse(MultiSearchResponse response, AnomalyLo
/**
* Identifies buckets of data that need localization and localizes entities in the bucket.
*/
private void getLocalizedEntities(AnomalyLocalizationInput input, AggregationBuilder agg, AnomalyLocalizationOutput.Result result, AnomalyLocalizationOutput output,
private void getLocalizedEntities(AnomalyLocalizationInput input, AggregationBuilder agg, AnomalyLocalizationOutput.Result result,
AnomalyLocalizationOutput output,
ActionListener<AnomalyLocalizationOutput> listener) {
if (setBase(result, input)) {
Counter counter = new HybridCounter();
Expand All @@ -155,8 +165,10 @@ private boolean isResultComplete(AnomalyLocalizationOutput.Result result) {
return result.getBuckets().stream().allMatch(e -> e.getCompleted() == null || e.getCompleted().get() == true);
}

private void processBaseEntry(AnomalyLocalizationInput input, AggregationBuilder agg, AnomalyLocalizationOutput.Result result, AnomalyLocalizationOutput.Bucket bucket, Counter counter,
Optional<Map<String, Object>> afterKey, AnomalyLocalizationOutput output, ActionListener<AnomalyLocalizationOutput> listener) {
private void processBaseEntry(AnomalyLocalizationInput input, AggregationBuilder agg, AnomalyLocalizationOutput.Result result,
AnomalyLocalizationOutput.Bucket bucket, Counter counter,
Optional<Map<String, Object>> afterKey, AnomalyLocalizationOutput output,
ActionListener<AnomalyLocalizationOutput> listener) {
SearchRequest request = newSearchRequestForEntry(input, agg, bucket, afterKey);
client.search(request, wrap(r -> onBaseEntryResponse(r, input, agg, result, bucket, counter, output, listener),
listener::onFailure));
Expand All @@ -165,8 +177,10 @@ private void processBaseEntry(AnomalyLocalizationInput input, AggregationBuilder
/**
* Keeps info from entities in the base bucket to compare entities from new buckets against.
*/
private void onBaseEntryResponse(SearchResponse response, AnomalyLocalizationInput input, AggregationBuilder agg, AnomalyLocalizationOutput.Result result,
AnomalyLocalizationOutput.Bucket bucket, Counter counter, AnomalyLocalizationOutput output, ActionListener<AnomalyLocalizationOutput> listener) {
private void onBaseEntryResponse(SearchResponse response, AnomalyLocalizationInput input, AggregationBuilder agg,
AnomalyLocalizationOutput.Result result,
AnomalyLocalizationOutput.Bucket bucket, Counter counter, AnomalyLocalizationOutput output,
ActionListener<AnomalyLocalizationOutput> listener) {
Optional<CompositeAggregation> respAgg =
Optional.ofNullable(response.getAggregations()).map(aggs -> (CompositeAggregation) aggs.get(agg.getName()));
respAgg.map(a -> a.getBuckets()).orElse(Collections.emptyList()).stream().forEach(b -> {
Expand All @@ -193,17 +207,21 @@ private void onBaseEntryResponse(SearchResponse response, AnomalyLocalizationInp
}
}

private void processNewEntry(AnomalyLocalizationInput input, AggregationBuilder agg, AnomalyLocalizationOutput.Result result, AnomalyLocalizationOutput.Bucket bucket, Optional<Map<String,
Object>> afterKey, PriorityQueue<AnomalyLocalizationOutput.Entity> queue, AnomalyLocalizationOutput output, ActionListener<AnomalyLocalizationOutput> listener) {
private void processNewEntry(AnomalyLocalizationInput input, AggregationBuilder agg, AnomalyLocalizationOutput.Result result,
AnomalyLocalizationOutput.Bucket bucket, Optional<Map<String,
Object>> afterKey, PriorityQueue<AnomalyLocalizationOutput.Entity> queue, AnomalyLocalizationOutput output,
ActionListener<AnomalyLocalizationOutput> listener) {
SearchRequest request = newSearchRequestForEntry(input, agg, bucket, afterKey);
client.search(request, wrap(r -> onNewEntryResponse(r, input, agg, result, bucket, queue, output, listener), listener::onFailure));
}

/**
* Chooses entities from the new bucket that contribute the most to the overall change.
*/
private void onNewEntryResponse(SearchResponse response, AnomalyLocalizationInput input, AggregationBuilder agg, AnomalyLocalizationOutput.Result result,
AnomalyLocalizationOutput.Bucket outputBucket, PriorityQueue<AnomalyLocalizationOutput.Entity> queue, AnomalyLocalizationOutput output,
private void onNewEntryResponse(SearchResponse response, AnomalyLocalizationInput input, AggregationBuilder agg,
AnomalyLocalizationOutput.Result result,
AnomalyLocalizationOutput.Bucket outputBucket, PriorityQueue<AnomalyLocalizationOutput.Entity> queue,
AnomalyLocalizationOutput output,
ActionListener<AnomalyLocalizationOutput> listener) {
Optional<CompositeAggregation> respAgg =
Optional.ofNullable(response.getAggregations()).map(aggs -> (CompositeAggregation) aggs.get(agg.getName()));
Expand Down Expand Up @@ -235,8 +253,10 @@ private void onNewEntryResponse(SearchResponse response, AnomalyLocalizationInpu
/**
* Updates to date entity contribution values in final output.
*/
private void onEntityKeysResponse(SearchResponse response, AnomalyLocalizationInput input, AggregationBuilder agg, AnomalyLocalizationOutput.Result result,
AnomalyLocalizationOutput.Bucket bucket, PriorityQueue<AnomalyLocalizationOutput.Entity> queue, AnomalyLocalizationOutput output,
private void onEntityKeysResponse(SearchResponse response, AnomalyLocalizationInput input, AggregationBuilder agg,
AnomalyLocalizationOutput.Result result,
AnomalyLocalizationOutput.Bucket bucket, PriorityQueue<AnomalyLocalizationOutput.Entity> queue,
AnomalyLocalizationOutput output,
ActionListener<AnomalyLocalizationOutput> listener) {
List<AnomalyLocalizationOutput.Entity> entities = new ArrayList<AnomalyLocalizationOutput.Entity>(queue);
Optional<Filters> respAgg = Optional.ofNullable(response.getAggregations()).map(aggs -> (Filters) aggs.get(agg.getName()));
Expand All @@ -257,7 +277,8 @@ private void onEntityKeysResponse(SearchResponse response, AnomalyLocalizationIn
outputIfResultsAreComplete(output, listener);
}

private SearchRequest newSearchRequestForEntityKeys(AnomalyLocalizationInput input, AggregationBuilder agg, AnomalyLocalizationOutput.Bucket bucket,
private SearchRequest newSearchRequestForEntityKeys(AnomalyLocalizationInput input, AggregationBuilder agg,
AnomalyLocalizationOutput.Bucket bucket,
List<List<String>> keys) {
RangeQueryBuilder timeRangeFilter = new RangeQueryBuilder(input.getTimeFieldName())
.from(bucket.getBase().get().getStartTime(), true)
Expand All @@ -283,7 +304,8 @@ private List<String> toStringKey(Map<String, Object> key, AnomalyLocalizationInp
return input.getAttributeFieldNames().stream().map(name -> key.get(name).toString()).collect(Collectors.toList());
}

private SearchRequest newSearchRequestForEntry(AnomalyLocalizationInput input, AggregationBuilder agg, AnomalyLocalizationOutput.Bucket bucket, Optional<Map<String,
private SearchRequest newSearchRequestForEntry(AnomalyLocalizationInput input, AggregationBuilder agg,
AnomalyLocalizationOutput.Bucket bucket, Optional<Map<String,
Object>> afterKey) {
RangeQueryBuilder timeRangeFilter = new RangeQueryBuilder(input.getTimeFieldName())
.from(bucket.getStartTime(), true)
Expand Down Expand Up @@ -397,4 +419,23 @@ protected List<Map.Entry<Long, Long>> getAllIntervals() {
return intervals;
}
}

@Override
public Output execute(Input input) {
CountDownLatch latch = new CountDownLatch(1);
AtomicReference<AnomalyLocalizationOutput> outRef = new AtomicReference<>();
AtomicReference<Exception> exRef = new AtomicReference<>();
getLocalizationResults((AnomalyLocalizationInput) input,
new LatchedActionListener(ActionListener.<AnomalyLocalizationOutput>wrap(o -> outRef.set(o), e -> exRef.set(e)), latch));
try {
latch.await();
} catch (InterruptedException e) {
throw new IllegalStateException(e);
}
if (exRef.get() != null) {
throw new RuntimeException(exRef.get());
} else {
return outRef.get();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ public class AnomalyLocalizerImplTests {
private String indexName = "indexName";
private String attributeFieldNameOne = "attributeOne";
private AggregationBuilder agg = AggregationBuilders.count("count").field("field");
;
private String timeFieldName = "timeFieldName";
private long startTime = 0;
private long endTime = 2;
Expand All @@ -87,7 +86,8 @@ public void setup() {
settings = Settings.builder().build();
anomalyLocalizer = new AnomalyLocalizerImpl(client, settings);

input = new AnomalyLocalizationInput(indexName, Arrays.asList(attributeFieldNameOne), Arrays.asList(agg), timeFieldName, startTime, endTime,
input = new AnomalyLocalizationInput(indexName, Arrays.asList(attributeFieldNameOne), Arrays.asList(agg), timeFieldName,
startTime, endTime,
minTimeInterval, numOutput, Optional.empty(), Optional.empty());

when(valueOne.value()).thenReturn(0.);
Expand Down Expand Up @@ -232,7 +232,8 @@ public void testGetLocalizedResultsGivenNoAnomaly() {
@Test
public void testGetLocalizedResultsGivenAnomaly() {
when(valueThree.value()).thenReturn(Double.NaN);
input = new AnomalyLocalizationInput(indexName, Arrays.asList(attributeFieldNameOne), Arrays.asList(agg), timeFieldName, startTime, endTime,
input = new AnomalyLocalizationInput(indexName, Arrays.asList(attributeFieldNameOne), Arrays.asList(agg), timeFieldName,
startTime, endTime,
minTimeInterval, numOutput, Optional.of(1L), Optional.of(mock(QueryBuilder.class)));

anomalyLocalizer.getLocalizationResults(input, outputListener);
Expand All @@ -245,7 +246,8 @@ public void testGetLocalizedResultsGivenAnomaly() {

@Test(expected = RuntimeException.class)
public void testGetLocalizedResultsForInvalidTimeRange() {
input = new AnomalyLocalizationInput(indexName, Arrays.asList(attributeFieldNameOne), Arrays.asList(agg), timeFieldName, startTime, startTime,
input = new AnomalyLocalizationInput(indexName, Arrays.asList(attributeFieldNameOne), Arrays.asList(agg), timeFieldName,
startTime, startTime,
minTimeInterval, numOutput, Optional.empty(), Optional.empty());

anomalyLocalizer.getLocalizationResults(input, outputListener);
Expand Down Expand Up @@ -308,7 +310,8 @@ public void testGetLocalizedResultsOverallUnchange() {

@Test
public void testGetLocalizedResultsFilterEntity() {
input = new AnomalyLocalizationInput(indexName, Arrays.asList(attributeFieldNameOne), Arrays.asList(agg), timeFieldName, startTime, endTime,
input = new AnomalyLocalizationInput(indexName, Arrays.asList(attributeFieldNameOne), Arrays.asList(agg), timeFieldName,
startTime, endTime,
minTimeInterval, 2, Optional.empty(), Optional.empty());

anomalyLocalizer.getLocalizationResults(input, outputListener);
Expand All @@ -318,6 +321,33 @@ public void testGetLocalizedResultsFilterEntity() {
AnomalyLocalizationOutput actualOutput = outputCaptor.getValue();
assertEquals(expectedOutput, actualOutput);
}

@Test
public void testExecuteSucceed() {
AnomalyLocalizationOutput actualOutput = (AnomalyLocalizationOutput) anomalyLocalizer.execute(input);

assertEquals(expectedOutput, actualOutput);
}

@Test(expected = RuntimeException.class)
public void testExecuteFail() {
doAnswer(new Answer() {
public Object answer(InvocationOnMock invocation) {
Object[] args = invocation.getArguments();
ActionListener<MultiSearchResponse> listener = (ActionListener<MultiSearchResponse>) args[1];
listener.onFailure(new RuntimeException());
return null;
}
}
).when(client).multiSearch(any(), any());
anomalyLocalizer.execute(input);
}

@Test(expected = RuntimeException.class)
public void testExecuteInterrupted() {
Thread.currentThread().interrupt();
anomalyLocalizer.execute(input);
}
}


Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
import org.opensearch.ml.common.transport.training.MLTrainingTaskAction;
import org.opensearch.ml.engine.MLEngineClassLoader;
import org.opensearch.ml.engine.algorithms.anomalylocalization.AnomalyLocalizationInput;
import org.opensearch.ml.engine.algorithms.anomalylocalization.AnomalyLocalizerImpl;
import org.opensearch.ml.engine.algorithms.sample.LocalSampleCalculator;
import org.opensearch.ml.indices.MLIndicesHandler;
import org.opensearch.ml.indices.MLInputDatasetHandler;
Expand Down Expand Up @@ -184,6 +185,9 @@ public Collection<Object> createComponents(
LocalSampleCalculator localSampleCalculator = new LocalSampleCalculator(client, settings);
MLEngineClassLoader.register(FunctionName.LOCAL_SAMPLE_CALCULATOR, localSampleCalculator);

AnomalyLocalizerImpl anomalyLocalizer = new AnomalyLocalizerImpl(client, settings);
MLEngineClassLoader.register(FunctionName.ANOMALY_LOCALIZATION, anomalyLocalizer);

return ImmutableList
.of(
mlStats,
Expand Down

0 comments on commit d397a54

Please sign in to comment.