Skip to content
This repository has been archived by the owner on Aug 2, 2022. It is now read-only.

add async getRcfResult #69

Merged
merged 1 commit into from
Mar 19, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@

import com.google.gson.Gson;

import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.monitor.jvm.JvmService;

Expand Down Expand Up @@ -319,6 +320,58 @@ public RcfResult getRcfResult(String detectorId, String modelId, double[] point)
return new RcfResult(score, confidence, forestSize);
}

/**
* Returns to listener the RCF anomaly result using the specified model.
*
* @param detectorId ID of the detector
* @param modelId ID of the model to score the point
* @param point features of the data point
* @param listener onResponse is called with RCF result for the input point, including a score
* onFailure is called with ResourceNotFoundException when the model is not found
* onFailure is called with LimitExceededException when a limit is exceeded for the model
*/
public void getRcfResult(String detectorId, String modelId, double[] point, ActionListener<RcfResult> listener) {
if (forests.containsKey(modelId)) {
getRcfResult(forests.get(modelId), point, listener);
} else {
checkpointDao
.getModelCheckpoint(
modelId,
ActionListener
.wrap(checkpoint -> processRcfCheckpoint(checkpoint, modelId, detectorId, point, listener), listener::onFailure)
);
}
}

private void getRcfResult(ModelState<RandomCutForest> modelState, double[] point, ActionListener<RcfResult> listener) {
RandomCutForest rcf = modelState.getModel();
double score = rcf.getAnomalyScore(point);
double confidence = computeRcfConfidence(rcf);
int forestSize = rcf.getNumberOfTrees();
rcf.update(point);
modelState.setLastUsedTime(clock.instant());
listener.onResponse(new RcfResult(score, confidence, forestSize));
}

private void processRcfCheckpoint(
Optional<String> rcfCheckpoint,
String modelId,
String detectorId,
double[] point,
ActionListener<RcfResult> listener
) {
Optional<ModelState<RandomCutForest>> model = rcfCheckpoint
.map(checkpoint -> AccessController.doPrivileged((PrivilegedAction<RandomCutForest>) () -> rcfSerde.fromJson(checkpoint)))
.filter(rcf -> isHostingAllowed(detectorId, rcf))
.map(rcf -> new ModelState<>(rcf, modelId, detectorId, ModelType.RCF.getName(), clock.instant()));
if (model.isPresent()) {
forests.put(modelId, model.get());
getRcfResult(model.get(), point, listener);
} else {
throw new ResourceNotFoundException(detectorId, CommonErrorMessages.NO_CHECKPOINT_ERR_MSG + modelId);
}
}

/**
* Gets the result using the specified thresholding model.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import junitparams.JUnitParamsRunner;
import junitparams.Parameters;

import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.collect.ImmutableOpenMap;
Expand All @@ -47,6 +48,7 @@
import org.junit.runner.RunWith;

import org.mockito.Answers;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;

Expand All @@ -65,8 +67,10 @@
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;

import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyObject;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
Expand Down Expand Up @@ -130,6 +134,7 @@ public class ModelManagerTests {
private String modelId;
private String rcfModelId;
private String thresholdModelId;
private String checkpoint;

@Before
public void setup() {
Expand Down Expand Up @@ -188,6 +193,7 @@ public void setup() {
modelId = "modelId";
rcfModelId = "detectorId_model_rcf_1";
thresholdModelId = "detectorId_model_threshold";
checkpoint = "testcheckpoint";
}

private Object[] getDetectorIdForModelIdData() {
Expand Down Expand Up @@ -363,6 +369,72 @@ public void getRcfResult_throwLimitExceeded_whenHeapLimitReached() {
modelManager.getRcfResult(detectorId, modelId, new double[0]);
}

@Test
@SuppressWarnings("unchecked")
public void getRcfResult_returnExpectedToListener() {
double[] point = new double[0];
RandomCutForest forest = mock(RandomCutForest.class);
double score = 11.;

doAnswer(invocation -> {
ActionListener<Optional<String>> listener = invocation.getArgument(1);
listener.onResponse(Optional.of(checkpoint));
return null;
}).when(checkpointDao).getModelCheckpoint(eq(rcfModelId), any(ActionListener.class));
when(rcfSerde.fromJson(checkpoint)).thenReturn(forest);
when(forest.getAnomalyScore(point)).thenReturn(score);
when(forest.getNumberOfTrees()).thenReturn(numTrees);
when(forest.getLambda()).thenReturn(rcfTimeDecay);
when(forest.getSampleSize()).thenReturn(numSamples);
when(forest.getTotalUpdates()).thenReturn((long) numSamples);

ActionListener<RcfResult> listener = mock(ActionListener.class);
modelManager.getRcfResult(detectorId, rcfModelId, point, listener);

RcfResult expected = new RcfResult(score, 0, numTrees);
verify(listener).onResponse(eq(expected));

when(forest.getTotalUpdates()).thenReturn(numSamples + 1L);
listener = mock(ActionListener.class);
modelManager.getRcfResult(detectorId, rcfModelId, point, listener);

ArgumentCaptor<RcfResult> responseCaptor = ArgumentCaptor.forClass(RcfResult.class);
verify(listener).onResponse(responseCaptor.capture());
assertEquals(0.091353632, responseCaptor.getValue().getConfidence(), 1e-6);
}

@Test
@SuppressWarnings("unchecked")
public void getRcfResult_throwToListener_whenNoCheckpoint() {
doAnswer(invocation -> {
ActionListener<Optional<String>> listener = invocation.getArgument(1);
listener.onResponse(Optional.empty());
return null;
}).when(checkpointDao).getModelCheckpoint(eq(rcfModelId), any(ActionListener.class));

ActionListener<RcfResult> listener = mock(ActionListener.class);
modelManager.getRcfResult(detectorId, rcfModelId, new double[0], listener);

verify(listener).onFailure(any(ResourceNotFoundException.class));
}

@Test
@SuppressWarnings("unchecked")
public void getRcfResult_throwToListener_whenHeapLimitExceed() {
doAnswer(invocation -> {
ActionListener<Optional<String>> listener = invocation.getArgument(1);
listener.onResponse(Optional.of(checkpoint));
return null;
}).when(checkpointDao).getModelCheckpoint(eq(rcfModelId), any(ActionListener.class));
when(rcfSerde.fromJson(checkpoint)).thenReturn(rcf);
when(jvmService.info().getMem().getHeapMax().getBytes()).thenReturn(1_000L);

ActionListener<RcfResult> listener = mock(ActionListener.class);
modelManager.getRcfResult(detectorId, rcfModelId, new double[0], listener);

verify(listener).onFailure(any(LimitExceededException.class));
}

@Test
public void getThresholdingResult_returnExpected() {
String modelId = "testModelId";
Expand Down