Skip to content

ClusterStateTaskListener usage refactoring in MasterServiceTests #82869

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jan 24, 2022
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.apache.logging.log4j.Level;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.Version;
import org.elasticsearch.action.support.master.AcknowledgedResponse;
Expand All @@ -20,6 +21,7 @@
import org.elasticsearch.cluster.ClusterStatePublicationEvent;
import org.elasticsearch.cluster.ClusterStateTaskConfig;
import org.elasticsearch.cluster.ClusterStateTaskExecutor;
import org.elasticsearch.cluster.ClusterStateTaskExecutor.ClusterTasksResult;
import org.elasticsearch.cluster.ClusterStateTaskListener;
import org.elasticsearch.cluster.ClusterStateUpdateTask;
import org.elasticsearch.cluster.LocalMasterServiceTask;
Expand Down Expand Up @@ -51,14 +53,10 @@

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.BrokenBarrierException;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.Semaphore;
Expand All @@ -67,13 +65,14 @@
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;

import static java.util.Collections.emptyMap;
import static java.util.Collections.emptySet;
import static java.util.stream.Collectors.toMap;
import static org.hamcrest.Matchers.anyOf;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasKey;

public class MasterServiceTests extends ESTestCase {

Expand Down Expand Up @@ -263,9 +262,18 @@ public void testClusterStateTaskListenerThrowingExceptionIsOkay() throws Interru
AtomicBoolean published = new AtomicBoolean();

try (MasterService masterService = createMasterService(true)) {
ClusterStateTaskListener update = new ClusterStateTaskListener() {
@Override
public void clusterStateProcessed(ClusterState oldState, ClusterState newState) {
throw new RuntimeException("testing exception handling");
}

@Override
public void onFailure(Exception e) {}
};
masterService.submitStateUpdateTask(
"testClusterStateTaskListenerThrowingExceptionIsOkay",
new Object(),
update,
ClusterStateTaskConfig.build(Priority.NORMAL),
new ClusterStateTaskExecutor<Object>() {
@Override
Expand All @@ -280,15 +288,7 @@ public void clusterStatePublished(ClusterStatePublicationEvent clusterStatePubli
latch.countDown();
}
},
new ClusterStateTaskListener() {
@Override
public void clusterStateProcessed(ClusterState oldState, ClusterState newState) {
throw new IllegalStateException();
}

@Override
public void onFailure(Exception e) {}
}
update
);

latch.await();
Expand Down Expand Up @@ -464,23 +464,39 @@ public void onFailure(Exception e) {
}

public void testClusterStateBatchedUpdates() throws BrokenBarrierException, InterruptedException {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One important thing that this test seems not to check is that batches submitted with the same executor are executed together. It's a bit tricky to test because you have to block the master service while you're submitting the batches to make sure that they all arrive "at once". Still, I think this is something worth addressing either randomly in this test or else in a separate test.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On reflection I think it doesn't make sense to try and extend this test to cover this extra invariant too. Let's leave it as it is and add a new test in a follow-up.

Copy link
Contributor Author

@idegtiarenko idegtiarenko Jan 24, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe that is asserted here: https://github.com/idegtiarenko/elasticsearch/blob/665c374b90098f4dd34bcf238de21342a2b4c47a/server/src/test/java/org/elasticsearch/cluster/service/MasterServiceTests.java#L538-L545 (unless I miss something)

UPD: it is not checking "all at once" but at least it is verifying the grouping

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This asserts that every set of tasks that was submitted together is executed as an atomic unit. That's certainly important but in practice we only submit multiple tasks in CandidateJoinAccumulator when completing an election. We also want to know that tasks submitted with multiple calls to submitStateUpdateTask are executed as a batch.

AtomicInteger counter = new AtomicInteger();
class Task {
private AtomicBoolean state = new AtomicBoolean();

AtomicInteger executedTasks = new AtomicInteger();
AtomicInteger submittedTasks = new AtomicInteger();
AtomicInteger processedStates = new AtomicInteger();
SetOnce<CountDownLatch> processedStatesLatch = new SetOnce<>();

class Task implements ClusterStateTaskListener {
private final AtomicBoolean executed = new AtomicBoolean();
private final int id;

Task(int id) {
this.id = id;
}

public void execute() {
if (state.compareAndSet(false, true) == false) {
throw new IllegalStateException();
if (executed.compareAndSet(false, true) == false) {
throw new AssertionError("Task [" + id + "] should only be executed once");
} else {
counter.incrementAndGet();
executedTasks.incrementAndGet();
}
}

@Override
public void onFailure(Exception e) {
throw new AssertionError(e);
}

@Override
public void clusterStateProcessed(ClusterState oldState, ClusterState newState) {
processedStates.incrementAndGet();
processedStatesLatch.get().countDown();
}

@Override
public boolean equals(Object o) {
if (this == o) {
Expand All @@ -491,7 +507,6 @@ public boolean equals(Object o) {
}
Task task = (Task) o;
return id == task.id;

}

@Override
Expand All @@ -505,38 +520,43 @@ public String toString() {
}
}

int numberOfThreads = randomIntBetween(2, 8);
int taskSubmissionsPerThread = randomIntBetween(1, 64);
int numberOfExecutors = Math.max(1, numberOfThreads / 4);
final Semaphore semaphore = new Semaphore(numberOfExecutors);
final int numberOfThreads = randomIntBetween(2, 8);
final int taskSubmissionsPerThread = randomIntBetween(1, 64);
final int numberOfExecutors = Math.max(1, numberOfThreads / 4);
final Semaphore semaphore = new Semaphore(1);

class TaskExecutor implements ClusterStateTaskExecutor<Task> {
private final List<Set<Task>> taskGroups;
private AtomicInteger counter = new AtomicInteger();
private AtomicInteger batches = new AtomicInteger();
private AtomicInteger published = new AtomicInteger();

TaskExecutor(List<Set<Task>> taskGroups) {
this.taskGroups = taskGroups;
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved this to assignments. Now executor will verify it only executes own tasks

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we don't actually verify this yet, we're now only checking that every group of our own tasks is completely executed. Previously we were checking that every group of any executor's tasks is completely executed so technically this is now a weaker test.

I think this would be a good property to verify, but to do this we also need to say that every task we're executing belongs to one of our own groups.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just added assertThat("All tasks should belong to this executor", totalCount, equalTo(tasks.size())); to verify that every single task from the list belongs to the current executor.

There is also the a check in the very end of the test that assertEquals(executor.assigned.get(), executor.executed.get()) that verifies that all assigned tasks are executed. We also have a check that no task is executed twice.

I believe this should cover everything, is it?

private final AtomicInteger executed = new AtomicInteger();
private final AtomicInteger assigned = new AtomicInteger();
private final AtomicInteger batches = new AtomicInteger();
private final AtomicInteger published = new AtomicInteger();
private final List<Set<Task>> assignments = new ArrayList<>();

@Override
public ClusterTasksResult<Task> execute(ClusterState currentState, List<Task> tasks) throws Exception {
for (Set<Task> expectedSet : taskGroups) {
long count = tasks.stream().filter(expectedSet::contains).count();
int totalCount = 0;
for (Set<Task> group : assignments) {
long count = tasks.stream().filter(group::contains).count();
assertThat(
"batched set should be executed together or not at all. Expected " + expectedSet + "s. Executing " + tasks,
"batched set should be executed together or not at all. Expected " + group + "s. Executing " + tasks,
count,
anyOf(equalTo(0L), equalTo((long) expectedSet.size()))
anyOf(equalTo(0L), equalTo((long) group.size()))
);
totalCount += count;
}
assertThat("All tasks should belong to this executor", totalCount, equalTo(tasks.size()));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

neat :)

tasks.forEach(Task::execute);
counter.addAndGet(tasks.size());
executed.addAndGet(tasks.size());
ClusterState maybeUpdatedClusterState = currentState;
if (randomBoolean()) {
maybeUpdatedClusterState = ClusterState.builder(currentState).build();
batches.incrementAndGet();
semaphore.acquire();
assertThat(
"All cluster state modifications should be executed on a single thread",
semaphore.tryAcquire(),
equalTo(true)
);
}
return ClusterTasksResult.<Task>builder().successes(tasks).build(maybeUpdatedClusterState);
}
Expand All @@ -548,40 +568,27 @@ public void clusterStatePublished(ClusterStatePublicationEvent clusterPublicatio
}
}

ConcurrentMap<String, AtomicInteger> processedStates = new ConcurrentHashMap<>();

List<Set<Task>> taskGroups = new ArrayList<>();
List<TaskExecutor> executors = new ArrayList<>();
for (int i = 0; i < numberOfExecutors; i++) {
executors.add(new TaskExecutor(taskGroups));
executors.add(new TaskExecutor());
}

// randomly assign tasks to executors
List<Tuple<TaskExecutor, Set<Task>>> assignments = new ArrayList<>();
int taskId = 0;
AtomicInteger totalTasks = new AtomicInteger();
for (int i = 0; i < numberOfThreads; i++) {
for (int j = 0; j < taskSubmissionsPerThread; j++) {
TaskExecutor executor = randomFrom(executors);
Set<Task> tasks = new HashSet<>();
for (int t = randomInt(3); t >= 0; t--) {
tasks.add(new Task(taskId++));
}
taskGroups.add(tasks);
var executor = randomFrom(executors);
var tasks = Set.copyOf(randomList(1, 3, () -> new Task(totalTasks.getAndIncrement())));

assignments.add(Tuple.tuple(executor, tasks));
executor.assigned.addAndGet(tasks.size());
executor.assignments.add(tasks);
}
}

Map<TaskExecutor, Integer> counts = new HashMap<>();
int totalTaskCount = 0;
for (Tuple<TaskExecutor, Set<Task>> assignment : assignments) {
final int taskCount = assignment.v2().size();
counts.merge(assignment.v1(), taskCount, (previous, count) -> previous + count);
totalTaskCount += taskCount;
}
final CountDownLatch updateLatch = new CountDownLatch(totalTaskCount);
processedStatesLatch.set(new CountDownLatch(totalTasks.get()));

try (MasterService masterService = createMasterService(true)) {
final ConcurrentMap<String, AtomicInteger> submittedTasksPerThread = new ConcurrentHashMap<>();
CyclicBarrier barrier = new CyclicBarrier(1 + numberOfThreads);
for (int i = 0; i < numberOfThreads; i++) {
final int index = i;
Expand All @@ -590,36 +597,23 @@ public void clusterStatePublished(ClusterStatePublicationEvent clusterPublicatio
try {
barrier.await();
for (int j = 0; j < taskSubmissionsPerThread; j++) {
Tuple<TaskExecutor, Set<Task>> assignment = assignments.get(index * taskSubmissionsPerThread + j);
final Set<Task> tasks = assignment.v2();
submittedTasksPerThread.computeIfAbsent(threadName, key -> new AtomicInteger()).addAndGet(tasks.size());
final TaskExecutor executor = assignment.v1();
final ClusterStateTaskListener listener = new ClusterStateTaskListener() {
@Override
public void onFailure(Exception e) {
throw new AssertionError(e);
}

@Override
public void clusterStateProcessed(ClusterState oldState, ClusterState newState) {
processedStates.computeIfAbsent(threadName, key -> new AtomicInteger()).incrementAndGet();
updateLatch.countDown();
}
};
var assignment = assignments.get(index * taskSubmissionsPerThread + j);
var tasks = assignment.v2();
var executor = assignment.v1();
submittedTasks.addAndGet(tasks.size());
if (tasks.size() == 1) {
var update = tasks.iterator().next();
masterService.submitStateUpdateTask(
threadName,
tasks.stream().findFirst().get(),
update,
ClusterStateTaskConfig.build(randomFrom(Priority.values())),
executor,
listener
update
);
} else {
Map<Task, ClusterStateTaskListener> taskListeners = new HashMap<>();
tasks.forEach(t -> taskListeners.put(t, listener));
masterService.submitStateUpdateTasks(
threadName,
taskListeners,
tasks.stream().collect(toMap(Function.<Task>identity(), Function.<ClusterStateTaskListener>identity())),
ClusterStateTaskConfig.build(randomFrom(Priority.values())),
executor
);
Expand All @@ -639,29 +633,19 @@ public void clusterStateProcessed(ClusterState oldState, ClusterState newState)
barrier.await();

// wait until all the cluster state updates have been processed
updateLatch.await();
// and until all of the publication callbacks have completed
semaphore.acquire(numberOfExecutors);
processedStatesLatch.get().await();
// and until all the publication callbacks have completed
semaphore.acquire();

// assert the number of executed tasks is correct
assertEquals(totalTaskCount, counter.get());
assertThat(submittedTasks.get(), equalTo(totalTasks.get()));
assertThat(executedTasks.get(), equalTo(totalTasks.get()));
assertThat(processedStates.get(), equalTo(totalTasks.get()));

// assert each executor executed the correct number of tasks
for (TaskExecutor executor : executors) {
if (counts.containsKey(executor)) {
assertEquals((int) counts.get(executor), executor.counter.get());
assertEquals(executor.batches.get(), executor.published.get());
}
}

// assert the correct number of clusterStateProcessed events were triggered
for (Map.Entry<String, AtomicInteger> entry : processedStates.entrySet()) {
assertThat(submittedTasksPerThread, hasKey(entry.getKey()));
assertEquals(
"not all tasks submitted by " + entry.getKey() + " received a processed event",
entry.getValue().get(),
submittedTasksPerThread.get(entry.getKey()).get()
);
assertEquals(executor.assigned.get(), executor.executed.get());
assertEquals(executor.batches.get(), executor.published.get());
}
}
}
Expand All @@ -672,36 +656,37 @@ public void testBlockingCallInClusterStateTaskListenerFails() throws Interrupted
final AtomicReference<AssertionError> assertionRef = new AtomicReference<>();

try (MasterService masterService = createMasterService(true)) {
ClusterStateTaskListener update = new ClusterStateTaskListener() {
@Override
public void clusterStateProcessed(ClusterState oldState, ClusterState newState) {
BaseFuture<Void> future = new BaseFuture<Void>() {
};
try {
if (randomBoolean()) {
future.get(1L, TimeUnit.SECONDS);
} else {
future.get();
}
} catch (Exception e) {
throw new RuntimeException(e);
} catch (AssertionError e) {
assertionRef.set(e);
latch.countDown();
}
}

@Override
public void onFailure(Exception e) {}
};
masterService.submitStateUpdateTask(
"testBlockingCallInClusterStateTaskListenerFails",
new Object(),
update,
ClusterStateTaskConfig.build(Priority.NORMAL),
(currentState, tasks) -> {
ClusterState newClusterState = ClusterState.builder(currentState).build();
return ClusterStateTaskExecutor.ClusterTasksResult.builder().successes(tasks).build(newClusterState);
return ClusterTasksResult.<ClusterStateTaskListener>builder().successes(tasks).build(newClusterState);
},
new ClusterStateTaskListener() {
@Override
public void clusterStateProcessed(ClusterState oldState, ClusterState newState) {
BaseFuture<Void> future = new BaseFuture<Void>() {
};
try {
if (randomBoolean()) {
future.get(1L, TimeUnit.SECONDS);
} else {
future.get();
}
} catch (Exception e) {
throw new RuntimeException(e);
} catch (AssertionError e) {
assertionRef.set(e);
latch.countDown();
}
}

@Override
public void onFailure(Exception e) {}
}
update
);

latch.await();
Expand Down