Skip to content

Commit

Permalink
GH-8703: Fix MessagingAnnotationPP for AOT
Browse files Browse the repository at this point in the history
Fixes #8703

* Instantiate a `MessagingAnnotationBeanPostProcessor` via factory method from `MessagingAnnotationPostProcessor`
avoiding extra code generation on an explicitly provided complex `Map` for bean definition property
* Fix test to react properly to a new logic of `MessagingAnnotationBeanPostProcessor` bean registration
  • Loading branch information
EddieChoCho authored and artembilan committed Aug 16, 2023
1 parent ba6d35d commit ef5db30
Show file tree
Hide file tree
Showing 12 changed files with 73 additions and 71 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package org.springframework.integration.config;

import java.beans.Introspector;

import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.support.BeanDefinitionBuilder;
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
Expand Down Expand Up @@ -99,19 +101,28 @@ private void registerIntegrationConfigurationBeanFactoryPostProcessor(BeanDefini

/**
* Register {@link MessagingAnnotationPostProcessor} and
* {@link org.springframework.integration.aop.PublisherAnnotationBeanPostProcessor},
* {@link MessagingAnnotationBeanPostProcessor},
* if necessary.
* Inject {@code defaultPublishedChannel} from provided {@link AnnotationMetadata}, if any.
* @param registry The {@link BeanDefinitionRegistry} to register additional {@link BeanDefinition}s.
* @see MessagingAnnotationPostProcessor#messagingAnnotationBeanPostProcessor()
*/
private void registerMessagingAnnotationPostProcessors(BeanDefinitionRegistry registry) {
if (!registry.containsBeanDefinition(IntegrationContextUtils.MESSAGING_ANNOTATION_POSTPROCESSOR_NAME)) {
BeanDefinitionBuilder builder =
registry.registerBeanDefinition(IntegrationContextUtils.MESSAGING_ANNOTATION_POSTPROCESSOR_NAME,
BeanDefinitionBuilder.genericBeanDefinition(MessagingAnnotationPostProcessor.class)
.setRole(BeanDefinition.ROLE_INFRASTRUCTURE);
.setRole(BeanDefinition.ROLE_INFRASTRUCTURE)
.getBeanDefinition());
}

registry.registerBeanDefinition(IntegrationContextUtils.MESSAGING_ANNOTATION_POSTPROCESSOR_NAME,
builder.getBeanDefinition());

String beanName = Introspector.decapitalize(MessagingAnnotationBeanPostProcessor.class.getName());
if (!registry.containsBeanDefinition(beanName)) {
registry.registerBeanDefinition(beanName,
BeanDefinitionBuilder.genericBeanDefinition()
.setFactoryMethodOnBean("messagingAnnotationBeanPostProcessor",
IntegrationContextUtils.MESSAGING_ANNOTATION_POSTPROCESSOR_NAME)
.setRole(BeanDefinition.ROLE_INFRASTRUCTURE)
.getBeanDefinition());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,13 @@ public class MessagingAnnotationBeanPostProcessor

private final List<Runnable> methodsToPostProcessAfterContextInitialization = new ArrayList<>();

private final BeanDefinitionRegistry registry;

private ConfigurableListableBeanFactory beanFactory;

private volatile boolean initialized;

public MessagingAnnotationBeanPostProcessor(BeanDefinitionRegistry registry,
public MessagingAnnotationBeanPostProcessor(
Map<Class<? extends Annotation>, MethodAnnotationPostProcessor<?>> postProcessors) {

this.registry = registry;
this.postProcessors = postProcessors;
}

Expand Down Expand Up @@ -187,13 +184,12 @@ private void postProcessMethodAndRegisterEndpointIfAny(Object bean, String beanN

String endpointBeanName = generateBeanName(beanName, method, annotationType);
endpoint.setBeanName(endpointBeanName);
this.registry.registerBeanDefinition(endpointBeanName,
((BeanDefinitionRegistry) this.beanFactory).registerBeanDefinition(endpointBeanName,
new RootBeanDefinition((Class<AbstractEndpoint>) endpoint.getClass(), () -> endpoint));
this.beanFactory.getBean(endpointBeanName);
}
}


protected String generateBeanName(String originalBeanName, Method method,
Class<? extends Annotation> annotationType) {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

package org.springframework.integration.config;

import java.beans.Introspector;
import java.lang.annotation.Annotation;
import java.util.HashMap;
import java.util.List;
Expand All @@ -29,7 +28,6 @@
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.BeanPostProcessor;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.beans.factory.support.BeanDefinitionBuilder;
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor;
import org.springframework.beans.factory.support.BeanDefinitionValidationException;
Expand Down Expand Up @@ -91,14 +89,6 @@ public void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry) t
.map(BeanFactoryAware.class::cast)
.forEach((processor) -> processor.setBeanFactory((BeanFactory) this.registry));

this.registry.registerBeanDefinition(
Introspector.decapitalize(MessagingAnnotationBeanPostProcessor.class.getName()),
BeanDefinitionBuilder.rootBeanDefinition(MessagingAnnotationBeanPostProcessor.class)
.setRole(BeanDefinition.ROLE_INFRASTRUCTURE)
.addConstructorArgValue(this.registry)
.addConstructorArgValue(this.postProcessors)
.getBeanDefinition());

String[] beanNames = registry.getBeanDefinitionNames();

for (String beanName : beanNames) {
Expand All @@ -111,6 +101,16 @@ public void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry) t
}
}

/**
* The factory method for {@link MessagingAnnotationBeanPostProcessor} based
* on the environment from this {@link MessagingAnnotationPostProcessor}.
* @return the {@link MessagingAnnotationBeanPostProcessor} instance based on {@link #postProcessors}.
* @since 6.2
*/
public MessagingAnnotationBeanPostProcessor messagingAnnotationBeanPostProcessor() {
return new MessagingAnnotationBeanPostProcessor(this.postProcessors);
}

private void processCandidate(String beanName, AnnotatedBeanDefinition beanDefinition) {
MethodMetadata methodMetadata = beanDefinition.getFactoryMethodMetadata();
MergedAnnotations annotations = methodMetadata.getAnnotations(); // NOSONAR
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import org.springframework.integration.annotation.ServiceActivator;
import org.springframework.integration.channel.DirectChannel;
import org.springframework.integration.channel.QueueChannel;
import org.springframework.integration.config.MessagingAnnotationPostProcessor;
import org.springframework.integration.config.IntegrationRegistrar;
import org.springframework.integration.context.IntegrationContextUtils;
import org.springframework.integration.endpoint.EventDrivenConsumer;
import org.springframework.integration.handler.AbstractReplyProducingMessageHandler;
Expand All @@ -38,6 +38,7 @@

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.mockito.Mockito.mock;

/**
* @author Mark Fisher
Expand All @@ -54,7 +55,7 @@ public class DirectChannelSubscriptionTests {

@BeforeEach
public void setupChannels() {
this.context.registerBean(MessagingAnnotationPostProcessor.class);
new IntegrationRegistrar().registerBeanDefinitions(mock(), this.context.getDefaultListableBeanFactory());
this.context.registerChannel("sourceChannel", this.sourceChannel);
this.context.registerChannel("targetChannel", this.targetChannel);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -31,6 +31,7 @@
import org.springframework.messaging.support.GenericMessage;

import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.mock;

/**
* @author Mark Fisher
Expand All @@ -42,8 +43,7 @@ public class ServiceActivatorAnnotationPostProcessorTests {
public void testAnnotatedMethod() throws InterruptedException {
CountDownLatch latch = new CountDownLatch(1);
try (TestApplicationContext context = TestUtils.createTestApplicationContext()) {
RootBeanDefinition postProcessorDef = new RootBeanDefinition(MessagingAnnotationPostProcessor.class);
context.registerBeanDefinition("postProcessor", postProcessorDef);
new IntegrationRegistrar().registerBeanDefinitions(mock(), context.getDefaultListableBeanFactory());
context.registerBeanDefinition("testChannel", new RootBeanDefinition(DirectChannel.class));
RootBeanDefinition beanDefinition = new RootBeanDefinition(SimpleServiceActivatorAnnotationTestBean.class);
beanDefinition.getConstructorArgumentValues().addGenericArgumentValue(latch);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2017-2022 the original author or authors.
* Copyright 2017-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -33,7 +33,6 @@
import org.springframework.beans.DirectFieldAccessor;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.annotation.AnnotationUtils;
Expand Down Expand Up @@ -103,9 +102,7 @@ public void testLogAnnotation() {
public static class Config {

@Bean(name = IntegrationContextUtils.MESSAGING_ANNOTATION_POSTPROCESSOR_NAME)
public static MessagingAnnotationPostProcessor messagingAnnotationPostProcessor(
ConfigurableListableBeanFactory beanFactory) {

public static MessagingAnnotationPostProcessor messagingAnnotationPostProcessor() {
MessagingAnnotationPostProcessor messagingAnnotationPostProcessor = new MessagingAnnotationPostProcessor();
messagingAnnotationPostProcessor.
addMessagingAnnotationPostProcessor(Logging.class, new LogAnnotationPostProcessor());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
import org.springframework.integration.annotation.MessageEndpoint;
import org.springframework.integration.channel.DirectChannel;
import org.springframework.integration.channel.QueueChannel;
import org.springframework.integration.config.IntegrationRegistrar;
import org.springframework.integration.config.MessagingAnnotationBeanPostProcessor;
import org.springframework.integration.config.MessagingAnnotationPostProcessor;
import org.springframework.integration.endpoint.EventDrivenConsumer;
import org.springframework.integration.handler.advice.AbstractRequestHandlerAdvice;
import org.springframework.integration.test.util.TestUtils;
Expand All @@ -38,6 +38,7 @@

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.Mockito.mock;

/**
* @author Mark Fisher
Expand All @@ -56,7 +57,7 @@ public class FilterAnnotationPostProcessorTests {

@BeforeEach
public void init() {
this.context.registerBean(MessagingAnnotationPostProcessor.class);
new IntegrationRegistrar().registerBeanDefinitions(mock(), this.context.getDefaultListableBeanFactory());
this.context.registerChannel("input", this.inputChannel);
this.context.registerChannel("output", this.outputChannel);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@
import org.springframework.integration.annotation.Transformer;
import org.springframework.integration.channel.DirectChannel;
import org.springframework.integration.channel.QueueChannel;
import org.springframework.integration.config.IntegrationRegistrar;
import org.springframework.integration.config.MessagingAnnotationBeanPostProcessor;
import org.springframework.integration.config.MessagingAnnotationPostProcessor;
import org.springframework.integration.endpoint.AbstractEndpoint;
import org.springframework.integration.handler.advice.AbstractRequestHandlerAdvice;
import org.springframework.integration.support.MessageBuilder;
Expand All @@ -49,6 +49,7 @@
import org.springframework.messaging.support.GenericMessage;

import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.mock;

/**
* @author Mark Fisher
Expand All @@ -59,10 +60,9 @@ public class MessagingAnnotationPostProcessorTests {

@Test
public void serviceActivatorAnnotation() {
TestApplicationContext context = TestUtils.createTestApplicationContext();
TestApplicationContext context = createTestApplicationContext();
DirectChannel inputChannel = new DirectChannel();
context.registerChannel("inputChannel", inputChannel);
context.registerBean(MessagingAnnotationPostProcessor.class);
context.refresh();

MessagingAnnotationBeanPostProcessor postProcessor = context.getBean(MessagingAnnotationBeanPostProcessor.class);
Expand Down Expand Up @@ -134,8 +134,7 @@ public void typeConvertingHandler() {

@Test
public void outboundOnlyServiceActivator() throws InterruptedException {
TestApplicationContext context = TestUtils.createTestApplicationContext();
context.registerBean(MessagingAnnotationPostProcessor.class);
TestApplicationContext context = createTestApplicationContext();
context.registerChannel("testChannel", new DirectChannel());
CountDownLatch latch = new CountDownLatch(1);
OutboundOnlyTestBean testBean = new OutboundOnlyTestBean(latch);
Expand All @@ -144,16 +143,14 @@ public void outboundOnlyServiceActivator() throws InterruptedException {
DestinationResolver<MessageChannel> channelResolver = new BeanFactoryChannelResolver(context);
MessageChannel testChannel = channelResolver.resolveDestination("testChannel");
testChannel.send(new GenericMessage<>("foo"));
latch.await(1000, TimeUnit.MILLISECONDS);
assertThat(latch.getCount()).isEqualTo(0);
assertThat(latch.await(10, TimeUnit.SECONDS)).isTrue();
assertThat(testBean.getMessageText()).isEqualTo("foo");
context.close();
}

@Test
public void testChannelResolution() {
TestApplicationContext context = TestUtils.createTestApplicationContext();
context.registerBean(MessagingAnnotationPostProcessor.class);
TestApplicationContext context = createTestApplicationContext();
DirectChannel inputChannel = new DirectChannel();
QueueChannel outputChannel = new QueueChannel();
DirectChannel eventBus = new DirectChannel();
Expand All @@ -177,8 +174,7 @@ public void testChannelResolution() {

@Test
public void testProxiedMessageEndpointAnnotation() {
TestApplicationContext context = TestUtils.createTestApplicationContext();
context.registerBean(MessagingAnnotationPostProcessor.class);
TestApplicationContext context = createTestApplicationContext();
DirectChannel inputChannel = new DirectChannel();
QueueChannel outputChannel = new QueueChannel();
context.registerChannel("inputChannel", inputChannel);
Expand All @@ -195,8 +191,7 @@ public void testProxiedMessageEndpointAnnotation() {

@Test
public void testMessageEndpointAnnotationInherited() {
TestApplicationContext context = TestUtils.createTestApplicationContext();
context.registerBean(MessagingAnnotationPostProcessor.class);
TestApplicationContext context = createTestApplicationContext();
DirectChannel inputChannel = new DirectChannel();
QueueChannel outputChannel = new QueueChannel();
context.registerChannel("inputChannel", inputChannel);
Expand All @@ -211,8 +206,7 @@ public void testMessageEndpointAnnotationInherited() {

@Test
public void testMessageEndpointAnnotationInheritedWithProxy() {
TestApplicationContext context = TestUtils.createTestApplicationContext();
context.registerBean(MessagingAnnotationPostProcessor.class);
TestApplicationContext context = createTestApplicationContext();
DirectChannel inputChannel = new DirectChannel();
QueueChannel outputChannel = new QueueChannel();
context.registerChannel("inputChannel", inputChannel);
Expand All @@ -229,8 +223,7 @@ public void testMessageEndpointAnnotationInheritedWithProxy() {

@Test
public void testMessageEndpointAnnotationInheritedFromInterface() {
TestApplicationContext context = TestUtils.createTestApplicationContext();
context.registerBean(MessagingAnnotationPostProcessor.class);
TestApplicationContext context = createTestApplicationContext();
DirectChannel inputChannel = new DirectChannel();
QueueChannel outputChannel = new QueueChannel();
context.registerChannel("inputChannel", inputChannel);
Expand All @@ -245,8 +238,7 @@ public void testMessageEndpointAnnotationInheritedFromInterface() {

@Test
public void testMessageEndpointAnnotationInheritedFromInterfaceWithAutoCreatedChannels() {
TestApplicationContext context = TestUtils.createTestApplicationContext();
context.registerBean(MessagingAnnotationPostProcessor.class);
TestApplicationContext context = createTestApplicationContext();
DirectChannel inputChannel = new DirectChannel();
QueueChannel outputChannel = new QueueChannel();
context.registerChannel("inputChannel", inputChannel);
Expand All @@ -261,8 +253,7 @@ public void testMessageEndpointAnnotationInheritedFromInterfaceWithAutoCreatedCh

@Test
public void testMessageEndpointAnnotationInheritedFromInterfaceWithProxy() {
TestApplicationContext context = TestUtils.createTestApplicationContext();
context.registerBean(MessagingAnnotationPostProcessor.class);
TestApplicationContext context = createTestApplicationContext();
DirectChannel inputChannel = new DirectChannel();
QueueChannel outputChannel = new QueueChannel();
context.registerChannel("inputChannel", inputChannel);
Expand All @@ -278,8 +269,7 @@ public void testMessageEndpointAnnotationInheritedFromInterfaceWithProxy() {

@Test
public void testTransformer() {
TestApplicationContext context = TestUtils.createTestApplicationContext();
context.registerBean(MessagingAnnotationPostProcessor.class);
TestApplicationContext context = createTestApplicationContext();
DirectChannel inputChannel = new DirectChannel();
context.registerChannel("inputChannel", inputChannel);
QueueChannel outputChannel = new QueueChannel();
Expand All @@ -298,6 +288,12 @@ public void testTransformer() {
context.close();
}

private static TestApplicationContext createTestApplicationContext() {
TestApplicationContext context = TestUtils.createTestApplicationContext();
new IntegrationRegistrar().registerBeanDefinitions(mock(), context.getDefaultListableBeanFactory());
return context;
}

@MessageEndpoint
public static class OutboundOnlyTestBean {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,14 @@
import org.springframework.integration.annotation.Router;
import org.springframework.integration.channel.DirectChannel;
import org.springframework.integration.channel.QueueChannel;
import org.springframework.integration.config.MessagingAnnotationPostProcessor;
import org.springframework.integration.config.IntegrationRegistrar;
import org.springframework.integration.test.util.TestUtils;
import org.springframework.integration.test.util.TestUtils.TestApplicationContext;
import org.springframework.messaging.Message;
import org.springframework.messaging.support.GenericMessage;

import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.mock;

/**
* @author Mark Fisher
Expand All @@ -56,7 +57,7 @@ public class RouterAnnotationPostProcessorTests {

@BeforeEach
public void init() {
this.context.registerBean(MessagingAnnotationPostProcessor.class);
new IntegrationRegistrar().registerBeanDefinitions(mock(), this.context.getDefaultListableBeanFactory());
context.registerChannel("input", inputChannel);
context.registerChannel("output", outputChannel);
context.registerChannel("routingChannel", routingChannel);
Expand Down
Loading

0 comments on commit ef5db30

Please sign in to comment.