Skip to content

Commit

Permalink
Adding negative cache to AD
Browse files Browse the repository at this point in the history
  • Loading branch information
zhanghg08 committed Jan 30, 2020
1 parent 9aa9e46 commit bc6a763
Show file tree
Hide file tree
Showing 10 changed files with 225 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ public HourlyCron(ClusterService clusterService, Client client) {
public void run() {
DiscoveryNode[] dataNodes = clusterService.state().nodes().getDataNodes().values().toArray(DiscoveryNode.class);

// we also add the cancel query function here based on query text from the negative cache.

CronRequest modelDeleteRequest = new CronRequest(dataNodes);
client.execute(CronAction.INSTANCE, modelDeleteRequest, ActionListener.wrap(response -> {
if (response.hasFailures()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@

import com.amazon.opendistroforelasticsearch.ad.model.AnomalyDetector;
import com.amazon.opendistroforelasticsearch.ad.model.IntervalTimeConfiguration;
import com.amazon.opendistroforelasticsearch.ad.transport.ADStateManager;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
Expand Down Expand Up @@ -115,15 +116,16 @@ public FeatureManager(
* @param detector anomaly detector for which the features are returned
* @param startTime start time of the data point in epoch milliseconds
* @param endTime end time of the data point in epoch milliseconds
* @param stateManager ADStateManager
* @return unprocessed features and processed features for the current data point
*/
@Deprecated
public SinglePointFeatures getCurrentFeatures(AnomalyDetector detector, long startTime, long endTime) {
public SinglePointFeatures getCurrentFeatures(AnomalyDetector detector, long startTime, long endTime, ADStateManager stateManager) {
double[][] currentPoints = null;
Deque<Entry<Long, double[]>> shingle = detectorIdsToTimeShingles
.computeIfAbsent(detector.getDetectorId(), id -> new ArrayDeque<Entry<Long, double[]>>(shingleSize));
if (shingle.isEmpty() || shingle.getLast().getKey() < endTime) {
Optional<double[]> point = searchFeatureDao.getFeaturesForPeriod(detector, startTime, endTime);
Optional<double[]> point = searchFeatureDao.getFeaturesForPeriod(detector, startTime, endTime, stateManager);
if (point.isPresent()) {
if (shingle.size() == shingleSize) {
shingle.remove();
Expand Down Expand Up @@ -174,13 +176,16 @@ private double[][] filterAndFill(Deque<Entry<Long, double[]>> shingle, long endT
* in dimension via shingling.
*
* @param detector contains data info (indices, documents, etc)
* @param stateManager ADStateManager
* @return data for cold-start training, or empty if unavailable
*/
@Deprecated
public Optional<double[][]> getColdStartData(AnomalyDetector detector) {
public Optional<double[][]> getColdStartData(AnomalyDetector detector, ADStateManager stateManager) {
return searchFeatureDao
.getLatestDataTime(detector)
.flatMap(latest -> searchFeatureDao.getFeaturesForSampledPeriods(detector, maxTrainSamples, maxSampleStride, latest))
.flatMap(
latest -> searchFeatureDao.getFeaturesForSampledPeriods(detector, maxTrainSamples, maxSampleStride, latest, stateManager)
)
.map(
samples -> transpose(
interpolator.interpolate(transpose(samples.getKey()), samples.getValue() * (samples.getKey().length - 1) + 1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

import com.amazon.opendistroforelasticsearch.ad.model.AnomalyDetector;
import com.amazon.opendistroforelasticsearch.ad.model.IntervalTimeConfiguration;
import com.amazon.opendistroforelasticsearch.ad.transport.ADStateManager;
import com.amazon.opendistroforelasticsearch.ad.util.ClientUtil;
import com.amazon.opendistroforelasticsearch.ad.util.ParseUtils;
import org.apache.logging.log4j.LogManager;
Expand Down Expand Up @@ -114,18 +115,24 @@ public Optional<Long> getLatestDataTime(AnomalyDetector detector) {
}

/**
* Gets features for the given time period.
* Gets features for the given time period. This function also add given detector to negative cache before sending es request.
* Once we get response/exception within timeout, we treat this request as complete and clear the negative cache.
* Otherwise this detector entry remain in the negative to reject further request.
*
* @param detector info about indices, documents, feature query
* @param startTime epoch milliseconds at the beginning of the period
* @param endTime epoch milliseconds at the end of the period
* @param stateManager ADStateManager
* @throws IllegalStateException when unexpected failures happen
* @return features from search results, empty when no data found
*/
public Optional<double[]> getFeaturesForPeriod(AnomalyDetector detector, long startTime, long endTime) {
public Optional<double[]> getFeaturesForPeriod(AnomalyDetector detector, long startTime, long endTime, ADStateManager stateManager) {
SearchRequest searchRequest = createFeatureSearchRequest(detector, startTime, endTime, Optional.empty());
// add (detectorId, filteredQuery) to negative cache
stateManager.insertFilteredQuery(detector, searchRequest);
// send throttled request: this request will clear the negative cache if the request finished within timeout
return clientUtil
.<SearchRequest, SearchResponse>timedRequest(searchRequest, logger, client::search)
.<SearchRequest, SearchResponse>throttledTimedRequest(searchRequest, logger, client::search, stateManager, detector)
.flatMap(resp -> parseResponse(resp, detector.getEnabledFeatureIds()));
}

Expand Down Expand Up @@ -242,20 +249,22 @@ public void getFeatureSamplesForPeriods(
* @param maxSamples the maximum number of samples to return
* @param maxStride the maximum number of periods between samples
* @param endTime the end time of the latest period
* @param stateManager ADStateManager
* @return sampled features and stride, empty when no data found
*/
public Optional<Entry<double[][], Integer>> getFeaturesForSampledPeriods(
AnomalyDetector detector,
int maxSamples,
int maxStride,
long endTime
long endTime,
ADStateManager stateManager
) {
Map<Long, double[]> cache = new HashMap<>();
int currentStride = maxStride;
Optional<double[][]> features = Optional.empty();
while (currentStride >= 1) {
boolean isInterpolatable = currentStride < maxStride;
features = getFeaturesForSampledPeriods(detector, maxSamples, currentStride, endTime, cache, isInterpolatable);
features = getFeaturesForSampledPeriods(detector, maxSamples, currentStride, endTime, cache, isInterpolatable, stateManager);
if (!features.isPresent() || features.get().length > maxSamples / 2 || currentStride == 1) {
break;
} else {
Expand All @@ -275,7 +284,8 @@ private Optional<double[][]> getFeaturesForSampledPeriods(
int stride,
long endTime,
Map<Long, double[]> cache,
boolean isInterpolatable
boolean isInterpolatable,
ADStateManager stateManager
) {
ArrayDeque<double[]> sampledFeatures = new ArrayDeque<>(maxSamples);
for (int i = 0; i < maxSamples; i++) {
Expand All @@ -284,7 +294,7 @@ private Optional<double[][]> getFeaturesForSampledPeriods(
if (cache.containsKey(end)) {
sampledFeatures.addFirst(cache.get(end));
} else {
Optional<double[]> features = getFeaturesForPeriod(detector, end - span, end);
Optional<double[]> features = getFeaturesForPeriod(detector, end - span, end, stateManager);
if (features.isPresent()) {
cache.put(end, features.get());
sampledFeatures.addFirst(features.get());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.get.GetRequest;
import org.elasticsearch.action.get.GetResponse;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.client.Client;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
Expand All @@ -55,6 +56,9 @@ public class ADStateManager {
private static final Logger LOG = LogManager.getLogger(ADStateManager.class);
private ConcurrentHashMap<String, Entry<AnomalyDetector, Instant>> currentDetectors;
private ConcurrentHashMap<String, Entry<Integer, Instant>> partitionNumber;
// negativeCache is used to reject search query if given detector already has one query running
// key is detectorId, value is an entry. Key is QueryBuilder and value is the timestamp
private ConcurrentHashMap<String, Entry<SearchRequest, Instant>> negativeCache;
private Client client;
private Random random;
private ModelManager modelManager;
Expand Down Expand Up @@ -83,6 +87,7 @@ public ADStateManager(
this.partitionNumber = new ConcurrentHashMap<>();
this.clientUtil = clientUtil;
this.backpressureMuter = new ConcurrentHashMap<>();
this.negativeCache = new ConcurrentHashMap<>();
this.clock = clock;
this.settings = settings;
this.stateTtl = stateTtl;
Expand Down Expand Up @@ -119,6 +124,47 @@ public int getPartitionNumber(String adID) throws InterruptedException {
return partitionNum;
}

/**
* Get negative cache value(QueryBuilder, Instant) for given detector
* If detectorId is null, return Optional.empty()
* @param detector AnomalyDetector
* @return negative cache value(QueryBuilder, Instant)
*/
public Optional<Entry<SearchRequest, Instant>> getFilteredQuery(AnomalyDetector detector) {
if (detector.getDetectorId() == null) {
return Optional.empty();
}
if (negativeCache.containsKey(detector.getDetectorId())) {
return Optional.of(negativeCache.get(detector.getDetectorId()));
}
return Optional.empty();
}

/**
* Insert the negative cache entry for given detector
* If detectorId is null, do nothing
* @param detector AnomalyDetector
* @param searchRequest ES search request
*/
public void insertFilteredQuery(AnomalyDetector detector, SearchRequest searchRequest) {
if (detector.getDetectorId() == null) {
return;
}
negativeCache.putIfAbsent(detector.getDetectorId(), new SimpleEntry<>(searchRequest, clock.instant()));
}

/**
* Clear the negative cache for given detector.
* If detectorId is null, do nothing
* @param detector AnomalyDetector
*/
public void clearFilteredQuery(AnomalyDetector detector) {
if (detector.getDetectorId() == null) {
return;
}
negativeCache.keySet().removeIf(key -> key.equals(detector.getDetectorId()));
}

public Optional<AnomalyDetector> getAnomalyDetector(String adID) {
Entry<AnomalyDetector, Instant> detectorAndTime = currentDetectors.get(adID);
if (detectorAndTime != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.concurrent.Callable;
import java.util.concurrent.CountDownLatch;
Expand Down Expand Up @@ -67,6 +68,7 @@
import org.elasticsearch.action.bulk.BackoffPolicy;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.index.IndexResponse;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.action.support.IndicesOptions;
Expand Down Expand Up @@ -249,6 +251,12 @@ protected void doExecute(Task task, ActionRequest actionRequest, ActionListener<
return;
}
AnomalyDetector anomalyDetector = detector.get();
Optional<Entry<SearchRequest, Instant>> queryEntry = stateManager.getFilteredQuery(anomalyDetector);
if (queryEntry.isPresent()) {
LOG.info("There is one query running for detectorId: {}", anomalyDetector.getDetectorId());
listener.onResponse(new AnomalyResultResponse(Double.NaN, Double.NaN, new ArrayList<FeatureData>()));
return;
}

String thresholdModelID = modelManager.getThresholdModelId(adID);
Optional<DiscoveryNode> thresholdNode = hashRing.getOwningNode(thresholdModelID);
Expand All @@ -270,7 +278,7 @@ protected void doExecute(Task task, ActionRequest actionRequest, ActionListener<
long startTime = request.getStart() - delayMillis;
long endTime = request.getEnd() - delayMillis;

SinglePointFeatures featureOptional = featureManager.getCurrentFeatures(anomalyDetector, startTime, endTime);
SinglePointFeatures featureOptional = featureManager.getCurrentFeatures(anomalyDetector, startTime, endTime, stateManager);

List<FeatureData> featureInResponse = null;

Expand Down Expand Up @@ -811,7 +819,7 @@ class ColdStartJob implements Callable<Boolean> {
@Override
public Boolean call() {
try {
Optional<double[][]> traingData = featureManager.getColdStartData(detector);
Optional<double[][]> traingData = featureManager.getColdStartData(detector, stateManager);
if (traingData.isPresent()) {
modelManager.trainModel(detector, traingData.get());
return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import java.util.function.BiConsumer;
import java.util.function.Function;

import com.amazon.opendistroforelasticsearch.ad.model.AnomalyDetector;
import com.amazon.opendistroforelasticsearch.ad.transport.ADStateManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchTimeoutException;
import org.elasticsearch.action.Action;
Expand Down Expand Up @@ -152,4 +154,35 @@ public <Request extends ActionRequest, Response extends ActionResponse> Response
) {
return function.apply(request).actionGet(requestTimeout);
}

public <Request extends ActionRequest, Response extends ActionResponse> Optional<Response> throttledTimedRequest(
Request request,
Logger LOG,
BiConsumer<Request, ActionListener<Response>> consumer,
ADStateManager stateManager,
AnomalyDetector detector
) {
try {
AtomicReference<Response> respReference = new AtomicReference<>();
final CountDownLatch latch = new CountDownLatch(1);

consumer.accept(request, new LatchedActionListener<Response>(ActionListener.wrap(response -> {
// clear negative cache
stateManager.clearFilteredQuery(detector);
respReference.set(response);
}, exception -> {
// clear negative cache
stateManager.clearFilteredQuery(detector);
LOG.error("Cannot get response for request {}, error: {}", request, exception);
}), latch));

if (!latch.await(requestTimeout.getSeconds(), TimeUnit.SECONDS)) {
throw new ElasticsearchTimeoutException("Cannot get response within time limit: " + request.toString());
}
return Optional.ofNullable(respReference.get());
} catch (InterruptedException e1) {
LOG.error(CommonErrorMessages.WAIT_ERR_MSG);
throw new IllegalStateException(e1);
}
}
}
Loading

0 comments on commit bc6a763

Please sign in to comment.