From 2b1535f539bfe95941d09d82305cf24b11a8803c Mon Sep 17 00:00:00 2001 From: Dmytro Nosan Date: Fri, 25 Oct 2024 15:00:10 +0300 Subject: [PATCH] TestcontainersBeanRegistrationAotProcessor that replaces InstanceSupplier of Container by a reflection equivalent --- .../ImportTestcontainersTests.java | 90 +++++++++++++++ .../TestcontainerFieldBeanDefinition.java | 5 +- ...ontainersBeanRegistrationAotProcessor.java | 109 ++++++++++++++++++ .../TestcontainersPropertySource.java | 11 ++ .../resources/META-INF/spring/aot.factories | 6 +- 5 files changed, 218 insertions(+), 3 deletions(-) create mode 100644 spring-boot-project/spring-boot-testcontainers/src/main/java/org/springframework/boot/testcontainers/context/TestcontainersBeanRegistrationAotProcessor.java diff --git a/spring-boot-project/spring-boot-testcontainers/src/dockerTest/java/org/springframework/boot/testcontainers/ImportTestcontainersTests.java b/spring-boot-project/spring-boot-testcontainers/src/dockerTest/java/org/springframework/boot/testcontainers/ImportTestcontainersTests.java index c3d0bd43703b..7ccb0cb3c6cf 100644 --- a/spring-boot-project/spring-boot-testcontainers/src/dockerTest/java/org/springframework/boot/testcontainers/ImportTestcontainersTests.java +++ b/spring-boot-project/spring-boot-testcontainers/src/dockerTest/java/org/springframework/boot/testcontainers/ImportTestcontainersTests.java @@ -18,17 +18,26 @@ import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; +import java.util.function.BiConsumer; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Test; import org.testcontainers.containers.Container; import org.testcontainers.containers.PostgreSQLContainer; +import org.springframework.aot.test.generate.TestGenerationContext; import org.springframework.boot.testcontainers.beans.TestcontainerBeanDefinition; import org.springframework.boot.testcontainers.context.ImportTestcontainers; import org.springframework.boot.testsupport.container.DisabledIfDockerUnavailable; import org.springframework.boot.testsupport.container.TestImage; +import org.springframework.context.ApplicationContextInitializer; import org.springframework.context.annotation.AnnotationConfigApplicationContext; +import org.springframework.context.aot.ApplicationContextAotGenerator; +import org.springframework.context.support.GenericApplicationContext; +import org.springframework.core.test.tools.CompileWithForkedClassLoader; +import org.springframework.core.test.tools.Compiled; +import org.springframework.core.test.tools.TestCompiler; +import org.springframework.javapoet.ClassName; import org.springframework.test.context.DynamicPropertyRegistry; import org.springframework.test.context.DynamicPropertySource; @@ -43,6 +52,8 @@ @DisabledIfDockerUnavailable class ImportTestcontainersTests { + private final TestGenerationContext generationContext = new TestGenerationContext(); + private AnnotationConfigApplicationContext applicationContext; @AfterEach @@ -122,6 +133,70 @@ void importWhenHasBadArgsDynamicPropertySourceMethod() { .withMessage("@DynamicPropertySource method 'containerProperties' must be static"); } + @Test + @CompileWithForkedClassLoader + void importTestcontainersImportWithoutValueAotContributionRegistersTestcontainers() { + this.applicationContext = new AnnotationConfigApplicationContext(); + this.applicationContext.register(ImportWithoutValue.class); + compile((freshContext, compiled) -> { + PostgreSQLContainer container = freshContext.getBean(PostgreSQLContainer.class); + assertThat(container).isSameAs(ImportWithoutValue.container); + }); + } + + @Test + @CompileWithForkedClassLoader + void importTestcontainersImportWithValueAotContributionRegistersTestcontainers() { + this.applicationContext = new AnnotationConfigApplicationContext(); + this.applicationContext.register(ImportWithValue.class); + compile((freshContext, compiled) -> { + PostgreSQLContainer container = freshContext.getBean(PostgreSQLContainer.class); + assertThat(container).isSameAs(ContainerDefinitions.container); + }); + } + + @Test + @CompileWithForkedClassLoader + void importTestcontainersWithDynamicPropertySourceAotContributionRegistersTestcontainers() { + this.applicationContext = new AnnotationConfigApplicationContext(); + this.applicationContext.register(ContainerDefinitionsWithDynamicPropertySource.class); + compile((freshContext, compiled) -> { + PostgreSQLContainer container = freshContext.getBean(PostgreSQLContainer.class); + assertThat(container).isSameAs(ContainerDefinitionsWithDynamicPropertySource.container); + }); + } + + @Test + @CompileWithForkedClassLoader + void importTestcontainersWithCustomPostgreSQLContainerAotContributionRegistersTestcontainers() { + this.applicationContext = new AnnotationConfigApplicationContext(); + this.applicationContext.register(CustomPostgreSQLContainerDefinitions.class); + compile((freshContext, compiled) -> { + PostgreSQLContainer container = freshContext.getBean(PostgreSQLContainer.class); + assertThat(container).isSameAs(CustomPostgreSQLContainerDefinitions.container); + }); + } + + @SuppressWarnings("unchecked") + private void compile(BiConsumer result) { + ClassName className = processAheadOfTime(); + TestCompiler.forSystem().with(this.generationContext).compile((compiled) -> { + GenericApplicationContext freshApplicationContext = new GenericApplicationContext(); + ApplicationContextInitializer initializer = compiled + .getInstance(ApplicationContextInitializer.class, className.toString()); + initializer.initialize(freshApplicationContext); + freshApplicationContext.refresh(); + result.accept(freshApplicationContext, compiled); + }); + } + + private ClassName processAheadOfTime() { + ClassName className = new ApplicationContextAotGenerator().processAheadOfTime(this.applicationContext, + this.generationContext); + this.generationContext.writeGeneratedContent(); + return className; + } + @ImportTestcontainers static class ImportWithoutValue { @@ -196,4 +271,19 @@ void containerProperties() { } + @ImportTestcontainers + static class CustomPostgreSQLContainerDefinitions { + + static CustomPostgreSQLContainer container = new CustomPostgreSQLContainer(); + + } + + static class CustomPostgreSQLContainer extends PostgreSQLContainer { + + CustomPostgreSQLContainer() { + super("postgres:14"); + } + + } + } diff --git a/spring-boot-project/spring-boot-testcontainers/src/main/java/org/springframework/boot/testcontainers/context/TestcontainerFieldBeanDefinition.java b/spring-boot-project/spring-boot-testcontainers/src/main/java/org/springframework/boot/testcontainers/context/TestcontainerFieldBeanDefinition.java index c5cf32d4b1aa..2f81ad8a7ced 100644 --- a/spring-boot-project/spring-boot-testcontainers/src/main/java/org/springframework/boot/testcontainers/context/TestcontainerFieldBeanDefinition.java +++ b/spring-boot-project/spring-boot-testcontainers/src/main/java/org/springframework/boot/testcontainers/context/TestcontainerFieldBeanDefinition.java @@ -1,5 +1,5 @@ /* - * Copyright 2012-2023 the original author or authors. + * Copyright 2012-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -38,9 +38,10 @@ class TestcontainerFieldBeanDefinition extends RootBeanDefinition implements Tes TestcontainerFieldBeanDefinition(Field field, Container container) { this.container = container; this.annotations = MergedAnnotations.from(field); - this.setBeanClass(container.getClass()); + setBeanClass(container.getClass()); setInstanceSupplier(() -> container); setRole(ROLE_INFRASTRUCTURE); + setAttribute(TestcontainerFieldBeanDefinition.class.getName(), field); } @Override diff --git a/spring-boot-project/spring-boot-testcontainers/src/main/java/org/springframework/boot/testcontainers/context/TestcontainersBeanRegistrationAotProcessor.java b/spring-boot-project/spring-boot-testcontainers/src/main/java/org/springframework/boot/testcontainers/context/TestcontainersBeanRegistrationAotProcessor.java new file mode 100644 index 000000000000..f520d6214b93 --- /dev/null +++ b/spring-boot-project/spring-boot-testcontainers/src/main/java/org/springframework/boot/testcontainers/context/TestcontainersBeanRegistrationAotProcessor.java @@ -0,0 +1,109 @@ +/* + * Copyright 2012-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.boot.testcontainers.context; + +import java.lang.reflect.Field; + +import javax.lang.model.element.Modifier; + +import org.testcontainers.containers.Container; + +import org.springframework.aot.generate.GeneratedMethod; +import org.springframework.aot.generate.GenerationContext; +import org.springframework.beans.factory.aot.BeanRegistrationAotContribution; +import org.springframework.beans.factory.aot.BeanRegistrationAotProcessor; +import org.springframework.beans.factory.aot.BeanRegistrationCode; +import org.springframework.beans.factory.aot.BeanRegistrationCodeFragments; +import org.springframework.beans.factory.aot.BeanRegistrationCodeFragmentsDecorator; +import org.springframework.beans.factory.support.InstanceSupplier; +import org.springframework.beans.factory.support.RegisteredBean; +import org.springframework.beans.factory.support.RootBeanDefinition; +import org.springframework.javapoet.ClassName; +import org.springframework.javapoet.CodeBlock; +import org.springframework.util.Assert; +import org.springframework.util.ClassUtils; +import org.springframework.util.ReflectionUtils; + +/** + * {@link BeanRegistrationAotProcessor} that replaces InstanceSupplier of + * {@link Container} by a reflection equivalent. + * + * @author Dmytro Nosan + */ +class TestcontainersBeanRegistrationAotProcessor implements BeanRegistrationAotProcessor { + + @Override + public BeanRegistrationAotContribution processAheadOfTime(RegisteredBean registeredBean) { + RootBeanDefinition bd = registeredBean.getMergedBeanDefinition(); + String attributeName = TestcontainerFieldBeanDefinition.class.getName(); + Object field = bd.getAttribute(attributeName); + if (field != null) { + Assert.isInstanceOf(Field.class, field, + "BeanDefinition attribute '" + attributeName + "' value must be a type of '" + Field.class + "'"); + return BeanRegistrationAotContribution.withCustomCodeFragments( + (codeFragments) -> new AotContribution(codeFragments, registeredBean, ((Field) field))); + } + return null; + } + + static class AotContribution extends BeanRegistrationCodeFragmentsDecorator { + + private final RegisteredBean registeredBean; + + private final Field field; + + AotContribution(BeanRegistrationCodeFragments delegate, RegisteredBean registeredBean, Field field) { + super(delegate); + this.registeredBean = registeredBean; + this.field = field; + } + + @Override + public ClassName getTarget(RegisteredBean registeredBean) { + return ClassName.get(this.registeredBean.getBeanClass()); + } + + @Override + public CodeBlock generateInstanceSupplierCode(GenerationContext generationContext, + BeanRegistrationCode beanRegistrationCode, boolean allowDirectSupplierShortcut) { + Class beanClass = this.registeredBean.getBeanClass(); + Class testClass = this.field.getDeclaringClass(); + String fieldName = this.field.getName(); + GeneratedMethod generatedMethod = beanRegistrationCode.getMethods() + .add("getInstance", (method) -> method + .addJavadoc("Get the bean instance for '$L'.", this.registeredBean.getBeanName()) + .addModifiers(Modifier.PRIVATE, Modifier.STATIC) + .returns(beanClass) + .addStatement("$T testClass = $T.forName($S, null)", Class.class, ClassUtils.class, + testClass.getName()) + .addStatement("$T field = $T.findField(testClass, $S)", Field.class, ReflectionUtils.class, + fieldName) + .addStatement("$T.notNull(field, $S)", Assert.class, "Field '" + fieldName + "' is not found") + .addStatement("$T.makeAccessible(field)", ReflectionUtils.class) + .addStatement("$T container = ($T) $T.getField(field, null)", beanClass, beanClass, + ReflectionUtils.class) + .addStatement("$T.notNull(container, $S)", Assert.class, + "Container field '" + fieldName + "' must not have a null value") + .addStatement("return container") + .addException(ClassNotFoundException.class)); + return CodeBlock.of("$T.using($T::$L)", InstanceSupplier.class, beanRegistrationCode.getClassName(), + generatedMethod.getName()); + } + + } + +} diff --git a/spring-boot-project/spring-boot-testcontainers/src/main/java/org/springframework/boot/testcontainers/properties/TestcontainersPropertySource.java b/spring-boot-project/spring-boot-testcontainers/src/main/java/org/springframework/boot/testcontainers/properties/TestcontainersPropertySource.java index f1ecfe878c80..d49ef3413bce 100644 --- a/spring-boot-project/spring-boot-testcontainers/src/main/java/org/springframework/boot/testcontainers/properties/TestcontainersPropertySource.java +++ b/spring-boot-project/spring-boot-testcontainers/src/main/java/org/springframework/boot/testcontainers/properties/TestcontainersPropertySource.java @@ -26,9 +26,11 @@ import org.testcontainers.containers.Container; import org.springframework.beans.BeansException; +import org.springframework.beans.factory.aot.BeanRegistrationExcludeFilter; import org.springframework.beans.factory.config.BeanFactoryPostProcessor; import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; import org.springframework.beans.factory.support.BeanDefinitionRegistry; +import org.springframework.beans.factory.support.RegisteredBean; import org.springframework.beans.factory.support.RootBeanDefinition; import org.springframework.context.ApplicationEventPublisher; import org.springframework.context.ApplicationEventPublisherAware; @@ -166,4 +168,13 @@ public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) } + static class TestcontainersEventPublisherBeanRegistrationExcludeFilter implements BeanRegistrationExcludeFilter { + + @Override + public boolean isExcludedFromAotProcessing(RegisteredBean registeredBean) { + return EventPublisherRegistrar.NAME.equals(registeredBean.getBeanName()); + } + + } + } diff --git a/spring-boot-project/spring-boot-testcontainers/src/main/resources/META-INF/spring/aot.factories b/spring-boot-project/spring-boot-testcontainers/src/main/resources/META-INF/spring/aot.factories index 5b3d49bd5020..61ff6cf6d122 100644 --- a/spring-boot-project/spring-boot-testcontainers/src/main/resources/META-INF/spring/aot.factories +++ b/spring-boot-project/spring-boot-testcontainers/src/main/resources/META-INF/spring/aot.factories @@ -1,5 +1,9 @@ org.springframework.beans.factory.aot.BeanRegistrationExcludeFilter=\ -org.springframework.boot.testcontainers.service.connection.ConnectionDetailsRegistrar.ServiceConnectionBeanRegistrationExcludeFilter +org.springframework.boot.testcontainers.service.connection.ConnectionDetailsRegistrar.ServiceConnectionBeanRegistrationExcludeFilter,\ +org.springframework.boot.testcontainers.properties.TestcontainersPropertySource.TestcontainersEventPublisherBeanRegistrationExcludeFilter org.springframework.aot.hint.RuntimeHintsRegistrar=\ org.springframework.boot.testcontainers.service.connection.ContainerConnectionDetailsFactory.ContainerConnectionDetailsFactoriesRuntimeHints + +org.springframework.beans.factory.aot.BeanRegistrationAotProcessor=\ +org.springframework.boot.testcontainers.context.TestcontainersBeanRegistrationAotProcessor