Skip to content

Commit

Permalink
Reuse real bean definitions for test bean overrides
Browse files Browse the repository at this point in the history
But create new bean definition for nonexistent bean definitions based
on BeanOverrideFactoryBean.

See spring-projectsgh-32933
  • Loading branch information
sbrannen committed Sep 24, 2024
1 parent 11a3a42 commit b0d3fce
Show file tree
Hide file tree
Showing 6 changed files with 194 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
*
* @author Simon Baslé
* @author Stephane Nicoll
* @author Sam Brannen
* @since 6.2
*/
class BeanOverrideBeanFactoryPostProcessor implements BeanFactoryPostProcessor, Ordered {
Expand Down Expand Up @@ -155,21 +156,41 @@ else if (enforceExistingDefinition) {
beanNameIncludingFactory = beanName;
}

if (existingBeanDefinition != null) {
copyBeanDefinitionDetails(existingBeanDefinition, beanDefinition);
registry.removeBeanDefinition(beanName);
if (existingBeanDefinition instanceof RootBeanDefinition rbd) {
rbd.setQualifiedElement(overrideMetadata.getField());
}

// BeanOverrideFactoryBean registered during AOT processing?
boolean beanOverrideFactoryBeanRegistered = ((existingBeanDefinition != null) &&
BeanOverrideFactoryBean.class.getName().equals(existingBeanDefinition.getBeanClassName()));

// Need to register a BeanDefinition for a nonexistent bean?
if (!beanOverrideFactoryBeanRegistered && (existingBeanDefinition == null)) {
convertToBeanOverrideFactoryBeanDefinition(beanDefinition, beanName, overrideMetadata);

beanOverrideFactoryBeanRegistered = true;
registry.registerBeanDefinition(beanName, beanDefinition);
}
registry.registerBeanDefinition(beanName, beanDefinition);

Object override = overrideMetadata.createOverride(beanName, existingBeanDefinition, null);
if (beanFactory.isSingleton(beanNameIncludingFactory)) {
boolean isFactoryBean = beanFactory.isFactoryBean(beanName);
if (!beanOverrideFactoryBeanRegistered &&
(beanFactory.isSingleton(beanNameIncludingFactory) || isFactoryBean)) {

// Need to remove singleton registration of FactoryBean?
if (isFactoryBean && existingBeanDefinition != null) {
registry.removeBeanDefinition(beanName);
registry.registerBeanDefinition(beanName, existingBeanDefinition);
}

// Now we have an instance (the override) that we can register.
// At this stage we don't expect a singleton instance to be present,
// and this call will throw if there is such an instance already.
beanFactory.registerSingleton(beanName, override);
}

overrideMetadata.track(override, beanFactory);
this.overrideRegistrar.registerBeanInstance(beanName, override);
this.overrideRegistrar.registerNameForMetadata(overrideMetadata, beanNameIncludingFactory);
}

Expand Down Expand Up @@ -270,6 +291,17 @@ private Set<String> getExistingBeanNamesByType(ConfigurableListableBeanFactory b
}


private static void convertToBeanOverrideFactoryBeanDefinition(RootBeanDefinition beanDefinition, String beanName,
OverrideMetadata overrideMetadata) {

beanDefinition.setBeanClass(BeanOverrideFactoryBean.class);
beanDefinition.setTargetType(ResolvableType.forClassWithGenerics(BeanOverrideFactoryBean.class, overrideMetadata.getBeanType()));
beanDefinition.setAttribute(FactoryBean.OBJECT_TYPE_ATTRIBUTE, overrideMetadata.getBeanType());
beanDefinition.getConstructorArgumentValues().addIndexedArgumentValue(0, beanName);
beanDefinition.getConstructorArgumentValues().addIndexedArgumentValue(1, overrideMetadata.getBeanType().resolve());
}


static class WrapEarlyBeanPostProcessor implements SmartInstantiationAwareBeanPostProcessor,
PriorityOrdered {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.util.function.Consumer;

import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.beans.factory.config.ConstructorArgumentValues;
import org.springframework.beans.factory.config.RuntimeBeanReference;
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
Expand All @@ -35,6 +36,7 @@
*
* @author Simon Baslé
* @author Stephane Nicoll
* @author Sam Brannen
* @since 6.2
*/
class BeanOverrideContextCustomizer implements ContextCustomizer {
Expand Down Expand Up @@ -62,6 +64,11 @@ public void customizeContext(ConfigurableApplicationContext context, MergedConte
"that doesn't implement BeanDefinitionRegistry: " + context.getClass());
}
registerInfrastructure(registry);

ConfigurableListableBeanFactory beanFactory = context.getBeanFactory();
BeanOverrideRegistrar beanOverrideRegistrar = beanFactory.getBean(REGISTRAR_BEAN_NAME, BeanOverrideRegistrar.class);
beanFactory.registerSingleton(INFRASTRUCTURE_BEAN_NAME,
new BeanOverrideBeanFactoryPostProcessor(this.metadata, beanOverrideRegistrar));
}

Set<OverrideMetadata> getMetadata() {
Expand All @@ -75,11 +82,6 @@ private void registerInfrastructure(BeanDefinitionRegistry registry) {
RuntimeBeanReference registrarReference = new RuntimeBeanReference(REGISTRAR_BEAN_NAME);
addInfrastructureBeanDefinition(registry, WrapEarlyBeanPostProcessor.class, EARLY_INFRASTRUCTURE_BEAN_NAME,
constructorArgs -> constructorArgs.addIndexedArgumentValue(0, registrarReference));
addInfrastructureBeanDefinition(registry, BeanOverrideBeanFactoryPostProcessor.class, INFRASTRUCTURE_BEAN_NAME,
constructorArgs -> {
constructorArgs.addIndexedArgumentValue(0, this.metadata);
constructorArgs.addIndexedArgumentValue(1, registrarReference);
});
}

private void addInfrastructureBeanDefinition(BeanDefinitionRegistry registry,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/*
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.test.context.bean.override;

import org.springframework.beans.BeansException;
import org.springframework.beans.factory.BeanFactory;
import org.springframework.beans.factory.BeanFactoryAware;
import org.springframework.beans.factory.FactoryBean;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;

/**
* {@link FactoryBean} that retrieves a bean override instance from the
* {@link BeanOverrideRegistrar} registered in the {@link BeanFactory}.
*
* @author Sam Brannen
* @since 6.2
* @param <T> the type of bean override instance returned from this factory
*/
@SuppressWarnings("rawtypes")
class BeanOverrideFactoryBean<T> implements FactoryBean<T>, BeanFactoryAware {

private final String beanName;

private final Class<T> beanType;

@Nullable
private BeanOverrideRegistrar beanOverrideRegistrar;


/**
* Create a new {@code BeanOverrideFactoryBean} for the given bean override
* name and type.
* @param beanName the name of the bean override instance
* @param beanType the type of the bean override instance
*/
BeanOverrideFactoryBean(String beanName, Class<T> beanType) {
this.beanName = beanName;
this.beanType = beanType;
}


@Override
public void setBeanFactory(BeanFactory beanFactory) throws BeansException {
this.beanOverrideRegistrar = beanFactory.getBean(BeanOverrideRegistrar.class);
}

@Override
public Class<T> getObjectType() {
return this.beanType;
}

@Override
public T getObject() throws Exception {
Assert.notNull(this.beanOverrideRegistrar, "BeanOverrideRegistrar must be available");
return this.beanOverrideRegistrar.getBeanInstance(this.beanName, this.beanType);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,15 @@
* for test execution listeners.
*
* @author Simon Baslé
* @author Sam Brannen
* @since 6.2
*/
class BeanOverrideRegistrar implements BeanFactoryAware {

private final Map<OverrideMetadata, String> beanNameRegistry = new HashMap<>();

private final Map<String, Object> beanInstanceRegistry = new HashMap<>();

private final Map<String, OverrideMetadata> earlyOverrideMetadata = new HashMap<>();

@Nullable
Expand Down Expand Up @@ -80,6 +83,32 @@ void registerNameForMetadata(OverrideMetadata metadata, String beanName) {
this.beanNameRegistry.put(metadata, beanName);
}

/**
* Register the provided bean override instance under the supplied {@code beanName}.
* @see #getBeanInstance(String, Class)
*/
void registerBeanInstance(String beanName, Object instance) {
this.beanInstanceRegistry.put(beanName, instance);
}

/**
* Get the bean override instance registered under the supplied {@code beanName}.
* @param beanName the name of the bean override instance
* @param requiredType the required type of the bean override instance
* @return the corresponding bean override instance
* @throws IllegalArgumentException if no bean override instance has been
* registered under the supplied name
* @throws ClassCastException if the bean override instance is not of the
* required type
* @see #registerBeanInstance(String, Object)
*/
<T> T getBeanInstance(String beanName, Class<T> requiredType) {
Assert.isTrue(this.beanInstanceRegistry.containsKey(beanName),
() -> "Bean instance registry does not contain an entry for bean with name '%s'"
.formatted(beanName));
return requiredType.cast(this.beanInstanceRegistry.get(beanName));
}

/**
* Mark the provided {@link OverrideMetadata} and {@code beanName} as "wrap
* early", allowing for later bean override using {@link #wrapIfNecessary(Object, String)}.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import java.util.Set;
import java.util.stream.Stream;

import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.platform.engine.TestSource;
import org.junit.platform.engine.support.descriptor.ClassSource;
Expand Down Expand Up @@ -122,13 +121,14 @@ void endToEndTests() {
/* 1 */ DisabledInAotRuntimeMethodLevelTests.class)));
}

@Disabled("Comment out to run all integration tests in spring-test in AOT mode")
// @Disabled("Comment out to run all integration tests in spring-test in AOT mode")
@Test
void endToEndTestsForEntireSpringTestModule() {
// AOT BUILD-TIME: CLASSPATH SCANNING
List<Class<?>> testClasses = createTestClassScanner()
// Scan all base packages in spring-test.
.scan("org.springframework.mock", "org.springframework.test")
// .scan("org.springframework.mock", "org.springframework.test")
.scan("org.springframework.test.context.bean.override")
// Or limit execution to a particular package and its subpackages.
// - For example, to test JDBC support:
// .scan("org.springframework.test.context.jdbc")
Expand All @@ -142,13 +142,14 @@ void endToEndTestsForEntireSpringTestModule() {
.toList();

// Optionally set failOnError flag to true to halt processing at the first failure.
runEndToEndTests(testClasses, false);
runEndToEndTests(testClasses, true);
}

@Disabled("Comment out to run @TestBean integration tests in AOT mode")
// @Disabled("Comment out to run @TestBean integration tests in AOT mode")
@Test
void endToEndTestsForTestBeanOverrideTestClasses() {
List<Class<?>> testClasses = List.of(
org.springframework.test.context.bean.override.convention.TestBeanForByTypeLookupIntegrationTests.class,
org.springframework.test.context.aot.samples.bean.override.convention.TestBeanJupiterTests.class,
org.springframework.test.context.bean.override.convention.TestBeanForByNameLookupIntegrationTests.class,
org.springframework.test.context.bean.override.convention.TestBeanForByNameLookupIntegrationTests.TestBeanFieldInEnclosingClassTests.class,
Expand All @@ -160,14 +161,20 @@ void endToEndTestsForTestBeanOverrideTestClasses() {
runEndToEndTests(testClasses, true);
}

@Disabled("Comment out to run selected integration tests in AOT mode")
// @Disabled("Comment out to run selected integration tests in AOT mode")
@Test
void endToEndTestsForSelectedTestClasses() {
List<Class<?>> testClasses = List.of(
org.springframework.test.context.bean.override.easymock.EasyMockBeanIntegrationTests.class,
org.springframework.test.context.bean.override.mockito.MockitoBeanForByNameLookupIntegrationTests.class,
org.springframework.test.context.junit4.SpringJUnit4ClassRunnerAppCtxTests.class,
org.springframework.test.context.junit4.ParameterizedDependencyInjectionTests.class
org.springframework.test.context.bean.override.mockito.MockitoBeanWithResetIntegrationTests.class,
org.springframework.test.context.bean.override.mockito.MockitoBeanForBeanFactoryIntegrationTests.class
// org.springframework.test.context.aot.samples.bean.override.convention.TestBeanJupiterTests.class,
// org.springframework.test.context.bean.override.easymock.EasyMockBeanIntegrationTests.class,
// org.springframework.test.context.bean.override.mockito.MockitoBeanForByNameLookupIntegrationTests.class,
// org.springframework.test.context.bean.override.mockito.MockitoBeanForByNameLookupIntegrationTests.MockitoBeanNestedTests.class,
// org.springframework.test.context.bean.override.mockito.MockitoSpyBeanForByNameLookupIntegrationTests.class,
// org.springframework.test.context.bean.override.mockito.MockitoSpyBeanForByNameLookupIntegrationTests.MockitoSpyBeanNestedTests.class
// org.springframework.test.context.junit4.SpringJUnit4ClassRunnerAppCtxTests.class,
// org.springframework.test.context.junit4.ParameterizedDependencyInjectionTests.class
);

runEndToEndTests(testClasses, true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
*
* @author Simon Baslé
* @author Stephane Nicoll
* @author Sam Brannen
*/
class BeanOverrideBeanFactoryPostProcessorTests {

Expand Down Expand Up @@ -210,7 +211,7 @@ void allowReplaceDefinitionWhenSingletonDefinitionPresent() {
}

@Test
void copyDefinitionPrimaryFallbackAndScope() {
void beanDefinitionIsRetainedForNonFactoryBean() {
AnnotationConfigApplicationContext context = createContext(CaseByName.class);
context.getBeanFactory().registerScope("customScope", new SimpleThreadScope());
RootBeanDefinition definition = new RootBeanDefinition(String.class, () -> "ORIGINAL");
Expand All @@ -219,9 +220,23 @@ void copyDefinitionPrimaryFallbackAndScope() {
definition.setFallback(true);
context.registerBeanDefinition("descriptionBean", definition);

assertThatNoException().isThrownBy(context::refresh);
assertThat(context.getBeanDefinition("descriptionBean")).isSameAs(definition);
}

@Test
void primaryFallbackAndScopeAreCopiedForFactoryBean() {
AnnotationConfigApplicationContext context = createContext(CaseByName.class);
context.getBeanFactory().registerScope("customScope", new SimpleThreadScope());
RootBeanDefinition definition = new RootBeanDefinition(StringFactoryBean.class);
definition.setScope("customScope");
definition.setPrimary(true);
definition.setFallback(true);
context.registerBeanDefinition("descriptionBean", definition);

assertThatNoException().isThrownBy(context::refresh);
assertThat(context.getBeanDefinition("descriptionBean"))
.isNotSameAs(definition)
.isSameAs(definition)
.matches(BeanDefinition::isPrimary, "isPrimary")
.matches(BeanDefinition::isFallback, "isFallback")
.satisfies(d -> assertThat(d.getScope()).isEqualTo("customScope"))
Expand Down Expand Up @@ -327,6 +342,20 @@ public boolean isSingleton() {
}
}

static class StringFactoryBean implements FactoryBean<String> {

@Override
public String getObject() throws Exception {
return "enigma";
}

@Override
public Class<?> getObjectType() {
return String.class;
}

}

static class FactoryBeanRegisteringPostProcessor implements BeanFactoryPostProcessor, Ordered {

@Override
Expand Down

0 comments on commit b0d3fce

Please sign in to comment.