Skip to content

Commit aee2df8

Browse files
committed
Improve WebFlux suspending handler method support
Support for suspending handler methods introduced in Spring Framework 5.2 M1 does not detect types correctly and does not support suspending handler methods returning Flow which is a common use case with WebClient. This commit fixes these issues and adds Coroutines integration tests. Closes gh-22820 Closes gh-22827
1 parent dab90cb commit aee2df8

File tree

3 files changed

+184
-7
lines changed

3 files changed

+184
-7
lines changed

spring-core-coroutines/src/main/kotlin/org/springframework/core/CoroutinesUtils.kt

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,12 @@ package org.springframework.core
1919

2020
import kotlinx.coroutines.Deferred
2121
import kotlinx.coroutines.Dispatchers
22+
import kotlinx.coroutines.FlowPreview
2223
import kotlinx.coroutines.GlobalScope
2324
import kotlinx.coroutines.async
25+
import kotlinx.coroutines.flow.Flow
26+
import kotlinx.coroutines.flow.collect
27+
import kotlinx.coroutines.flow.flow
2428
import kotlinx.coroutines.reactive.awaitFirstOrNull
2529

2630
import kotlinx.coroutines.reactor.mono
@@ -29,6 +33,8 @@ import reactor.core.publisher.onErrorMap
2933
import java.lang.reflect.InvocationTargetException
3034
import java.lang.reflect.Method
3135
import kotlin.reflect.full.callSuspend
36+
import kotlin.reflect.full.isSubtypeOf
37+
import kotlin.reflect.full.starProjectedType
3238
import kotlin.reflect.jvm.kotlinFunction
3339

3440
/**
@@ -50,18 +56,29 @@ internal fun <T: Any> monoToDeferred(source: Mono<T>) =
5056
GlobalScope.async(Dispatchers.Unconfined) { source.awaitFirstOrNull() }
5157

5258
/**
53-
* Invoke an handler method converting suspending method to [Mono] if necessary.
59+
* Invoke an handler method converting suspending method to [Mono] or [Flow] if necessary.
5460
*
5561
* @author Sebastien Deleuze
5662
* @since 5.2
5763
*/
64+
@Suppress("UNCHECKED_CAST")
65+
@FlowPreview
5866
internal fun invokeHandlerMethod(method: Method, bean: Any, vararg args: Any?): Any? {
5967
val function = method.kotlinFunction!!
6068
return if (function.isSuspend) {
61-
GlobalScope.mono(Dispatchers.Unconfined) {
62-
function.callSuspend(bean, *args.sliceArray(0..(args.size-2)))
63-
.let { if (it == Unit) null else it} }
64-
.onErrorMap(InvocationTargetException::class) { it.targetException }
69+
if (function.returnType.isSubtypeOf(Flow::class.starProjectedType)) {
70+
flow {
71+
(function.callSuspend(bean, *args.sliceArray(0..(args.size-2))) as Flow<*>).collect {
72+
emit(it)
73+
}
74+
}
75+
}
76+
else {
77+
GlobalScope.mono(Dispatchers.Unconfined) {
78+
function.callSuspend(bean, *args.sliceArray(0..(args.size-2)))
79+
.let { if (it == Unit) null else it}
80+
}.onErrorMap(InvocationTargetException::class) { it.targetException }
81+
}
6582
}
6683
else {
6784
function.call(bean, *args)

spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/AbstractMessageWriterResultHandler.java

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,16 @@
1616

1717
package org.springframework.web.reactive.result.method.annotation;
1818

19+
import java.lang.reflect.Method;
1920
import java.util.ArrayList;
2021
import java.util.List;
2122

23+
import kotlin.reflect.KFunction;
24+
import kotlin.reflect.jvm.ReflectJvmMapping;
2225
import org.reactivestreams.Publisher;
2326
import reactor.core.publisher.Mono;
2427

28+
import org.springframework.core.KotlinDetector;
2529
import org.springframework.core.MethodParameter;
2630
import org.springframework.core.ReactiveAdapter;
2731
import org.springframework.core.ReactiveAdapterRegistry;
@@ -48,6 +52,8 @@
4852
*/
4953
public abstract class AbstractMessageWriterResultHandler extends HandlerResultHandlerSupport {
5054

55+
private static final String COROUTINES_FLOW_CLASS_NAME = "kotlinx.coroutines.flow.Flow";
56+
5157
private final List<HttpMessageWriter<?>> messageWriters;
5258

5359

@@ -110,7 +116,7 @@ protected Mono<Void> writeBody(@Nullable Object body, MethodParameter bodyParame
110116
* @return indicates completion or error
111117
* @since 5.0.2
112118
*/
113-
@SuppressWarnings({"unchecked", "rawtypes"})
119+
@SuppressWarnings({"unchecked", "rawtypes", "ConstantConditions"})
114120
protected Mono<Void> writeBody(@Nullable Object body, MethodParameter bodyParameter,
115121
@Nullable MethodParameter actualParam, ServerWebExchange exchange) {
116122

@@ -122,7 +128,11 @@ protected Mono<Void> writeBody(@Nullable Object body, MethodParameter bodyParame
122128
ResolvableType elementType;
123129
if (adapter != null) {
124130
publisher = adapter.toPublisher(body);
125-
ResolvableType genericType = bodyType.getGeneric();
131+
boolean isUnwrapped = KotlinDetector.isKotlinReflectPresent() &&
132+
KotlinDetector.isKotlinType(bodyParameter.getContainingClass()) &&
133+
KotlinDelegate.isSuspend(bodyParameter.getMethod()) &&
134+
!COROUTINES_FLOW_CLASS_NAME.equals(bodyType.toClass().getName());
135+
ResolvableType genericType = isUnwrapped ? bodyType : bodyType.getGeneric();
126136
elementType = getElementType(adapter, genericType);
127137
}
128138
else {
@@ -183,4 +193,15 @@ private List<MediaType> getMediaTypesFor(ResolvableType elementType) {
183193
return writableMediaTypes;
184194
}
185195

196+
/**
197+
* Inner class to avoid a hard dependency on Kotlin at runtime.
198+
*/
199+
private static class KotlinDelegate {
200+
201+
static private boolean isSuspend(Method method) {
202+
KFunction<?> function = ReflectJvmMapping.getKotlinFunction(method);
203+
return function != null && function.isSuspend();
204+
}
205+
}
206+
186207
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
/*
2+
* Copyright 2002-2019 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.web.reactive.result.method.annotation
18+
19+
import kotlinx.coroutines.Deferred
20+
import kotlinx.coroutines.FlowPreview
21+
import kotlinx.coroutines.GlobalScope
22+
import kotlinx.coroutines.async
23+
import kotlinx.coroutines.delay
24+
import kotlinx.coroutines.flow.Flow
25+
import kotlinx.coroutines.flow.flow
26+
import org.junit.Assert.assertEquals
27+
import org.junit.Test
28+
import org.springframework.context.ApplicationContext
29+
import org.springframework.context.annotation.AnnotationConfigApplicationContext
30+
import org.springframework.context.annotation.ComponentScan
31+
import org.springframework.context.annotation.Configuration
32+
import org.springframework.http.HttpHeaders
33+
import org.springframework.http.HttpStatus
34+
import org.springframework.web.bind.annotation.GetMapping
35+
import org.springframework.web.bind.annotation.RestController
36+
import org.springframework.web.client.HttpServerErrorException
37+
import org.springframework.web.reactive.config.EnableWebFlux
38+
39+
@FlowPreview
40+
class CoroutinesIntegrationTests : AbstractRequestMappingIntegrationTests() {
41+
42+
override fun initApplicationContext(): ApplicationContext {
43+
val context = AnnotationConfigApplicationContext()
44+
context.register(WebConfig::class.java)
45+
context.refresh()
46+
return context
47+
}
48+
49+
@Test
50+
fun `Suspending handler method`() {
51+
val entity = performGet<String>("/suspend", HttpHeaders.EMPTY, String::class.java)
52+
assertEquals(HttpStatus.OK, entity.statusCode)
53+
assertEquals("foo", entity.body)
54+
}
55+
56+
@Test
57+
fun `Handler method returning Deferred`() {
58+
val entity = performGet<String>("/deferred", HttpHeaders.EMPTY, String::class.java)
59+
assertEquals(HttpStatus.OK, entity.statusCode)
60+
assertEquals("foo", entity.body)
61+
}
62+
63+
@Test
64+
fun `Handler method returning Flow`() {
65+
val entity = performGet<String>("/flow", HttpHeaders.EMPTY, String::class.java)
66+
assertEquals(HttpStatus.OK, entity.statusCode)
67+
assertEquals("foobar", entity.body)
68+
}
69+
70+
@Test
71+
fun `Suspending handler method returning Flow`() {
72+
val entity = performGet<String>("/suspending-flow", HttpHeaders.EMPTY, String::class.java)
73+
assertEquals(HttpStatus.OK, entity.statusCode)
74+
assertEquals("foobar", entity.body)
75+
}
76+
77+
@Test(expected = HttpServerErrorException.InternalServerError::class)
78+
fun `Suspending handler method throwing exception`() {
79+
performGet<String>("/error", HttpHeaders.EMPTY, String::class.java)
80+
}
81+
82+
@Test(expected = HttpServerErrorException.InternalServerError::class)
83+
fun `Handler method returning Flow throwing exception`() {
84+
performGet<String>("/flow-error", HttpHeaders.EMPTY, String::class.java)
85+
}
86+
87+
@Configuration
88+
@EnableWebFlux
89+
@ComponentScan(resourcePattern = "**/CoroutinesIntegrationTests*")
90+
open class WebConfig
91+
92+
@RestController
93+
class CoroutinesController {
94+
95+
@GetMapping("/suspend")
96+
suspend fun suspendingEndpoint(): String {
97+
delay(1)
98+
return "foo"
99+
}
100+
101+
@GetMapping("/deferred")
102+
fun deferredEndpoint(): Deferred<String> = GlobalScope.async {
103+
delay(1)
104+
"foo"
105+
}
106+
107+
@GetMapping("/flow")
108+
fun flowEndpoint()= flow {
109+
emit("foo")
110+
delay(1)
111+
emit("bar")
112+
delay(1)
113+
}
114+
115+
@GetMapping("/suspending-flow")
116+
suspend fun suspendingFlowEndpoint(): Flow<String> {
117+
delay(10)
118+
return flow {
119+
emit("foo")
120+
delay(1)
121+
emit("bar")
122+
delay(1)
123+
}
124+
}
125+
126+
@GetMapping("/error")
127+
suspend fun error() {
128+
delay(1)
129+
throw IllegalStateException()
130+
}
131+
132+
@GetMapping("/flow-error")
133+
suspend fun flowError() = flow<String> {
134+
delay(1)
135+
throw IllegalStateException()
136+
}
137+
138+
}
139+
}

0 commit comments

Comments
 (0)