Skip to content

Commit

Permalink
Introduce @⁠DisabledInAotMode in the TestContext framework
Browse files Browse the repository at this point in the history
This commit introduces @⁠DisabledInAotMode in the TestContext
framework to support the following use cases.

- Disabling AOT build-time processing of a test ApplicationContext --
  applicable to any testing framework (JUnit 4, JUnit Jupiter, etc.).

- Disabling an entire test class or a single test method at run time
  when the test suite is run with AOT optimizations enabled -- only
  applicable to JUnit Jupiter based tests.

Closes gh-30834
  • Loading branch information
sbrannen committed Oct 15, 2023
1 parent 8e5f39b commit 39a282e
Show file tree
Hide file tree
Showing 10 changed files with 384 additions and 40 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.test.context.aot;

import java.lang.annotation.Documented;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

import org.junit.jupiter.api.condition.DisabledIf;

/**
* {@code @DisabledInAotMode} signals that an annotated test class is <em>disabled</em>
* in Spring AOT (ahead-of-time) mode, which means that the {@code ApplicationContext}
* for the test class will not be processed for AOT optimizations at build time.
*
* <p>If a test class is annotated with {@code @DisabledInAotMode}, all other test
* classes which specify configuration to load the same {@code ApplicationContext}
* must also be annotated with {@code @DisabledInAotMode}. Failure to annotate
* all such test classes will result in a exception, either at build time or
* run time.
*
* <p>When used with JUnit Jupiter based tests, {@code @DisabledInAotMode} also
* signals that the annotated test class or test method is <em>disabled</em> when
* running the test suite in Spring AOT mode. When applied at the class level,
* all test methods within that class will be disabled. In this sense,
* {@code @DisabledInAotMode} has semantics similar to those of JUnit Jupiter's
* {@link org.junit.jupiter.api.condition.DisabledInNativeImage @DisabledInNativeImage}
* annotation.
*
* <p>This annotation may be used as a meta-annotation in order to create a
* custom <em>composed annotation</em> that inherits the semantics of this
* annotation.
*
* @author Sam Brannen
* @since 6.1
* @see org.springframework.aot.AotDetector#useGeneratedArtifacts() AotDetector.useGeneratedArtifacts()
* @see org.junit.jupiter.api.condition.EnabledInNativeImage @EnabledInNativeImage
* @see org.junit.jupiter.api.condition.DisabledInNativeImage @DisabledInNativeImage
*/
@Target({ElementType.TYPE, ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
@DisabledIf(value = "org.springframework.aot.AotDetector#useGeneratedArtifacts",
disabledReason = "Disabled in Spring AOT mode")
public @interface DisabledInAotMode {
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@

import java.util.Arrays;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Predicate;
import java.util.stream.Stream;

import org.apache.commons.logging.Log;
Expand Down Expand Up @@ -91,6 +93,9 @@ public class TestContextAotGenerator {

private static final Log logger = LogFactory.getLog(TestContextAotGenerator.class);

private static final Predicate<? super Class<?>> isDisabledInAotMode =
testClass -> MergedAnnotations.from(testClass).isPresent(DisabledInAotMode.class);


private final ApplicationContextAotGenerator aotGenerator = new ApplicationContextAotGenerator();

Expand Down Expand Up @@ -235,35 +240,56 @@ private MultiValueMap<ClassName, Class<?>> processAheadOfTime(
ClassLoader classLoader = getClass().getClassLoader();
MultiValueMap<ClassName, Class<?>> initializerClassMappings = new LinkedMultiValueMap<>();
mergedConfigMappings.forEach((mergedConfig, testClasses) -> {
if (logger.isDebugEnabled()) {
logger.debug("Generating AOT artifacts for test classes " +
testClasses.stream().map(Class::getName).toList());
}
this.mergedConfigRuntimeHints.registerHints(this.runtimeHints, mergedConfig, classLoader);
try {
// Use first test class discovered for a given unique MergedContextConfiguration.
Class<?> testClass = testClasses.get(0);
DefaultGenerationContext generationContext = createGenerationContext(testClass);
ClassName initializer = processAheadOfTime(mergedConfig, generationContext);
Assert.state(!initializerClassMappings.containsKey(initializer),
() -> "ClassName [%s] already encountered".formatted(initializer.reflectionName()));
initializerClassMappings.addAll(initializer, testClasses);
generationContext.writeGeneratedContent();
}
catch (Exception ex) {
if (this.failOnError) {
throw new TestContextAotException("Failed to generate AOT artifacts for test classes " +
testClasses.stream().map(Class::getName).toList(), ex);
long numDisabled = testClasses.stream().filter(isDisabledInAotMode).count();
// At least one test class is disabled?
if (numDisabled > 0) {
// Then all related test classes should be disabled.
if (numDisabled != testClasses.size()) {
if (this.failOnError) {
throw new TestContextAotException("""
All test classes that share an ApplicationContext must be annotated
with @DisabledInAotMode if one of them is: """ + classNames(testClasses));
}
else if (logger.isWarnEnabled()) {
logger.warn("""
All test classes that share an ApplicationContext must be annotated
with @DisabledInAotMode if one of them is: """ + classNames(testClasses));
}
}
if (logger.isInfoEnabled()) {
logger.info("Skipping AOT processing due to the presence of @DisabledInAotMode for test classes " +
classNames(testClasses));
}
}
else {
if (logger.isDebugEnabled()) {
logger.debug("Failed to generate AOT artifacts for test classes " +
testClasses.stream().map(Class::getName).toList(), ex);
logger.debug("Generating AOT artifacts for test classes " + classNames(testClasses));
}
this.mergedConfigRuntimeHints.registerHints(this.runtimeHints, mergedConfig, classLoader);
try {
// Use first test class discovered for a given unique MergedContextConfiguration.
Class<?> testClass = testClasses.get(0);
DefaultGenerationContext generationContext = createGenerationContext(testClass);
ClassName initializer = processAheadOfTime(mergedConfig, generationContext);
Assert.state(!initializerClassMappings.containsKey(initializer),
() -> "ClassName [%s] already encountered".formatted(initializer.reflectionName()));
initializerClassMappings.addAll(initializer, testClasses);
generationContext.writeGeneratedContent();
}
else if (logger.isWarnEnabled()) {
logger.warn("""
catch (Exception ex) {
if (this.failOnError) {
throw new TestContextAotException("Failed to generate AOT artifacts for test classes " +
classNames(testClasses), ex);
}
if (logger.isDebugEnabled()) {
logger.debug("Failed to generate AOT artifacts for test classes " + classNames(testClasses), ex);
}
else if (logger.isWarnEnabled()) {
logger.warn("""
Failed to generate AOT artifacts for test classes %s. \
Enable DEBUG logging to view the stack trace. %s"""
.formatted(testClasses.stream().map(Class::getName).toList(), ex));
.formatted(classNames(testClasses), ex));
}
}
}
});
Expand Down Expand Up @@ -401,4 +427,8 @@ private static boolean getFailOnErrorFlag() {
return true;
}

private static List<String> classNames(List<Class<?>> classes) {
return classes.stream().map(Class::getName).toList();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,13 @@ abstract class AbstractAotTests {
"org/springframework/context/event/EventListenerMethodProcessor__TestContext005_BeanDefinitions.java",
"org/springframework/test/context/aot/samples/basic/BasicSpringVintageTests__TestContext005_ApplicationContextInitializer.java",
"org/springframework/test/context/aot/samples/basic/BasicSpringVintageTests__TestContext005_BeanFactoryRegistrations.java",
"org/springframework/test/context/aot/samples/basic/BasicTestConfiguration__TestContext005_BeanDefinitions.java"
"org/springframework/test/context/aot/samples/basic/BasicTestConfiguration__TestContext005_BeanDefinitions.java",
// DisabledInAotRuntimeMethodLevelTests
"org/springframework/context/event/DefaultEventListenerFactory__TestContext006_BeanDefinitions.java",
"org/springframework/context/event/EventListenerMethodProcessor__TestContext006_BeanDefinitions.java",
"org/springframework/test/context/aot/samples/basic/DisabledInAotRuntimeMethodLevelTests__TestContext006_ApplicationContextInitializer.java",
"org/springframework/test/context/aot/samples/basic/DisabledInAotRuntimeMethodLevelTests__TestContext006_BeanDefinitions.java",
"org/springframework/test/context/aot/samples/basic/DisabledInAotRuntimeMethodLevelTests__TestContext006_BeanFactoryRegistrations.java"
};

Stream<Class<?>> scan() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@
import org.springframework.test.context.aot.samples.basic.BasicSpringJupiterTests;
import org.springframework.test.context.aot.samples.basic.BasicSpringTestNGTests;
import org.springframework.test.context.aot.samples.basic.BasicSpringVintageTests;
import org.springframework.test.context.aot.samples.basic.DisabledInAotProcessingTests;
import org.springframework.test.context.aot.samples.basic.DisabledInAotRuntimeClassLevelTests;
import org.springframework.test.context.aot.samples.basic.DisabledInAotRuntimeMethodLevelTests;

import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.platform.engine.discovery.DiscoverySelectors.selectClass;
Expand Down Expand Up @@ -98,15 +101,20 @@ void endToEndTests() {
// .printFiles(System.out)
.compile(compiled ->
// AOT RUN-TIME: EXECUTION
runTestsInAotMode(6, List.of(
BasicSpringJupiterSharedConfigTests.class,
BasicSpringJupiterTests.class, // NestedTests get executed automatically
runTestsInAotMode(7, List.of(
// The #s represent how many tests should run from each test class, which
// must add up to the expectedNumTests above.
/* 1 */ BasicSpringJupiterSharedConfigTests.class,
/* 2 */ BasicSpringJupiterTests.class, // NestedTests get executed automatically
// Run @Import tests AFTER the tests with otherwise identical config
// in order to ensure that the other test classes are not accidentally
// using the config for the @Import tests.
BasicSpringJupiterImportedConfigTests.class,
BasicSpringTestNGTests.class,
BasicSpringVintageTests.class)));
/* 1 */ BasicSpringJupiterImportedConfigTests.class,
/* 1 */ BasicSpringTestNGTests.class,
/* 1 */ BasicSpringVintageTests.class,
/* 0 */ DisabledInAotProcessingTests.class,
/* 0 */ DisabledInAotRuntimeClassLevelTests.class,
/* 1 */ DisabledInAotRuntimeMethodLevelTests.class)));
}

@Disabled("Uncomment to run all Spring integration tests in `spring-test`")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -34,6 +34,9 @@
import org.springframework.test.context.aot.samples.basic.BasicSpringJupiterTests;
import org.springframework.test.context.aot.samples.basic.BasicSpringTestNGTests;
import org.springframework.test.context.aot.samples.basic.BasicSpringVintageTests;
import org.springframework.test.context.aot.samples.basic.DisabledInAotProcessingTests;
import org.springframework.test.context.aot.samples.basic.DisabledInAotRuntimeClassLevelTests;
import org.springframework.test.context.aot.samples.basic.DisabledInAotRuntimeMethodLevelTests;
import org.springframework.util.ClassUtils;

import static org.assertj.core.api.Assertions.assertThat;
Expand All @@ -56,7 +59,10 @@ void process(@TempDir(cleanup = CleanupMode.ON_SUCCESS) Path tempDir) throws Exc
BasicSpringJupiterTests.class,
BasicSpringJupiterTests.NestedTests.class,
BasicSpringTestNGTests.class,
BasicSpringVintageTests.class
BasicSpringVintageTests.class,
DisabledInAotProcessingTests.class,
DisabledInAotRuntimeClassLevelTests.class,
DisabledInAotRuntimeMethodLevelTests.class
).forEach(testClass -> copy(testClass, classpathRoot));

Set<Path> classpathRoots = Set.of(classpathRoot);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -23,6 +23,9 @@
import org.springframework.test.context.aot.samples.basic.BasicSpringJupiterTests;
import org.springframework.test.context.aot.samples.basic.BasicSpringTestNGTests;
import org.springframework.test.context.aot.samples.basic.BasicSpringVintageTests;
import org.springframework.test.context.aot.samples.basic.DisabledInAotProcessingTests;
import org.springframework.test.context.aot.samples.basic.DisabledInAotRuntimeClassLevelTests;
import org.springframework.test.context.aot.samples.basic.DisabledInAotRuntimeMethodLevelTests;

import static org.assertj.core.api.Assertions.assertThat;

Expand All @@ -42,17 +45,26 @@ void scanBasicTestClasses() {
BasicSpringJupiterSharedConfigTests.class,
BasicSpringJupiterTests.class,
BasicSpringJupiterTests.NestedTests.class,
BasicSpringTestNGTests.class,
BasicSpringVintageTests.class,
BasicSpringTestNGTests.class
DisabledInAotProcessingTests.class,
DisabledInAotRuntimeClassLevelTests.class,
DisabledInAotRuntimeMethodLevelTests.class
);
}

@Test
void scanTestSuitesForJupiter() {
assertThat(scan("org.springframework.test.context.aot.samples.suites.jupiter"))
.containsExactlyInAnyOrder(BasicSpringJupiterImportedConfigTests.class,
BasicSpringJupiterSharedConfigTests.class, BasicSpringJupiterTests.class,
BasicSpringJupiterTests.NestedTests.class);
.containsExactlyInAnyOrder(
BasicSpringJupiterImportedConfigTests.class,
BasicSpringJupiterSharedConfigTests.class,
BasicSpringJupiterTests.class,
BasicSpringJupiterTests.NestedTests.class,
DisabledInAotProcessingTests.class,
DisabledInAotRuntimeClassLevelTests.class,
DisabledInAotRuntimeMethodLevelTests.class
);
}

@Test
Expand All @@ -76,7 +88,10 @@ void scanTestSuitesForAllTestEngines() {
BasicSpringJupiterTests.class,
BasicSpringJupiterTests.NestedTests.class,
BasicSpringVintageTests.class,
BasicSpringTestNGTests.class
BasicSpringTestNGTests.class,
DisabledInAotProcessingTests.class,
DisabledInAotRuntimeClassLevelTests.class,
DisabledInAotRuntimeMethodLevelTests.class
);
}

Expand All @@ -88,7 +103,10 @@ void scanTestSuitesWithNestedSuites() {
BasicSpringJupiterSharedConfigTests.class,
BasicSpringJupiterTests.class,
BasicSpringJupiterTests.NestedTests.class,
BasicSpringVintageTests.class
BasicSpringVintageTests.class,
DisabledInAotProcessingTests.class,
DisabledInAotRuntimeClassLevelTests.class,
DisabledInAotRuntimeMethodLevelTests.class
);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.test.context.aot.samples.basic;

import org.junit.jupiter.api.Test;

import org.springframework.aot.AotDetector;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.config.BeanFactoryPostProcessor;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.test.context.aot.DisabledInAotMode;
import org.springframework.test.context.aot.TestContextAotGenerator;
import org.springframework.test.context.junit.jupiter.SpringJUnitConfig;
import org.springframework.util.Assert;

import static org.assertj.core.api.Assertions.assertThat;

/**
* {@code @DisabledInAotMode} test class which verifies that the application context
* for the test class is skipped during AOT processing.
*
* @author Sam Brannen
* @since 6.1
*/
@SpringJUnitConfig
@DisabledInAotMode
public class DisabledInAotProcessingTests {

@Test
void disabledInAotMode(@Autowired String enigma) {
assertThat(AotDetector.useGeneratedArtifacts()).as("Should be disabled in AOT mode").isFalse();
assertThat(enigma).isEqualTo("puzzle");
}

@Configuration
static class Config {

@Bean
String enigma() {
return "puzzle";
}

@Bean
static BeanFactoryPostProcessor bfppBrokenDuringAotProcessing() {
boolean runningDuringAotProcessing = StackWalker.getInstance().walk(stream ->
stream.anyMatch(stackFrame -> stackFrame.getClassName().equals(TestContextAotGenerator.class.getName())));

return beanFactory -> Assert.state(!runningDuringAotProcessing, "Should not be used during AOT processing");
}
}

}
Loading

0 comments on commit 39a282e

Please sign in to comment.