diff --git a/documentation/src/docs/asciidoc/release-notes/release-notes-5.11.0-RC1.adoc b/documentation/src/docs/asciidoc/release-notes/release-notes-5.11.0-RC1.adoc index e76d8c4020aa..1a90b20a5407 100644 --- a/documentation/src/docs/asciidoc/release-notes/release-notes-5.11.0-RC1.adoc +++ b/documentation/src/docs/asciidoc/release-notes/release-notes-5.11.0-RC1.adoc @@ -41,12 +41,15 @@ repository on GitHub. [[release-notes-5.11.0-RC1-junit-jupiter-bug-fixes]] ==== Bug Fixes -* ❓ +* `TestInstancePostProcessor` extensions can now be registered via the `@ExtendWith` + annotation on non-static fields. [[release-notes-5.11.0-RC1-junit-jupiter-deprecations-and-breaking-changes]] ==== Deprecations and Breaking Changes -* ❓ +* The registration order of extensions was changed to allow non-static fields to be + processed earlier. This change may affect extensions that rely on the order of + registration. [[release-notes-5.11.0-RC1-junit-jupiter-new-features-and-improvements]] ==== New Features and Improvements diff --git a/documentation/src/docs/asciidoc/user-guide/extensions.adoc b/documentation/src/docs/asciidoc/user-guide/extensions.adoc index b5e4e89dc614..bffbd523d5ab 100644 --- a/documentation/src/docs/asciidoc/user-guide/extensions.adoc +++ b/documentation/src/docs/asciidoc/user-guide/extensions.adoc @@ -129,21 +129,9 @@ important to note which extension APIs are implemented and for what reasons. Specifically, `RandomNumberExtension` implements the following extension APIs: - `BeforeAllCallback`: to support static field injection -- `BeforeEachCallback`: to support non-static field injection +- `TestInstancePostProcessor`: to support non-static field injection - `ParameterResolver`: to support constructor and method injection -[NOTE] -==== -Ideally, the `RandomNumberExtension` would implement `TestInstancePostProcessor` instead -of `BeforeEachCallback` in order to support non-static field injection immediately after -the test class has been instantiated. - -However, JUnit Jupiter currently does not allow a `TestInstancePostProcessor` to be -registered via `@ExtendWith` on a non-static field (see -link:{junit5-repo}/issues/3437[issue 3437]). In light of that, the `RandomNumberExtension` -implements `BeforeEachCallback` as an alternative approach. -==== - [source,java,indent=0] ---- include::{testDir}/example/extensions/RandomNumberExtension.java[tags=user_guide] @@ -272,11 +260,8 @@ will be registered after the test class has been instantiated and after each reg (potentially injecting the instance of the extension to be used into the annotated field). Thus, if such an _instance extension_ implements class-level or instance-level extension APIs such as `BeforeAllCallback`, `AfterAllCallback`, or -`TestInstancePostProcessor`, those APIs will not be honored. By default, an instance -extension will be registered _after_ extensions that are registered at the method level -via `@ExtendWith`; however, if the test class is configured with -`@TestInstance(Lifecycle.PER_CLASS)` semantics, an instance extension will be registered -_before_ extensions that are registered at the method level via `@ExtendWith`. +`TestInstancePostProcessor`, those APIs will not be honored. Instance extensions will be +registered _before_ extensions that are registered at the method level via `@ExtendWith`. In the following example, the `docs` field in the test class is initialized programmatically by invoking a custom `lookUpDocsDir()` method and supplying the result diff --git a/documentation/src/test/java/example/extensions/RandomNumberExtension.java b/documentation/src/test/java/example/extensions/RandomNumberExtension.java index 2b16cc1c38b1..2317997eb8c3 100644 --- a/documentation/src/test/java/example/extensions/RandomNumberExtension.java +++ b/documentation/src/test/java/example/extensions/RandomNumberExtension.java @@ -18,17 +18,17 @@ import java.util.function.Predicate; import org.junit.jupiter.api.extension.BeforeAllCallback; -import org.junit.jupiter.api.extension.BeforeEachCallback; import org.junit.jupiter.api.extension.ExtensionContext; import org.junit.jupiter.api.extension.ParameterContext; import org.junit.jupiter.api.extension.ParameterResolver; +import org.junit.jupiter.api.extension.TestInstancePostProcessor; import org.junit.platform.commons.support.ModifierSupport; // end::user_guide[] // @formatter:off // tag::user_guide[] class RandomNumberExtension - implements BeforeAllCallback, BeforeEachCallback, ParameterResolver { + implements BeforeAllCallback, TestInstancePostProcessor, ParameterResolver { private final java.util.Random random = new java.util.Random(System.nanoTime()); @@ -47,9 +47,8 @@ public void beforeAll(ExtensionContext context) { * {@code @Random} and can be assigned an integer value. */ @Override - public void beforeEach(ExtensionContext context) { + public void postProcessTestInstance(Object testInstance, ExtensionContext context) { Class testClass = context.getRequiredTestClass(); - Object testInstance = context.getRequiredTestInstance(); injectFields(testClass, testInstance, ModifierSupport::isNotStatic); } diff --git a/junit-jupiter-api/src/main/java/org/junit/jupiter/api/extension/TestInstancePostProcessor.java b/junit-jupiter-api/src/main/java/org/junit/jupiter/api/extension/TestInstancePostProcessor.java index 1bc5609e81a2..6b0cd8e59b17 100644 --- a/junit-jupiter-api/src/main/java/org/junit/jupiter/api/extension/TestInstancePostProcessor.java +++ b/junit-jupiter-api/src/main/java/org/junit/jupiter/api/extension/TestInstancePostProcessor.java @@ -23,7 +23,9 @@ * etc. * *

Extensions that implement {@code TestInstancePostProcessor} must be - * registered at the class level. + * registered at the class level, {@linkplain ExtendWith declaratively} via a + * field of the test class, or {@linkplain RegisterExtension programmatically} + * via a static field of the test class. * *

Constructor Requirements

* diff --git a/junit-jupiter-engine/src/main/java/org/junit/jupiter/engine/descriptor/ClassBasedTestDescriptor.java b/junit-jupiter-engine/src/main/java/org/junit/jupiter/engine/descriptor/ClassBasedTestDescriptor.java index fd16422a98a0..4f52ac305f7c 100644 --- a/junit-jupiter-engine/src/main/java/org/junit/jupiter/engine/descriptor/ClassBasedTestDescriptor.java +++ b/junit-jupiter-engine/src/main/java/org/junit/jupiter/engine/descriptor/ClassBasedTestDescriptor.java @@ -15,7 +15,8 @@ import static org.junit.jupiter.engine.descriptor.ExtensionUtils.populateNewExtensionRegistryFromExtendWithAnnotation; import static org.junit.jupiter.engine.descriptor.ExtensionUtils.registerExtensionsFromConstructorParameters; import static org.junit.jupiter.engine.descriptor.ExtensionUtils.registerExtensionsFromExecutableParameters; -import static org.junit.jupiter.engine.descriptor.ExtensionUtils.registerExtensionsFromFields; +import static org.junit.jupiter.engine.descriptor.ExtensionUtils.registerExtensionsFromInstanceFields; +import static org.junit.jupiter.engine.descriptor.ExtensionUtils.registerExtensionsFromStaticFields; import static org.junit.jupiter.engine.descriptor.LifecycleMethodUtils.findAfterAllMethods; import static org.junit.jupiter.engine.descriptor.LifecycleMethodUtils.findAfterEachMethods; import static org.junit.jupiter.engine.descriptor.LifecycleMethodUtils.findBeforeAllMethods; @@ -152,7 +153,7 @@ public JupiterEngineExecutionContext prepare(JupiterEngineExecutionContext conte // Register extensions from static fields here, at the class level but // after extensions registered via @ExtendWith. - registerExtensionsFromFields(registry, this.testClass, null); + registerExtensionsFromStaticFields(registry, this.testClass); // Resolve the TestInstanceFactory at the class level in order to fail // the entire class in case of configuration errors (e.g., more than @@ -175,6 +176,7 @@ public JupiterEngineExecutionContext prepare(JupiterEngineExecutionContext conte registerBeforeEachMethodAdapters(registry); registerAfterEachMethodAdapters(registry); this.afterAllMethods.forEach(method -> registerExtensionsFromExecutableParameters(registry, method)); + registerExtensionsFromInstanceFields(registry, this.testClass); ThrowableCollector throwableCollector = createThrowableCollector(); ExecutableInvoker executableInvoker = new DefaultExecutableInvoker(context); @@ -288,10 +290,10 @@ private TestInstances instantiateAndPostProcessTestInstance(JupiterEngineExecuti throwableCollector); throwableCollector.execute(() -> { invokeTestInstancePostProcessors(instances.getInnermostInstance(), registry, extensionContext); - // In addition, we register extensions from instance fields here since the - // best time to do that is immediately following test class instantiation - // and post processing. - registerExtensionsFromFields(registrar, this.testClass, instances.getInnermostInstance()); + // In addition, we initialize extension registered programmatically from instance fields here + // since the best time to do that is immediately following test class instantiation + // and post-processing. + registrar.initializeExtensions(this.testClass, instances.getInnermostInstance()); }); return instances; } diff --git a/junit-jupiter-engine/src/main/java/org/junit/jupiter/engine/descriptor/ExtensionUtils.java b/junit-jupiter-engine/src/main/java/org/junit/jupiter/engine/descriptor/ExtensionUtils.java index 0cb68545e454..9cc9c64cf849 100644 --- a/junit-jupiter-engine/src/main/java/org/junit/jupiter/engine/descriptor/ExtensionUtils.java +++ b/junit-jupiter-engine/src/main/java/org/junit/jupiter/engine/descriptor/ExtensionUtils.java @@ -35,6 +35,7 @@ import org.junit.jupiter.api.extension.RegisterExtension; import org.junit.jupiter.engine.extension.ExtensionRegistrar; import org.junit.jupiter.engine.extension.MutableExtensionRegistry; +import org.junit.platform.commons.PreconditionViolationException; import org.junit.platform.commons.util.Preconditions; import org.junit.platform.commons.util.ReflectionUtils; @@ -71,60 +72,94 @@ static MutableExtensionRegistry populateNewExtensionRegistryFromExtendWithAnnota Preconditions.notNull(parentRegistry, "Parent ExtensionRegistry must not be null"); Preconditions.notNull(annotatedElement, "AnnotatedElement must not be null"); - return MutableExtensionRegistry.createRegistryFrom(parentRegistry, streamExtensionTypes(annotatedElement)); + return MutableExtensionRegistry.createRegistryFrom(parentRegistry, + streamDeclarativeExtensionTypes(annotatedElement)); } /** - * Register extensions using the supplied registrar from fields in the supplied - * class that are annotated with {@link ExtendWith @ExtendWith} or - * {@link RegisterExtension @RegisterExtension}. + * Register extensions using the supplied registrar from static fields in + * the supplied class that are annotated with {@link ExtendWith @ExtendWith} + * or {@link RegisterExtension @RegisterExtension}. * *

The extensions will be sorted according to {@link Order @Order} semantics * prior to registration. * * @param registrar the registrar with which to register the extensions; never {@code null} * @param clazz the class or interface in which to find the fields; never {@code null} - * @param instance the instance of the supplied class; may be {@code null} - * when searching for {@code static} fields in the class + * @since 5.11 */ - static void registerExtensionsFromFields(ExtensionRegistrar registrar, Class clazz, Object instance) { - Preconditions.notNull(registrar, "ExtensionRegistrar must not be null"); - Preconditions.notNull(clazz, "Class must not be null"); + static void registerExtensionsFromStaticFields(ExtensionRegistrar registrar, Class clazz) { + streamExtensionRegisteringFields(clazz, ReflectionUtils::isStatic) // + .forEach(field -> { + List> extensionTypes = streamDeclarativeExtensionTypes(field).collect( + toList()); + boolean isExtendWithPresent = !extensionTypes.isEmpty(); - Predicate predicate = (instance == null ? ReflectionUtils::isStatic : ReflectionUtils::isNotStatic); + if (isExtendWithPresent) { + extensionTypes.forEach(registrar::registerExtension); + } + if (isAnnotated(field, RegisterExtension.class)) { + Extension extension = readAndValidateExtensionFromField(field, null, extensionTypes); + registrar.registerExtension(extension, field); + } + }); + } - streamFields(clazz, predicate, TOP_DOWN)// - .sorted(orderComparator)// + /** + * Register extensions using the supplied registrar from instance fields in + * the supplied class that are annotated with {@link ExtendWith @ExtendWith} + * or {@link RegisterExtension @RegisterExtension}. + * + *

The extensions will be sorted according to {@link Order @Order} semantics + * prior to registration. + * + * @param registrar the registrar with which to register the extensions; never {@code null} + * @param clazz the class or interface in which to find the fields; never {@code null} + * @since 5.11 + */ + static void registerExtensionsFromInstanceFields(ExtensionRegistrar registrar, Class clazz) { + streamExtensionRegisteringFields(clazz, ReflectionUtils::isNotStatic) // .forEach(field -> { - List> extensionTypes = streamExtensionTypes(field).collect(toList()); + List> extensionTypes = streamDeclarativeExtensionTypes(field).collect( + toList()); boolean isExtendWithPresent = !extensionTypes.isEmpty(); - boolean isRegisterExtensionPresent = isAnnotated(field, RegisterExtension.class); + if (isExtendWithPresent) { extensionTypes.forEach(registrar::registerExtension); } - if (isRegisterExtensionPresent) { - tryToReadFieldValue(field, instance).ifSuccess(value -> { - Preconditions.condition(value instanceof Extension, () -> String.format( - "Failed to register extension via @RegisterExtension field [%s]: field value's type [%s] must implement an [%s] API.", - field, (value != null ? value.getClass().getName() : null), Extension.class.getName())); - - if (isExtendWithPresent) { - Class valueType = value.getClass(); - extensionTypes.forEach(extensionType -> { - Preconditions.condition(!extensionType.equals(valueType), - () -> String.format("Failed to register extension via field [%s]. " - + "The field registers an extension of type [%s] via @RegisterExtension and @ExtendWith, " - + "but only one registration of a given extension type is permitted.", - field, valueType.getName())); - }); - } - - registrar.registerExtension((Extension) value, field); - }); + if (isAnnotated(field, RegisterExtension.class)) { + registrar.registerUninitializedExtension(clazz, field, + instance -> readAndValidateExtensionFromField(field, instance, extensionTypes)); } }); } + /** + * @since 5.11 + */ + private static Extension readAndValidateExtensionFromField(Field field, Object instance, + List> declarativeExtensionTypes) { + Object value = tryToReadFieldValue(field, instance) // + .getOrThrow(e -> new PreconditionViolationException( + String.format("Failed to read @RegisterExtension field [%s]", field), e)); + + Preconditions.condition(value instanceof Extension, () -> String.format( + "Failed to register extension via @RegisterExtension field [%s]: field value's type [%s] must implement an [%s] API.", + field, (value != null ? value.getClass().getName() : null), Extension.class.getName())); + + declarativeExtensionTypes.forEach(extensionType -> { + Class valueType = value.getClass(); + Preconditions.condition(!extensionType.equals(valueType), + () -> String.format( + "Failed to register extension via field [%s]. " + + "The field registers an extension of type [%s] via @RegisterExtension and @ExtendWith, " + + "but only one registration of a given extension type is permitted.", + field, valueType.getName())); + }); + + return (Extension) value; + } + /** * Register extensions using the supplied registrar from parameters in the * declared constructor of the supplied class that are annotated with @@ -157,22 +192,34 @@ static void registerExtensionsFromExecutableParameters(ExtensionRegistrar regist // @formatter:off Arrays.stream(executable.getParameters()) .map(parameter -> findRepeatableAnnotations(parameter, index.getAndIncrement(), ExtendWith.class)) - .flatMap(ExtensionUtils::streamExtensionTypes) + .flatMap(ExtensionUtils::streamDeclarativeExtensionTypes) .forEach(registrar::registerExtension); // @formatter:on } /** - * @since 5.8 + * @since 5.11 */ - private static Stream> streamExtensionTypes(AnnotatedElement annotatedElement) { - return streamExtensionTypes(findRepeatableAnnotations(annotatedElement, ExtendWith.class)); + private static Stream streamExtensionRegisteringFields(Class clazz, Predicate predicate) { + Predicate composedPredicate = predicate.and( + field -> isAnnotated(field, ExtendWith.class) || isAnnotated(field, RegisterExtension.class)); + return streamFields(clazz, composedPredicate, TOP_DOWN)// + .sorted(orderComparator); } /** - * @since 5.8 + * @since 5.11 + */ + private static Stream> streamDeclarativeExtensionTypes( + AnnotatedElement annotatedElement) { + return streamDeclarativeExtensionTypes(findRepeatableAnnotations(annotatedElement, ExtendWith.class)); + } + + /** + * @since 5.11 */ - private static Stream> streamExtensionTypes(List extendWithAnnotations) { + private static Stream> streamDeclarativeExtensionTypes( + List extendWithAnnotations) { return extendWithAnnotations.stream().map(ExtendWith::value).flatMap(Arrays::stream); } diff --git a/junit-jupiter-engine/src/main/java/org/junit/jupiter/engine/descriptor/TestTemplateTestDescriptor.java b/junit-jupiter-engine/src/main/java/org/junit/jupiter/engine/descriptor/TestTemplateTestDescriptor.java index d02f61def41a..353c5c8325f1 100644 --- a/junit-jupiter-engine/src/main/java/org/junit/jupiter/engine/descriptor/TestTemplateTestDescriptor.java +++ b/junit-jupiter-engine/src/main/java/org/junit/jupiter/engine/descriptor/TestTemplateTestDescriptor.java @@ -74,7 +74,7 @@ public boolean mayRegisterTests() { // --- Node ---------------------------------------------------------------- @Override - public JupiterEngineExecutionContext prepare(JupiterEngineExecutionContext context) throws Exception { + public JupiterEngineExecutionContext prepare(JupiterEngineExecutionContext context) { MutableExtensionRegistry registry = populateNewExtensionRegistryFromExtendWithAnnotation( context.getExtensionRegistry(), getTestMethod()); diff --git a/junit-jupiter-engine/src/main/java/org/junit/jupiter/engine/extension/ExtensionRegistrar.java b/junit-jupiter-engine/src/main/java/org/junit/jupiter/engine/extension/ExtensionRegistrar.java index 8764904451d6..3c4e90ed4d13 100644 --- a/junit-jupiter-engine/src/main/java/org/junit/jupiter/engine/extension/ExtensionRegistrar.java +++ b/junit-jupiter-engine/src/main/java/org/junit/jupiter/engine/extension/ExtensionRegistrar.java @@ -12,8 +12,12 @@ import static org.apiguardian.api.API.Status.INTERNAL; +import java.lang.reflect.Field; +import java.util.function.Function; + import org.apiguardian.api.API; import org.junit.jupiter.api.extension.Extension; +import org.junit.jupiter.api.extension.RegisterExtension; /** * An {@code ExtensionRegistrar} is used to register extensions. @@ -45,11 +49,11 @@ public interface ExtensionRegistrar { * {@link org.junit.jupiter.api.extension.ExtendWith @ExtendWith}, the * {@code source} and the {@code extension} should be the same object. * However, if an extension is registered programmatically via - * {@link org.junit.jupiter.api.extension.RegisterExtension @RegisterExtension}, - * the {@code source} object should be the {@link java.lang.reflect.Field} - * that is annotated with {@code @RegisterExtension}. Similarly, if an - * extension is registered programmatically as a lambda expression - * or method reference, the {@code source} object should be the underlying + * {@link RegisterExtension @RegisterExtension}, the {@code source} object + * should be the {@link java.lang.reflect.Field} that is annotated with + * {@code @RegisterExtension}. Similarly, if an extension is registered + * programmatically as a lambda expression or method reference, the + * {@code source} object should be the underlying * {@link java.lang.reflect.Method} that implements the extension API. * * @param extension the extension to register; never {@code null} @@ -68,4 +72,34 @@ public interface ExtensionRegistrar { */ void registerSyntheticExtension(Extension extension, Object source); + /** + * Register an uninitialized extension for the supplied {@code testClass} to + * be initialized using the supplied {@code initializer} when an instance of + * the test class is created. + * + *

Uninitialized extensions are typically registered for fields annotated + * with {@link RegisterExtension @RegisterExtension} that cannot be + * initialized until an instance of the test class is created. Until they + * are initialized, such extensions are not available for use. + * + * @param testClass the test class for which the extension is registered; + * never {@code null} + * @param source the source of the extension; never {@code null} + * @param initializer the initializer function to be used to create the + * extension; never {@code null} + */ + void registerUninitializedExtension(Class testClass, Field source, + Function initializer); + + /** + * Initialize all registered extensions for the supplied {@code testClass} + * using the supplied {@code testInstance}. + * + * @param testClass the test class for which the extensions are initialized; + * never {@code null} + * @param testInstance the test instance to be used to initialize the + * extensions; never {@code null} + */ + void initializeExtensions(Class testClass, Object testInstance); + } diff --git a/junit-jupiter-engine/src/main/java/org/junit/jupiter/engine/extension/MutableExtensionRegistry.java b/junit-jupiter-engine/src/main/java/org/junit/jupiter/engine/extension/MutableExtensionRegistry.java index 3791f83b8c7b..b7b3f43034dd 100644 --- a/junit-jupiter-engine/src/main/java/org/junit/jupiter/engine/extension/MutableExtensionRegistry.java +++ b/junit-jupiter-engine/src/main/java/org/junit/jupiter/engine/extension/MutableExtensionRegistry.java @@ -10,18 +10,24 @@ package org.junit.jupiter.engine.extension; -import static java.util.stream.Stream.concat; +import static java.util.Collections.emptyList; +import static java.util.Collections.emptySet; import static org.apiguardian.api.API.Status.INTERNAL; +import java.lang.reflect.Field; import java.lang.reflect.Member; import java.lang.reflect.Method; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; +import java.util.LinkedHashMap; import java.util.LinkedHashSet; import java.util.List; +import java.util.Map; +import java.util.Optional; import java.util.ServiceLoader; import java.util.Set; +import java.util.function.Function; import java.util.stream.Stream; import org.apiguardian.api.API; @@ -36,10 +42,6 @@ /** * Default, mutable implementation of {@link ExtensionRegistry}. * - *

A registry has a reference to its parent registry, and all lookups are - * performed first in the current registry itself and then recursively in its - * ancestors. - * * @since 5.5 */ @API(status = INTERNAL, since = "5.5") @@ -69,7 +71,7 @@ public class MutableExtensionRegistry implements ExtensionRegistry, ExtensionReg * @return a new {@code ExtensionRegistry}; never {@code null} */ public static MutableExtensionRegistry createRegistryWithDefaultExtensions(JupiterConfiguration configuration) { - MutableExtensionRegistry extensionRegistry = new MutableExtensionRegistry(null); + MutableExtensionRegistry extensionRegistry = new MutableExtensionRegistry(); DEFAULT_STATELESS_EXTENSIONS.forEach(extensionRegistry::registerDefaultExtension); @@ -106,38 +108,41 @@ public static MutableExtensionRegistry createRegistryFrom(MutableExtensionRegist return registry; } - private final MutableExtensionRegistry parent; - - private final Set> registeredExtensionTypes = new LinkedHashSet<>(); + private final Set> registeredExtensionTypes; + private final List registeredExtensions; + private final Map, LateInitExtensions> lateInitExtensions; - private final List registeredExtensions = new ArrayList<>(); + private MutableExtensionRegistry() { + this(emptySet(), emptyList()); + } private MutableExtensionRegistry(MutableExtensionRegistry parent) { - this.parent = parent; + this(parent.registeredExtensionTypes, parent.registeredExtensions); } - @Override - public Stream stream(Class extensionType) { - if (this.parent == null) { - return streamLocal(extensionType); - } - return concat(this.parent.stream(extensionType), streamLocal(extensionType)); + private MutableExtensionRegistry(Set> registeredExtensionTypes, + List registeredExtensions) { + this.registeredExtensionTypes = new LinkedHashSet<>(registeredExtensionTypes); + this.registeredExtensions = new ArrayList<>(registeredExtensions.size()); + this.lateInitExtensions = new LinkedHashMap<>(); + registeredExtensions.forEach(entry -> { + Entry newEntry = entry; + if (entry instanceof LateInitEntry) { + LateInitEntry lateInitEntry = (LateInitEntry) entry; + newEntry = lateInitEntry.getExtension() // + .map(Entry::of) // + .orElseGet(() -> getLateInitExtensions(lateInitEntry.getTestClass()).add(lateInitEntry.copy())); + } + this.registeredExtensions.add(newEntry); + }); } - /** - * Stream all {@code Extensions} of the specified type that are present - * in this registry. - * - *

Extensions in ancestors are ignored. - * - * @param extensionType the type of {@link Extension} to stream - */ - private Stream streamLocal(Class extensionType) { - // @formatter:off - return this.registeredExtensions.stream() - .filter(extensionType::isInstance) + @Override + public Stream stream(Class extensionType) { + return this.registeredExtensions.stream() // + .map(p -> p.getExtension().orElse(null)) // + .filter(extensionType::isInstance) // .map(extensionType::cast); - // @formatter:on } @Override @@ -152,8 +157,7 @@ public void registerExtension(Class extensionType) { * parent registry. */ private boolean isAlreadyRegistered(Class extensionType) { - return (this.registeredExtensionTypes.contains(extensionType) - || (this.parent != null && this.parent.isAlreadyRegistered(extensionType))); + return this.registeredExtensionTypes.contains(extensionType); } @Override @@ -167,6 +171,36 @@ public void registerSyntheticExtension(Extension extension, Object source) { registerExtension("synthetic", extension, source); } + @Override + public void registerUninitializedExtension(Class testClass, Field source, + Function initializer) { + Preconditions.notNull(testClass, "testClass must not be null"); + Preconditions.notNull(source, "source must not be null"); + Preconditions.notNull(initializer, "initializer must not be null"); + + logger.trace(() -> String.format("Registering local extension (late-init) for [%s]%s", + source.getType().getName(), buildSourceInfo(source))); + + LateInitEntry entry = getLateInitExtensions(testClass) // + .add(new LateInitEntry(testClass, initializer)); + this.registeredExtensions.add(entry); + } + + @Override + public void initializeExtensions(Class testClass, Object testInstance) { + Preconditions.notNull(testClass, "testClass must not be null"); + Preconditions.notNull(testInstance, "testInstance must not be null"); + + LateInitExtensions extensions = lateInitExtensions.remove(testClass); + if (extensions != null) { + extensions.initialize(testInstance); + } + } + + private LateInitExtensions getLateInitExtensions(Class testClass) { + return this.lateInitExtensions.computeIfAbsent(testClass, __ -> new LateInitExtensions()); + } + private void registerDefaultExtension(Extension extension) { registerExtension("default", extension); } @@ -185,12 +219,12 @@ private void registerExtension(String category, Extension extension) { private void registerExtension(String category, Extension extension, Object source) { Preconditions.notBlank(category, "category must not be null or blank"); - Preconditions.notNull(extension, "Extension must not be null"); + Preconditions.notNull(extension, "extension must not be null"); logger.trace( () -> String.format("Registering %s extension [%s]%s", category, extension, buildSourceInfo(source))); - this.registeredExtensions.add(extension); + this.registeredExtensions.add(Entry.of(extension)); this.registeredExtensionTypes.add(extension.getClass()); } @@ -206,4 +240,62 @@ private String buildSourceInfo(Object source) { return " from source [" + source + "]"; } + private interface Entry { + + static Entry of(Extension extension) { + Optional value = Optional.of(extension); + return () -> value; + } + + Optional getExtension(); + } + + private static class LateInitEntry implements Entry { + + private final Class testClass; + private final Function initializer; + + @SuppressWarnings("OptionalUsedAsFieldOrParameterType") + private Optional extension = Optional.empty(); + + public LateInitEntry(Class testClass, Function initializer) { + this.testClass = testClass; + this.initializer = initializer; + } + + @Override + public Optional getExtension() { + return extension; + } + + public Class getTestClass() { + return testClass; + } + + void initialize(Object testInstance) { + Preconditions.condition(!extension.isPresent(), "Extension already initialized"); + extension = Optional.of(initializer.apply(testInstance)); + } + + LateInitEntry copy() { + Preconditions.condition(!extension.isPresent(), "Extension already initialized"); + return new LateInitEntry(testClass, initializer); + } + } + + private static class LateInitExtensions { + + private final List entries = new ArrayList<>(); + + LateInitEntry add(LateInitEntry entry) { + entries.add(entry); + return entry; + } + + void initialize(Object testInstance) { + entries.forEach(entry -> entry.initialize(testInstance)); + } + + } + } diff --git a/junit-jupiter-engine/src/test/java/org/junit/jupiter/engine/AbstractJupiterTestEngineTests.java b/junit-jupiter-engine/src/test/java/org/junit/jupiter/engine/AbstractJupiterTestEngineTests.java index 75a92cc6bf3e..209f07cc4767 100644 --- a/junit-jupiter-engine/src/test/java/org/junit/jupiter/engine/AbstractJupiterTestEngineTests.java +++ b/junit-jupiter-engine/src/test/java/org/junit/jupiter/engine/AbstractJupiterTestEngineTests.java @@ -21,6 +21,7 @@ import org.junit.platform.engine.TestDescriptor; import org.junit.platform.engine.UniqueId; import org.junit.platform.launcher.LauncherDiscoveryRequest; +import org.junit.platform.launcher.core.LauncherDiscoveryRequestBuilder; import org.junit.platform.testkit.engine.EngineExecutionResults; import org.junit.platform.testkit.engine.EngineTestKit; @@ -38,7 +39,11 @@ protected EngineExecutionResults executeTestsForClass(Class testClass) { } protected EngineExecutionResults executeTests(DiscoverySelector... selectors) { - return executeTests(request().selectors(selectors).build()); + return executeTests(request().selectors(selectors)); + } + + protected EngineExecutionResults executeTests(LauncherDiscoveryRequestBuilder builder) { + return executeTests(builder.build()); } protected EngineExecutionResults executeTests(LauncherDiscoveryRequest request) { diff --git a/junit-jupiter-engine/src/test/java/org/junit/jupiter/engine/extension/ExtensionRegistrationViaParametersAndFieldsTests.java b/junit-jupiter-engine/src/test/java/org/junit/jupiter/engine/extension/ExtensionRegistrationViaParametersAndFieldsTests.java index 0d10f40a5fbd..6aeb38a2926d 100644 --- a/junit-jupiter-engine/src/test/java/org/junit/jupiter/engine/extension/ExtensionRegistrationViaParametersAndFieldsTests.java +++ b/junit-jupiter-engine/src/test/java/org/junit/jupiter/engine/extension/ExtensionRegistrationViaParametersAndFieldsTests.java @@ -10,12 +10,15 @@ package org.junit.jupiter.engine.extension; -import static java.util.stream.Collectors.toList; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertSame; import static org.junit.jupiter.api.DynamicTest.dynamicTest; +import static org.junit.jupiter.api.parallel.ExecutionMode.CONCURRENT; import static org.junit.platform.commons.util.AnnotationUtils.findAnnotatedFields; import static org.junit.platform.commons.util.ReflectionUtils.makeAccessible; +import static org.junit.platform.engine.discovery.DiscoverySelectors.selectClass; +import static org.junit.platform.launcher.core.LauncherDiscoveryRequestBuilder.request; import static org.junit.platform.testkit.engine.EventConditions.finishedWithFailure; import static org.junit.platform.testkit.engine.TestExecutionResultConditions.instanceOf; import static org.junit.platform.testkit.engine.TestExecutionResultConditions.message; @@ -42,8 +45,10 @@ import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.DynamicTest; +import org.junit.jupiter.api.Named; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Order; +import org.junit.jupiter.api.RepeatedTest; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestFactory; import org.junit.jupiter.api.TestInfo; @@ -51,7 +56,6 @@ import org.junit.jupiter.api.TestInstance.Lifecycle; import org.junit.jupiter.api.TestTemplate; import org.junit.jupiter.api.extension.BeforeAllCallback; -import org.junit.jupiter.api.extension.BeforeEachCallback; import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.Extension; import org.junit.jupiter.api.extension.ExtensionContext; @@ -59,15 +63,19 @@ import org.junit.jupiter.api.extension.ParameterResolutionException; import org.junit.jupiter.api.extension.ParameterResolver; import org.junit.jupiter.api.extension.RegisterExtension; +import org.junit.jupiter.api.extension.TestInstancePostProcessor; import org.junit.jupiter.api.extension.TestTemplateInvocationContext; import org.junit.jupiter.api.extension.TestTemplateInvocationContextProvider; import org.junit.jupiter.api.fixtures.TrackLogRecords; +import org.junit.jupiter.api.parallel.Execution; import org.junit.jupiter.engine.AbstractJupiterTestEngineTests; +import org.junit.jupiter.engine.config.JupiterConfiguration; import org.junit.jupiter.engine.execution.injection.sample.LongParameterResolver; import org.junit.platform.commons.PreconditionViolationException; import org.junit.platform.commons.logging.LogRecordListener; import org.junit.platform.commons.util.ExceptionUtils; import org.junit.platform.commons.util.ReflectionUtils; +import org.junit.platform.testkit.engine.EngineExecutionResults; /** * Integration tests that verify support for extension registration via @@ -75,6 +83,7 @@ * * @since 5.8 */ +@SuppressWarnings("JUnitMalformedDeclaration") class ExtensionRegistrationViaParametersAndFieldsTests extends AbstractJupiterTestEngineTests { @Test @@ -156,62 +165,67 @@ void duplicateRegistrationViaField() { finishedWithFailure(instanceOf(PreconditionViolationException.class), message(expectedMessage))); } + @TestFactory + Stream registrationOrder(@TrackLogRecords LogRecordListener listener) { + return Stream.of( // + Named.named("per-method", AllInOneWithTestInstancePerMethodTestCase.class), // + Named.named("per-class", AllInOneWithTestInstancePerClassTestCase.class) // + ) // + .map(it -> dynamicTest(it.getName(), () -> { + listener.clear(); + assertOneTestSucceeded(it.getPayload()); + assertThat(getRegisteredLocalExtensions(listener))// + .containsExactly(// + "ClassLevelExtension2", // @RegisterExtension on static field + "StaticField2", // @ExtendWith on static field + "ClassLevelExtension1", // @RegisterExtension on static field + "StaticField1", // @ExtendWith on static field + "ConstructorParameter", // @ExtendWith on parameter in constructor + "BeforeAllParameter", // @ExtendWith on parameter in static @BeforeAll method + "BeforeEachParameter", // @ExtendWith on parameter in @BeforeEach method + "AfterEachParameter", // @ExtendWith on parameter in @AfterEach method + "AfterAllParameter", // @ExtendWith on parameter in static @AfterAll method + "InstanceLevelExtension1", // @RegisterExtension on instance field + "InstanceField1", // @ExtendWith on instance field + "InstanceLevelExtension2", // @RegisterExtension on instance field + "InstanceField2", // @ExtendWith on instance field + "TestParameter" // @ExtendWith on parameter in @Test method + ); + })); + } + + @Test + void registersProgrammaticTestInstancePostProcessors() { + assertOneTestSucceeded(ProgrammaticTestInstancePostProcessorTestCase.class); + } + @Test - void registrationOrder(@TrackLogRecords LogRecordListener listener) { - assertOneTestSucceeded(AllInOneWithTestInstancePerMethodTestCase.class); - assertThat(getRegisteredLocalExtensions(listener))// - .containsExactly(// - "ClassLevelExtension2", // @RegisterExtension on static field - "StaticField2", // @ExtendWith on static field - "ClassLevelExtension1", // @RegisterExtension on static field - "StaticField1", // @ExtendWith on static field - "ConstructorParameter", // @ExtendWith on parameter in constructor - "BeforeAllParameter", // @ExtendWith on parameter in static @BeforeAll method - "BeforeEachParameter", // @ExtendWith on parameter in @BeforeEach method - "AfterEachParameter", // @ExtendWith on parameter in @AfterEach method - "AfterAllParameter", // @ExtendWith on parameter in static @AfterAll method - "TestParameter", // @ExtendWith on parameter in @Test method - "InstanceLevelExtension1", // @RegisterExtension on instance field - "InstanceField1", // @ExtendWith on instance field - "InstanceLevelExtension2", // @RegisterExtension on instance field - "InstanceField2" // @ExtendWith on instance field - ); - - listener.clear(); - assertOneTestSucceeded(AllInOneWithTestInstancePerClassTestCase.class); - assertThat(getRegisteredLocalExtensions(listener))// - .containsExactly(// - "ClassLevelExtension2", // @RegisterExtension on static field - "StaticField2", // @ExtendWith on static field - "ClassLevelExtension1", // @RegisterExtension on static field - "StaticField1", // @ExtendWith on static field - "ConstructorParameter", // @ExtendWith on parameter in constructor - "BeforeAllParameter", // @ExtendWith on parameter in static @BeforeAll method - "BeforeEachParameter", // @ExtendWith on parameter in @BeforeEach method - "AfterEachParameter", // @ExtendWith on parameter in @AfterEach method - "AfterAllParameter", // @ExtendWith on parameter in static @AfterAll method - "InstanceLevelExtension1", // @RegisterExtension on instance field - "InstanceField1", // @ExtendWith on instance field - "InstanceLevelExtension2", // @RegisterExtension on instance field - "InstanceField2", // @ExtendWith on instance field - "TestParameter" // @ExtendWith on parameter in @Test method - ); + void createsExtensionPerInstance() { + var results = executeTests(request() // + .selectors(selectClass(InitializationPerInstanceTestCase.class)) // + .configurationParameter(JupiterConfiguration.PARALLEL_EXECUTION_ENABLED_PROPERTY_NAME, "true") // + ); + assertTestsSucceeded(results, 100); } private List getRegisteredLocalExtensions(LogRecordListener listener) { - // @formatter:off - return listener.stream(MutableExtensionRegistry.class, Level.FINER) - .map(LogRecord::getMessage) - .filter(message -> message.contains("local extension")) - .map(message -> { - message = message.replaceAll("from source .+", ""); - int indexOfDollarSign = message.indexOf("$"); - int indexOfAtSign = message.indexOf("@"); - int endIndex = (indexOfDollarSign > 1 ? indexOfDollarSign : indexOfAtSign); - return message.substring(message.lastIndexOf('.') + 1, endIndex); - }) - .collect(toList()); - // @formatter:on + return listener.stream(MutableExtensionRegistry.class, Level.FINER) // + .map(LogRecord::getMessage) // + .filter(message -> message.contains("local extension")) // + .map(message -> { + message = message.replaceAll(" from source .+", ""); + int beginIndex = message.lastIndexOf('.') + 1; + if (message.contains("late-init")) { + return message.substring(beginIndex, message.indexOf("]")); + } + else { + int indexOfDollarSign = message.indexOf("$"); + int indexOfAtSign = message.indexOf("@"); + int endIndex = (indexOfDollarSign > 1 ? indexOfDollarSign : indexOfAtSign); + return message.substring(beginIndex, endIndex); + } + }) // + .toList(); } private void assertOneTestSucceeded(Class testClass) { @@ -219,7 +233,11 @@ private void assertOneTestSucceeded(Class testClass) { } private void assertTestsSucceeded(Class testClass, int expected) { - executeTestsForClass(testClass).testEvents().assertStatistics( + assertTestsSucceeded(executeTestsForClass(testClass), expected); + } + + private static void assertTestsSucceeded(EngineExecutionResults results, int expected) { + results.testEvents().assertStatistics( stats -> stats.started(expected).succeeded(expected).skipped(0).aborted(0).failed(0)); } @@ -557,7 +575,7 @@ static class MultipleRegistrationsViaFieldTestCase { @ExtendWith(LongParameterResolver.class) @RegisterExtension - Extension dummy = new DummyExtension(); + DummyExtension dummy = new DummyExtension(); @Test void test() { @@ -580,6 +598,7 @@ void test() { */ static class StaticFieldTestCase { + @SuppressWarnings("unused") @MagicField private static String staticField1; @@ -612,8 +631,8 @@ static class InstanceFieldTestCase { @Test void test() { - assertThat(instanceField1).isEqualTo("beforeEach - instanceField1"); - assertThat(instanceField2).isEqualTo("beforeEach - instanceField2"); + assertThat(instanceField1).isEqualTo("postProcessTestInstance - instanceField1"); + assertThat(instanceField2).isEqualTo("postProcessTestInstance - instanceField2"); } } @@ -633,13 +652,13 @@ static class TestInstancePerClassFieldTestCase { @BeforeAll void beforeAll() { assertThat(staticField).isEqualTo("beforeAll - staticField"); - assertThat(instanceField).isNull(); + assertThat(instanceField).isEqualTo("postProcessTestInstance - instanceField"); } @Test void test() { assertThat(staticField).isEqualTo("beforeAll - staticField"); - assertThat(instanceField).isEqualTo("beforeEach - instanceField"); + assertThat(instanceField).isEqualTo("postProcessTestInstance - instanceField"); } } @@ -672,11 +691,11 @@ static class AllInOneWithTestInstancePerMethodTestCase { @RegisterExtension @Order(1) - private Extension instanceLevelExtension1 = new InstanceLevelExtension1(); + private InstanceLevelExtension1 instanceLevelExtension1 = new InstanceLevelExtension1(); @RegisterExtension @Order(3) - Extension instanceLevelExtension2 = new InstanceLevelExtension2(); + InstanceLevelExtension2 instanceLevelExtension2 = new InstanceLevelExtension2(); AllInOneWithTestInstancePerMethodTestCase(@ConstructorParameter String text) { assertThat(text).isEqualTo("enigma"); @@ -694,8 +713,8 @@ void beforeEach(@BeforeEachParameter String text) { assertThat(text).isEqualTo("enigma"); assertThat(staticField1).isEqualTo("beforeAll - staticField1"); assertThat(staticField2).isEqualTo("beforeAll - staticField2"); - assertThat(instanceField1).isEqualTo("beforeEach - instanceField1"); - assertThat(instanceField2).isEqualTo("beforeEach - instanceField2"); + assertThat(instanceField1).isEqualTo("postProcessTestInstance - instanceField1"); + assertThat(instanceField2).isEqualTo("postProcessTestInstance - instanceField2"); } @Test @@ -703,8 +722,8 @@ void test(@TestParameter String text) { assertThat(text).isEqualTo("enigma"); assertThat(staticField1).isEqualTo("beforeAll - staticField1"); assertThat(staticField2).isEqualTo("beforeAll - staticField2"); - assertThat(instanceField1).isEqualTo("beforeEach - instanceField1"); - assertThat(instanceField2).isEqualTo("beforeEach - instanceField2"); + assertThat(instanceField1).isEqualTo("postProcessTestInstance - instanceField1"); + assertThat(instanceField2).isEqualTo("postProcessTestInstance - instanceField2"); } @AfterEach @@ -712,8 +731,8 @@ void afterEach(@AfterEachParameter String text) { assertThat(text).isEqualTo("enigma"); assertThat(staticField1).isEqualTo("beforeAll - staticField1"); assertThat(staticField2).isEqualTo("beforeAll - staticField2"); - assertThat(instanceField1).isEqualTo("beforeEach - instanceField1"); - assertThat(instanceField2).isEqualTo("beforeEach - instanceField2"); + assertThat(instanceField1).isEqualTo("postProcessTestInstance - instanceField1"); + assertThat(instanceField2).isEqualTo("postProcessTestInstance - instanceField2"); } @AfterAll @@ -733,6 +752,55 @@ static class AllInOneWithTestInstancePerClassTestCase extends AllInOneWithTestIn } } + static class ProgrammaticTestInstancePostProcessorTestCase { + + @RegisterExtension + static Extension resolver = new InstanceField2.Extension(); + + @InstanceField2 + String instanceField2; + + @Test + void test() { + assertThat(instanceField2).isEqualTo("postProcessTestInstance - instanceField2"); + } + } + + @Execution(CONCURRENT) + static class InitializationPerInstanceTestCase { + @RegisterExtension + Extension extension = new InstanceParameterResolver<>(this); + + @Nested + class Wrapper { + + @RegisterExtension + Extension extension = new InstanceParameterResolver<>(this); + + @RepeatedTest(100) + void test(InitializationPerInstanceTestCase outerInstance, Wrapper innerInstance) { + assertSame(InitializationPerInstanceTestCase.this, outerInstance); + assertSame(Wrapper.this, innerInstance); + } + + } + + private record InstanceParameterResolver(T instance) implements ParameterResolver { + + @Override + public boolean supportsParameter(ParameterContext parameterContext, ExtensionContext extensionContext) + throws ParameterResolutionException { + return instance.getClass().equals(parameterContext.getParameter().getType()); + } + + @Override + public Object resolveParameter(ParameterContext parameterContext, ExtensionContext extensionContext) + throws ParameterResolutionException { + return instance; + } + } + } + } @Target(ElementType.PARAMETER) @@ -838,7 +906,7 @@ class Extension extends BaseParameterExtension { class DummyExtension implements Extension { } -class BaseFieldExtension implements BeforeAllCallback, BeforeEachCallback { +class BaseFieldExtension implements BeforeAllCallback, TestInstancePostProcessor { private final Class annotationType; @@ -849,13 +917,13 @@ class BaseFieldExtension implements BeforeAllCallback, Bef } @Override - public final void beforeAll(ExtensionContext context) throws Exception { + public final void beforeAll(ExtensionContext context) { injectFields("beforeAll", context.getRequiredTestClass(), null, ReflectionUtils::isStatic); } @Override - public final void beforeEach(ExtensionContext context) throws Exception { - injectFields("beforeEach", context.getRequiredTestClass(), context.getRequiredTestInstance(), + public final void postProcessTestInstance(Object testInstance, ExtensionContext context) { + injectFields("postProcessTestInstance", context.getRequiredTestClass(), testInstance, ReflectionUtils::isNotStatic); } diff --git a/junit-jupiter-engine/src/test/java/org/junit/jupiter/engine/extension/TestWatcherTests.java b/junit-jupiter-engine/src/test/java/org/junit/jupiter/engine/extension/TestWatcherTests.java index b9763b382f71..27ca3d2d77c9 100644 --- a/junit-jupiter-engine/src/test/java/org/junit/jupiter/engine/extension/TestWatcherTests.java +++ b/junit-jupiter-engine/src/test/java/org/junit/jupiter/engine/extension/TestWatcherTests.java @@ -11,7 +11,6 @@ package org.junit.jupiter.engine.extension; import static java.util.function.Predicate.not; -import static java.util.stream.Collectors.toUnmodifiableList; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -60,10 +59,10 @@ */ class TestWatcherTests extends AbstractJupiterTestEngineTests { - private static final List testWatcherMethodNames = Arrays.stream(TestWatcher.class.getDeclaredMethods())// - .filter(not(Method::isSynthetic))// - .map(Method::getName)// - .collect(toUnmodifiableList()); + private static final List testWatcherMethodNames = Arrays.stream(TestWatcher.class.getDeclaredMethods()) // + .filter(not(Method::isSynthetic)) // + .map(Method::getName) // + .toList(); @BeforeEach void clearResults() {