Skip to content

Commit

Permalink
add circuit breaker trigger count stat (#274)
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <ylwu@amazon.com>
  • Loading branch information
ylwu-amzn authored Apr 8, 2022
1 parent 08ca3ab commit 18f9065
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 6 deletions.
3 changes: 1 addition & 2 deletions plugin/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,7 @@ jacocoTestReport {
List<String> jacocoExclusions = [
// TODO: add more unit test to meet the minimal test coverage.
'org.opensearch.ml.constant.CommonValue',
'org.opensearch.ml.plugin.*',
'org.opensearch.ml.task.MLPredictTaskRunner',
'org.opensearch.ml.plugin.MachineLearningPlugin*',
'org.opensearch.ml.rest.AbstractMLSearchAction*',
'org.opensearch.ml.rest.RestMLExecuteAction' //0.3
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ public Collection<Object> createComponents(
stats.put(StatNames.ML_TOTAL_REQUEST_COUNT, new MLStat<>(false, new CounterSupplier()));
stats.put(StatNames.ML_TOTAL_FAILURE_COUNT, new MLStat<>(false, new CounterSupplier()));
stats.put(StatNames.ML_TOTAL_MODEL_COUNT, new MLStat<>(false, new CounterSupplier()));
stats.put(StatNames.ML_TOTAL_CIRCUIT_BREAKER_TRIGGER_COUNT, new MLStat<>(false, new CounterSupplier()));
this.mlStats = new MLStats(stats);

mlIndicesHandler = new MLIndicesHandler(clusterService, client);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ public class StatNames {
public static String ML_TOTAL_REQUEST_COUNT = "ml_total_request_count";
public static String ML_TOTAL_FAILURE_COUNT = "ml_total_failure_count";
public static String ML_TOTAL_MODEL_COUNT = "ml_total_model_count";
public static String ML_TOTAL_CIRCUIT_BREAKER_TRIGGER_COUNT = "ml_total_circuit_breaker_trigger_count";

public static String requestCountStat(FunctionName functionName, ActionName actionName) {
return String.format(Locale.ROOT, "ml_%s_%s_request_count", functionName, actionName).toLowerCase(Locale.ROOT);
Expand Down
2 changes: 2 additions & 0 deletions plugin/src/main/java/org/opensearch/ml/task/MLTaskRunner.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.task;

import static org.opensearch.ml.stats.StatNames.ML_EXECUTING_TASK_COUNT;
import static org.opensearch.ml.stats.StatNames.ML_TOTAL_CIRCUIT_BREAKER_TRIGGER_COUNT;

import java.util.HashMap;
import java.util.Map;
Expand Down Expand Up @@ -78,6 +79,7 @@ protected void handleAsyncMLTaskComplete(MLTask mlTask) {

public void run(Request request, TransportService transportService, ActionListener<Response> listener) {
if (mlCircuitBreakerService.isOpen()) {
mlStats.getStat(ML_TOTAL_CIRCUIT_BREAKER_TRIGGER_COUNT).increment();
throw new MLLimitExceededException("Circuit breaker is open");
}
try {
Expand Down
19 changes: 15 additions & 4 deletions plugin/src/test/java/org/opensearch/ml/task/TaskRunnerTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import java.time.Instant;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

import org.junit.Before;
import org.junit.Rule;
Expand All @@ -29,15 +30,17 @@
import org.opensearch.ml.common.MLTaskType;
import org.opensearch.ml.common.breaker.MLCircuitBreakerService;
import org.opensearch.ml.common.exception.MLLimitExceededException;
import org.opensearch.ml.stats.MLStat;
import org.opensearch.ml.stats.MLStats;
import org.opensearch.ml.stats.StatNames;
import org.opensearch.ml.stats.suppliers.CounterSupplier;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.transport.TransportService;

public class TaskRunnerTests extends OpenSearchTestCase {

@Mock
MLTaskManager mlTaskManager;
@Mock
MLStats mlStats;
@Mock
MLTaskDispatcher mlTaskDispatcher;
Expand All @@ -52,6 +55,14 @@ public class TaskRunnerTests extends OpenSearchTestCase {

@Before
public void setup() {
Map<String, MLStat<?>> stats = new ConcurrentHashMap<>();
stats.put(StatNames.ML_EXECUTING_TASK_COUNT, new MLStat<>(false, new CounterSupplier()));
stats.put(StatNames.ML_TOTAL_REQUEST_COUNT, new MLStat<>(false, new CounterSupplier()));
stats.put(StatNames.ML_TOTAL_FAILURE_COUNT, new MLStat<>(false, new CounterSupplier()));
stats.put(StatNames.ML_TOTAL_MODEL_COUNT, new MLStat<>(false, new CounterSupplier()));
stats.put(StatNames.ML_TOTAL_CIRCUIT_BREAKER_TRIGGER_COUNT, new MLStat<>(false, new CounterSupplier()));
mlStats = new MLStats(stats);

MockitoAnnotations.openMocks(this);
mlTaskRunner = new MLTaskRunner(mlTaskManager, mlStats, mlTaskDispatcher, mlCircuitBreakerService) {
@Override
Expand Down Expand Up @@ -98,11 +109,11 @@ public void testHandleAsyncMLTaskComplete_SyncTask() {
}

public void testRun_CircuitBreakerOpen() {
exceptionRule.expect(MLLimitExceededException.class);
exceptionRule.expectMessage("Circuit breaker is open");
when(mlCircuitBreakerService.isOpen()).thenReturn(true);
TransportService transportService = mock(TransportService.class);
ActionListener listener = mock(ActionListener.class);
mlTaskRunner.run(null, transportService, listener);
expectThrows(MLLimitExceededException.class, () -> mlTaskRunner.run(null, transportService, listener));
Long value = (Long) mlStats.getStat(StatNames.ML_TOTAL_CIRCUIT_BREAKER_TRIGGER_COUNT).getValue();
assertEquals(1L, value.longValue());
}
}

0 comments on commit 18f9065

Please sign in to comment.