Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

anomaly localization integration step 4 and 5 #125

Merged
merged 1 commit into from
Jan 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,14 @@
import java.util.Map;
import java.util.Optional;
import java.util.PriorityQueue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import org.opensearch.action.ActionListener;
import org.opensearch.action.LatchedActionListener;
import org.opensearch.action.NotifyOnceListener;
import org.opensearch.action.search.MultiSearchRequest;
import org.opensearch.action.search.MultiSearchResponse;
Expand All @@ -36,6 +39,9 @@
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.RangeQueryBuilder;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.ml.common.parameter.Input;
import org.opensearch.ml.common.parameter.Output;
import org.opensearch.ml.engine.Executable;
import org.opensearch.search.aggregations.AggregationBuilder;
import org.opensearch.search.aggregations.AggregationBuilders;
import org.opensearch.search.aggregations.bucket.composite.CompositeAggregation;
Expand All @@ -58,7 +64,7 @@
* Implementation of AnomalyLocalizer.
*/
@Log4j2
public class AnomalyLocalizerImpl implements AnomalyLocalizer {
public class AnomalyLocalizerImpl implements AnomalyLocalizer, Executable {

// Localize when the change of new value over base value is over the percentage.
protected static final double MIN_DIFF_PCT = 0.01;
Expand Down Expand Up @@ -101,19 +107,22 @@ public void getLocalizationResults(AnomalyLocalizationInput input, ActionListene
/**
* Bucketizes data by time and get overall aggregates.
*/
private void localizeByBuckets(AnomalyLocalizationInput input, AggregationBuilder agg, AnomalyLocalizationOutput output, ActionListener<AnomalyLocalizationOutput> listener) {
private void localizeByBuckets(AnomalyLocalizationInput input, AggregationBuilder agg, AnomalyLocalizationOutput output,
ActionListener<AnomalyLocalizationOutput> listener) {
LocalizationTimeBuckets timeBuckets = getTimeBuckets(input);
getOverallAggregates(input, timeBuckets, agg, output, listener);
}

private void getOverallAggregates(AnomalyLocalizationInput input, LocalizationTimeBuckets timeBuckets, AggregationBuilder agg, AnomalyLocalizationOutput output,
private void getOverallAggregates(AnomalyLocalizationInput input, LocalizationTimeBuckets timeBuckets, AggregationBuilder agg,
AnomalyLocalizationOutput output,
ActionListener<AnomalyLocalizationOutput> listener) {
MultiSearchRequest searchRequest = newSearchRequestForOverallAggregates(input, agg, timeBuckets);
client.multiSearch(searchRequest, wrap(r -> onOverallAggregatesResponse(r, input, agg, output, timeBuckets, listener),
listener::onFailure));
}

private void onOverallAggregatesResponse(MultiSearchResponse response, AnomalyLocalizationInput input, AggregationBuilder agg, AnomalyLocalizationOutput output,
private void onOverallAggregatesResponse(MultiSearchResponse response, AnomalyLocalizationInput input, AggregationBuilder agg,
AnomalyLocalizationOutput output,
LocalizationTimeBuckets timeBuckets, ActionListener<AnomalyLocalizationOutput> listener) {
AnomalyLocalizationOutput.Result result = new AnomalyLocalizationOutput.Result();
List<Map.Entry<Long, Long>> intervals = timeBuckets.getAllIntervals();
Expand All @@ -134,7 +143,8 @@ private void onOverallAggregatesResponse(MultiSearchResponse response, AnomalyLo
/**
* Identifies buckets of data that need localization and localizes entities in the bucket.
*/
private void getLocalizedEntities(AnomalyLocalizationInput input, AggregationBuilder agg, AnomalyLocalizationOutput.Result result, AnomalyLocalizationOutput output,
private void getLocalizedEntities(AnomalyLocalizationInput input, AggregationBuilder agg, AnomalyLocalizationOutput.Result result,
AnomalyLocalizationOutput output,
ActionListener<AnomalyLocalizationOutput> listener) {
if (setBase(result, input)) {
Counter counter = new HybridCounter();
Expand All @@ -155,8 +165,10 @@ private boolean isResultComplete(AnomalyLocalizationOutput.Result result) {
return result.getBuckets().stream().allMatch(e -> e.getCompleted() == null || e.getCompleted().get() == true);
}

private void processBaseEntry(AnomalyLocalizationInput input, AggregationBuilder agg, AnomalyLocalizationOutput.Result result, AnomalyLocalizationOutput.Bucket bucket, Counter counter,
Optional<Map<String, Object>> afterKey, AnomalyLocalizationOutput output, ActionListener<AnomalyLocalizationOutput> listener) {
private void processBaseEntry(AnomalyLocalizationInput input, AggregationBuilder agg, AnomalyLocalizationOutput.Result result,
AnomalyLocalizationOutput.Bucket bucket, Counter counter,
Optional<Map<String, Object>> afterKey, AnomalyLocalizationOutput output,
ActionListener<AnomalyLocalizationOutput> listener) {
SearchRequest request = newSearchRequestForEntry(input, agg, bucket, afterKey);
client.search(request, wrap(r -> onBaseEntryResponse(r, input, agg, result, bucket, counter, output, listener),
listener::onFailure));
Expand All @@ -165,8 +177,10 @@ private void processBaseEntry(AnomalyLocalizationInput input, AggregationBuilder
/**
* Keeps info from entities in the base bucket to compare entities from new buckets against.
*/
private void onBaseEntryResponse(SearchResponse response, AnomalyLocalizationInput input, AggregationBuilder agg, AnomalyLocalizationOutput.Result result,
AnomalyLocalizationOutput.Bucket bucket, Counter counter, AnomalyLocalizationOutput output, ActionListener<AnomalyLocalizationOutput> listener) {
private void onBaseEntryResponse(SearchResponse response, AnomalyLocalizationInput input, AggregationBuilder agg,
AnomalyLocalizationOutput.Result result,
AnomalyLocalizationOutput.Bucket bucket, Counter counter, AnomalyLocalizationOutput output,
ActionListener<AnomalyLocalizationOutput> listener) {
Optional<CompositeAggregation> respAgg =
Optional.ofNullable(response.getAggregations()).map(aggs -> (CompositeAggregation) aggs.get(agg.getName()));
respAgg.map(a -> a.getBuckets()).orElse(Collections.emptyList()).stream().forEach(b -> {
Expand All @@ -193,17 +207,21 @@ private void onBaseEntryResponse(SearchResponse response, AnomalyLocalizationInp
}
}

private void processNewEntry(AnomalyLocalizationInput input, AggregationBuilder agg, AnomalyLocalizationOutput.Result result, AnomalyLocalizationOutput.Bucket bucket, Optional<Map<String,
Object>> afterKey, PriorityQueue<AnomalyLocalizationOutput.Entity> queue, AnomalyLocalizationOutput output, ActionListener<AnomalyLocalizationOutput> listener) {
private void processNewEntry(AnomalyLocalizationInput input, AggregationBuilder agg, AnomalyLocalizationOutput.Result result,
AnomalyLocalizationOutput.Bucket bucket, Optional<Map<String,
Object>> afterKey, PriorityQueue<AnomalyLocalizationOutput.Entity> queue, AnomalyLocalizationOutput output,
ActionListener<AnomalyLocalizationOutput> listener) {
SearchRequest request = newSearchRequestForEntry(input, agg, bucket, afterKey);
client.search(request, wrap(r -> onNewEntryResponse(r, input, agg, result, bucket, queue, output, listener), listener::onFailure));
}

/**
* Chooses entities from the new bucket that contribute the most to the overall change.
*/
private void onNewEntryResponse(SearchResponse response, AnomalyLocalizationInput input, AggregationBuilder agg, AnomalyLocalizationOutput.Result result,
AnomalyLocalizationOutput.Bucket outputBucket, PriorityQueue<AnomalyLocalizationOutput.Entity> queue, AnomalyLocalizationOutput output,
private void onNewEntryResponse(SearchResponse response, AnomalyLocalizationInput input, AggregationBuilder agg,
AnomalyLocalizationOutput.Result result,
AnomalyLocalizationOutput.Bucket outputBucket, PriorityQueue<AnomalyLocalizationOutput.Entity> queue,
AnomalyLocalizationOutput output,
ActionListener<AnomalyLocalizationOutput> listener) {
Optional<CompositeAggregation> respAgg =
Optional.ofNullable(response.getAggregations()).map(aggs -> (CompositeAggregation) aggs.get(agg.getName()));
Expand Down Expand Up @@ -235,8 +253,10 @@ private void onNewEntryResponse(SearchResponse response, AnomalyLocalizationInpu
/**
* Updates to date entity contribution values in final output.
*/
private void onEntityKeysResponse(SearchResponse response, AnomalyLocalizationInput input, AggregationBuilder agg, AnomalyLocalizationOutput.Result result,
AnomalyLocalizationOutput.Bucket bucket, PriorityQueue<AnomalyLocalizationOutput.Entity> queue, AnomalyLocalizationOutput output,
private void onEntityKeysResponse(SearchResponse response, AnomalyLocalizationInput input, AggregationBuilder agg,
AnomalyLocalizationOutput.Result result,
AnomalyLocalizationOutput.Bucket bucket, PriorityQueue<AnomalyLocalizationOutput.Entity> queue,
AnomalyLocalizationOutput output,
ActionListener<AnomalyLocalizationOutput> listener) {
List<AnomalyLocalizationOutput.Entity> entities = new ArrayList<AnomalyLocalizationOutput.Entity>(queue);
Optional<Filters> respAgg = Optional.ofNullable(response.getAggregations()).map(aggs -> (Filters) aggs.get(agg.getName()));
Expand All @@ -257,7 +277,8 @@ private void onEntityKeysResponse(SearchResponse response, AnomalyLocalizationIn
outputIfResultsAreComplete(output, listener);
}

private SearchRequest newSearchRequestForEntityKeys(AnomalyLocalizationInput input, AggregationBuilder agg, AnomalyLocalizationOutput.Bucket bucket,
private SearchRequest newSearchRequestForEntityKeys(AnomalyLocalizationInput input, AggregationBuilder agg,
AnomalyLocalizationOutput.Bucket bucket,
List<List<String>> keys) {
RangeQueryBuilder timeRangeFilter = new RangeQueryBuilder(input.getTimeFieldName())
.from(bucket.getBase().get().getStartTime(), true)
Expand All @@ -283,7 +304,8 @@ private List<String> toStringKey(Map<String, Object> key, AnomalyLocalizationInp
return input.getAttributeFieldNames().stream().map(name -> key.get(name).toString()).collect(Collectors.toList());
}

private SearchRequest newSearchRequestForEntry(AnomalyLocalizationInput input, AggregationBuilder agg, AnomalyLocalizationOutput.Bucket bucket, Optional<Map<String,
private SearchRequest newSearchRequestForEntry(AnomalyLocalizationInput input, AggregationBuilder agg,
AnomalyLocalizationOutput.Bucket bucket, Optional<Map<String,
Object>> afterKey) {
RangeQueryBuilder timeRangeFilter = new RangeQueryBuilder(input.getTimeFieldName())
.from(bucket.getStartTime(), true)
Expand Down Expand Up @@ -397,4 +419,23 @@ protected List<Map.Entry<Long, Long>> getAllIntervals() {
return intervals;
}
}

@Override
public Output execute(Input input) {
CountDownLatch latch = new CountDownLatch(1);
AtomicReference<AnomalyLocalizationOutput> outRef = new AtomicReference<>();
AtomicReference<Exception> exRef = new AtomicReference<>();
getLocalizationResults((AnomalyLocalizationInput) input,
new LatchedActionListener(ActionListener.<AnomalyLocalizationOutput>wrap(o -> outRef.set(o), e -> exRef.set(e)), latch));
try {
latch.await();
} catch (InterruptedException e) {
throw new IllegalStateException(e);
}
if (exRef.get() != null) {
throw new RuntimeException(exRef.get());
} else {
return outRef.get();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ public class AnomalyLocalizerImplTests {
private String indexName = "indexName";
private String attributeFieldNameOne = "attributeOne";
private AggregationBuilder agg = AggregationBuilders.count("count").field("field");
;
private String timeFieldName = "timeFieldName";
private long startTime = 0;
private long endTime = 2;
Expand All @@ -87,7 +86,8 @@ public void setup() {
settings = Settings.builder().build();
anomalyLocalizer = new AnomalyLocalizerImpl(client, settings);

input = new AnomalyLocalizationInput(indexName, Arrays.asList(attributeFieldNameOne), Arrays.asList(agg), timeFieldName, startTime, endTime,
input = new AnomalyLocalizationInput(indexName, Arrays.asList(attributeFieldNameOne), Arrays.asList(agg), timeFieldName,
startTime, endTime,
minTimeInterval, numOutput, Optional.empty(), Optional.empty());

when(valueOne.value()).thenReturn(0.);
Expand Down Expand Up @@ -232,7 +232,8 @@ public void testGetLocalizedResultsGivenNoAnomaly() {
@Test
public void testGetLocalizedResultsGivenAnomaly() {
when(valueThree.value()).thenReturn(Double.NaN);
input = new AnomalyLocalizationInput(indexName, Arrays.asList(attributeFieldNameOne), Arrays.asList(agg), timeFieldName, startTime, endTime,
input = new AnomalyLocalizationInput(indexName, Arrays.asList(attributeFieldNameOne), Arrays.asList(agg), timeFieldName,
startTime, endTime,
minTimeInterval, numOutput, Optional.of(1L), Optional.of(mock(QueryBuilder.class)));

anomalyLocalizer.getLocalizationResults(input, outputListener);
Expand All @@ -245,7 +246,8 @@ public void testGetLocalizedResultsGivenAnomaly() {

@Test(expected = RuntimeException.class)
public void testGetLocalizedResultsForInvalidTimeRange() {
input = new AnomalyLocalizationInput(indexName, Arrays.asList(attributeFieldNameOne), Arrays.asList(agg), timeFieldName, startTime, startTime,
input = new AnomalyLocalizationInput(indexName, Arrays.asList(attributeFieldNameOne), Arrays.asList(agg), timeFieldName,
startTime, startTime,
minTimeInterval, numOutput, Optional.empty(), Optional.empty());

anomalyLocalizer.getLocalizationResults(input, outputListener);
Expand Down Expand Up @@ -308,7 +310,8 @@ public void testGetLocalizedResultsOverallUnchange() {

@Test
public void testGetLocalizedResultsFilterEntity() {
input = new AnomalyLocalizationInput(indexName, Arrays.asList(attributeFieldNameOne), Arrays.asList(agg), timeFieldName, startTime, endTime,
input = new AnomalyLocalizationInput(indexName, Arrays.asList(attributeFieldNameOne), Arrays.asList(agg), timeFieldName,
startTime, endTime,
minTimeInterval, 2, Optional.empty(), Optional.empty());

anomalyLocalizer.getLocalizationResults(input, outputListener);
Expand All @@ -318,6 +321,33 @@ public void testGetLocalizedResultsFilterEntity() {
AnomalyLocalizationOutput actualOutput = outputCaptor.getValue();
assertEquals(expectedOutput, actualOutput);
}

@Test
public void testExecuteSucceed() {
AnomalyLocalizationOutput actualOutput = (AnomalyLocalizationOutput) anomalyLocalizer.execute(input);

assertEquals(expectedOutput, actualOutput);
}

@Test(expected = RuntimeException.class)
public void testExecuteFail() {
doAnswer(new Answer() {
public Object answer(InvocationOnMock invocation) {
Object[] args = invocation.getArguments();
ActionListener<MultiSearchResponse> listener = (ActionListener<MultiSearchResponse>) args[1];
listener.onFailure(new RuntimeException());
return null;
}
}
).when(client).multiSearch(any(), any());
anomalyLocalizer.execute(input);
}

@Test(expected = RuntimeException.class)
public void testExecuteInterrupted() {
Thread.currentThread().interrupt();
anomalyLocalizer.execute(input);
}
}


Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
import org.opensearch.ml.common.transport.training.MLTrainingTaskAction;
import org.opensearch.ml.engine.MLEngineClassLoader;
import org.opensearch.ml.engine.algorithms.anomalylocalization.AnomalyLocalizationInput;
import org.opensearch.ml.engine.algorithms.anomalylocalization.AnomalyLocalizerImpl;
import org.opensearch.ml.engine.algorithms.sample.LocalSampleCalculator;
import org.opensearch.ml.indices.MLIndicesHandler;
import org.opensearch.ml.indices.MLInputDatasetHandler;
Expand Down Expand Up @@ -184,6 +185,9 @@ public Collection<Object> createComponents(
LocalSampleCalculator localSampleCalculator = new LocalSampleCalculator(client, settings);
MLEngineClassLoader.register(FunctionName.LOCAL_SAMPLE_CALCULATOR, localSampleCalculator);

AnomalyLocalizerImpl anomalyLocalizer = new AnomalyLocalizerImpl(client, settings);
MLEngineClassLoader.register(FunctionName.ANOMALY_LOCALIZATION, anomalyLocalizer);

return ImmutableList
.of(
mlStats,
Expand Down