diff --git a/junit-platform-engine/src/main/java/org/junit/platform/engine/support/hierarchical/ForkJoinPoolHierarchicalTestExecutorService.java b/junit-platform-engine/src/main/java/org/junit/platform/engine/support/hierarchical/ForkJoinPoolHierarchicalTestExecutorService.java index 5b8a821b11bb..fc0aae08f68c 100644 --- a/junit-platform-engine/src/main/java/org/junit/platform/engine/support/hierarchical/ForkJoinPoolHierarchicalTestExecutorService.java +++ b/junit-platform-engine/src/main/java/org/junit/platform/engine/support/hierarchical/ForkJoinPoolHierarchicalTestExecutorService.java @@ -53,7 +53,9 @@ @API(status = STABLE, since = "1.10") public class ForkJoinPoolHierarchicalTestExecutorService implements HierarchicalTestExecutorService { - private final ForkJoinPool forkJoinPool; + // package-private for testing + final ForkJoinPool forkJoinPool; + private final TaskEventListener taskEventListener; private final int parallelism; private final ThreadLocal threadLocks = ThreadLocal.withInitial(ThreadLock::new); diff --git a/platform-tests/src/test/java/org/junit/platform/engine/support/hierarchical/ForkJoinPoolHierarchicalTestExecutorServiceTests.java b/platform-tests/src/test/java/org/junit/platform/engine/support/hierarchical/ForkJoinPoolHierarchicalTestExecutorServiceTests.java index dbca07d0d253..2e0648df63bf 100644 --- a/platform-tests/src/test/java/org/junit/platform/engine/support/hierarchical/ForkJoinPoolHierarchicalTestExecutorServiceTests.java +++ b/platform-tests/src/test/java/org/junit/platform/engine/support/hierarchical/ForkJoinPoolHierarchicalTestExecutorServiceTests.java @@ -10,12 +10,14 @@ package org.junit.platform.engine.support.hierarchical; +import static java.util.concurrent.TimeUnit.SECONDS; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.params.provider.Arguments.arguments; import static org.junit.platform.engine.support.hierarchical.ExclusiveResource.GLOBAL_READ; import static org.junit.platform.engine.support.hierarchical.ExclusiveResource.GLOBAL_READ_WRITE; @@ -25,13 +27,16 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.CountDownLatch; -import java.util.concurrent.ExecutionException; import java.util.concurrent.atomic.AtomicReference; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; import org.junit.jupiter.api.function.Executable; +import org.junit.jupiter.api.function.ThrowingConsumer; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; @@ -44,6 +49,9 @@ @Timeout(5) class ForkJoinPoolHierarchicalTestExecutorServiceTests { + DummyTaskFactory taskFactory = new DummyTaskFactory(); + LockManager lockManager = new LockManager(); + @Test void exceptionsFromInvalidConfigurationAreNotSwallowed() { var configuration = new DefaultParallelExecutionConfiguration(2, 1, 1, 1, 0, __ -> true); @@ -91,7 +99,6 @@ static List incompatibleLockCombinations() { arguments(// Set.of(GLOBAL_READ, new ExclusiveResource("a", LockMode.READ), new ExclusiveResource("b", LockMode.READ), new ExclusiveResource("d", LockMode.READ)), - // Set.of(GLOBAL_READ, new ExclusiveResource("a", LockMode.READ), new ExclusiveResource("c", LockMode.READ)) // )// @@ -101,9 +108,8 @@ static List incompatibleLockCombinations() { @ParameterizedTest @MethodSource("incompatibleLockCombinations") void defersTasksWithIncompatibleLocks(Set initialResources, - Set incompatibleResources) throws Exception { + Set incompatibleResources) throws Throwable { - var lockManager = new LockManager(); var initialLock = lockManager.getLockForResources(initialResources); var incompatibleLock = lockManager.getLockForResources(incompatibleResources); @@ -115,20 +121,14 @@ void defersTasksWithIncompatibleLocks(Set initialResources, deferred.countDown(); }; - var incompatibleTask = new DummyTestTask("incompatibleTask", incompatibleLock); + var incompatibleTask = taskFactory.create("incompatibleTask", incompatibleLock); - var tasks = runWithAttemptedWorkStealing(taskEventListener, incompatibleTask, initialLock, () -> { - try { - deferred.await(); - } - catch (InterruptedException e) { - System.out.println("Interrupted while waiting for task to be deferred"); - } - }); + var tasks = runWithAttemptedWorkStealing(taskEventListener, incompatibleTask, initialLock, + () -> await(deferred, "Interrupted while waiting for task to be deferred")); assertEquals(incompatibleTask, deferredTask.get()); - assertEquals(tasks.get("nestedTask").threadName, tasks.get("leafTask2").threadName); - assertNotEquals(tasks.get("leafTask1").threadName, tasks.get("leafTask2").threadName); + assertEquals(tasks.get("nestedTask").threadName, tasks.get("leafTaskB").threadName); + assertNotEquals(tasks.get("leafTaskA").threadName, tasks.get("leafTaskB").threadName); } static List compatibleLockCombinations() { @@ -169,67 +169,132 @@ static List compatibleLockCombinations() { @ParameterizedTest @MethodSource("compatibleLockCombinations") void canWorkStealTaskWithCompatibleLocks(Set initialResources, - Set compatibleResources) throws Exception { + Set compatibleResources) throws Throwable { - var lockManager = new LockManager(); var initialLock = lockManager.getLockForResources(initialResources); var compatibleLock = lockManager.getLockForResources(compatibleResources); var deferredTask = new AtomicReference(); var workStolen = new CountDownLatch(1); - var compatibleTask = new DummyTestTask("compatibleTask", compatibleLock, workStolen::countDown); + var compatibleTask = taskFactory.create("compatibleTask", compatibleLock, workStolen::countDown); - var tasks = runWithAttemptedWorkStealing(deferredTask::set, compatibleTask, initialLock, () -> { - try { - workStolen.await(); - } - catch (InterruptedException e) { - System.out.println("Interrupted while waiting for work to be stolen"); - } - }); + var tasks = runWithAttemptedWorkStealing(deferredTask::set, compatibleTask, initialLock, + () -> await(workStolen, "Interrupted while waiting for work to be stolen")); assertNull(deferredTask.get()); - assertEquals(tasks.get("nestedTask").threadName, tasks.get("leafTask2").threadName); - assertNotEquals(tasks.get("leafTask1").threadName, tasks.get("leafTask2").threadName); + assertEquals(tasks.get("nestedTask").threadName, tasks.get("leafTaskB").threadName); + assertNotEquals(tasks.get("leafTaskA").threadName, tasks.get("leafTaskB").threadName); } - private static Map runWithAttemptedWorkStealing(TaskEventListener taskEventListener, - DummyTestTask taskToBeStolen, ResourceLock initialLock, Runnable waitAction) - throws InterruptedException, ExecutionException { + @Test + void defersTasksWithIncompatibleLocksOnMultipleLevels() throws Throwable { + + var initialLock = lockManager.getLockForResources( + Set.of(GLOBAL_READ, new ExclusiveResource("a", LockMode.READ))); + var incompatibleLock1 = lockManager.getLockForResource(new ExclusiveResource("a", LockMode.READ_WRITE)); + var compatibleLock1 = lockManager.getLockForResource(new ExclusiveResource("b", LockMode.READ)); + var incompatibleLock2 = lockManager.getLockForResource(new ExclusiveResource("b", LockMode.READ_WRITE)); + + var deferred = new ConcurrentHashMap(); + var deferredTasks = new CopyOnWriteArrayList(); + TaskEventListener taskEventListener = testTask -> { + deferredTasks.add(testTask); + deferred.get(testTask).countDown(); + }; + + var incompatibleTask1 = taskFactory.create("incompatibleTask1", incompatibleLock1); + deferred.put(incompatibleTask1, new CountDownLatch(1)); + + var incompatibleTask2 = taskFactory.create("incompatibleTask2", incompatibleLock2); + deferred.put(incompatibleTask2, new CountDownLatch(1)); + + var configuration = new DefaultParallelExecutionConfiguration(2, 2, 2, 2, 1, __1 -> true); + + withForkJoinPoolHierarchicalTestExecutorService(configuration, taskEventListener, service -> { + + var nestedTask2 = createNestedTaskWithTwoConcurrentLeafTasks(service, "2", compatibleLock1, + List.of(incompatibleTask2), // + () -> await(deferred.get(incompatibleTask2), incompatibleTask2.identifier + " to be deferred")); + + var nestedTask1 = createNestedTaskWithTwoConcurrentLeafTasks(service, "1", initialLock, + List.of(incompatibleTask1, nestedTask2), // + () -> { + await(deferred.get(incompatibleTask1), incompatibleTask1.identifier + " to be deferred"); + await(nestedTask2.started, nestedTask2.identifier + " to be started"); + }); + + service.submit(nestedTask1).get(); + }); + + assertThat(deferredTasks) // + .containsExactly(incompatibleTask1, incompatibleTask2, incompatibleTask1); // incompatibleTask1 may be deferred multiple times + assertThat(taskFactory.tasks) // + .hasSize(3 + 3 + 2) // + .values().extracting(it -> it.completion.isDone()).containsOnly(true); + assertThat(taskFactory.tasks) // + .values().extracting(it -> it.completion.isCompletedExceptionally()).containsOnly(false); + } - var tasks = new HashMap(); - tasks.put(taskToBeStolen.identifier, taskToBeStolen); + private Map runWithAttemptedWorkStealing(TaskEventListener taskEventListener, + DummyTestTask taskToBeStolen, ResourceLock initialLock, Runnable waitAction) throws Throwable { var configuration = new DefaultParallelExecutionConfiguration(2, 2, 2, 2, 1, __ -> true); - try (var pool = new ForkJoinPoolHierarchicalTestExecutorService(configuration, taskEventListener)) { + withForkJoinPoolHierarchicalTestExecutorService(configuration, taskEventListener, service -> { + + var nestedTask = createNestedTaskWithTwoConcurrentLeafTasks(service, "", initialLock, + List.of(taskToBeStolen), waitAction); + + service.submit(nestedTask).get(); + }); + + return taskFactory.tasks; + } + + private DummyTestTask createNestedTaskWithTwoConcurrentLeafTasks( + ForkJoinPoolHierarchicalTestExecutorService service, String identifierSuffix, ResourceLock parentLock, + List tasksToFork, Runnable waitAction) { + + return taskFactory.create("nestedTask" + identifierSuffix, parentLock, () -> { - var extraTask = pool.new ExclusiveTask(taskToBeStolen); var bothLeafTasksAreRunning = new CountDownLatch(2); - var nestedTask = new DummyTestTask("nestedTask", initialLock, () -> { - var leafTask1 = new DummyTestTask("leafTask1", NopLock.INSTANCE, () -> { - extraTask.fork(); - bothLeafTasksAreRunning.countDown(); - bothLeafTasksAreRunning.await(); - waitAction.run(); - }); - tasks.put(leafTask1.identifier, leafTask1); - var leafTask2 = new DummyTestTask("leafTask2", NopLock.INSTANCE, () -> { - bothLeafTasksAreRunning.countDown(); - bothLeafTasksAreRunning.await(); - }); - tasks.put(leafTask2.identifier, leafTask2); - pool.invokeAll(List.of(leafTask1, leafTask2)); + var leafTaskA = taskFactory.create("leafTaskA" + identifierSuffix, NopLock.INSTANCE, () -> { + tasksToFork.forEach(task -> service.new ExclusiveTask(task).fork()); + bothLeafTasksAreRunning.countDown(); + bothLeafTasksAreRunning.await(); + waitAction.run(); + }); + + var leafTaskB = taskFactory.create("leafTaskB" + identifierSuffix, NopLock.INSTANCE, () -> { + bothLeafTasksAreRunning.countDown(); + bothLeafTasksAreRunning.await(); }); - tasks.put(nestedTask.identifier, nestedTask); - pool.submit(nestedTask).get(); - extraTask.join(); + service.invokeAll(List.of(leafTaskA, leafTaskB)); + }); + } + + private static void await(CountDownLatch latch, String message) { + try { + latch.await(); } + catch (InterruptedException e) { + System.out.println("Interrupted while waiting for " + message); + } + } + + private void withForkJoinPoolHierarchicalTestExecutorService(ParallelExecutionConfiguration configuration, + TaskEventListener taskEventListener, ThrowingConsumer action) + throws Throwable { + try (var service = new ForkJoinPoolHierarchicalTestExecutorService(configuration, taskEventListener)) { - return tasks; + action.accept(service); + + service.forkJoinPool.shutdown(); + assertTrue(service.forkJoinPool.awaitTermination(5, SECONDS), "Pool did not terminate within timeout"); + } } static final class DummyTestTask implements TestTask { @@ -238,12 +303,9 @@ static final class DummyTestTask implements TestTask { private final ResourceLock resourceLock; private final Executable action; - private String threadName; - - DummyTestTask(String identifier, ResourceLock resourceLock) { - this(identifier, resourceLock, () -> { - }); - } + private volatile String threadName; + private final CountDownLatch started = new CountDownLatch(1); + private final CompletableFuture completion = new CompletableFuture<>(); DummyTestTask(String identifier, ResourceLock resourceLock, Executable action) { this.identifier = identifier; @@ -264,10 +326,13 @@ public ResourceLock getResourceLock() { @Override public void execute() { threadName = Thread.currentThread().getName(); + started.countDown(); try { action.execute(); + completion.complete(null); } catch (Throwable e) { + completion.completeExceptionally(e); throw new RuntimeException("Action " + identifier + " failed", e); } } @@ -277,4 +342,20 @@ public String toString() { return identifier; } } + + static final class DummyTaskFactory { + + final Map tasks = new HashMap<>(); + + DummyTestTask create(String identifier, ResourceLock resourceLock) { + return create(identifier, resourceLock, () -> { + }); + } + + DummyTestTask create(String identifier, ResourceLock resourceLock, Executable action) { + DummyTestTask task = new DummyTestTask(identifier, resourceLock, action); + tasks.put(task.identifier, task); + return task; + } + } }