From 2c67fb2e88bd4afae049a68d502a5064f5fe5ce7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ond=C5=99ej=20Hrstka?= Date: Mon, 27 Feb 2023 11:01:49 +0100 Subject: [PATCH] fix fallback for suspended methods (#8825) Fixes #7101 --- .../retry/intercept/RecoveryInterceptor.java | 36 +++- .../kotlin/io/micronaut/retry/FallbackSpec.kt | 183 ++++++++++++++++++ 2 files changed, 217 insertions(+), 2 deletions(-) create mode 100644 test-suite-kotlin/src/test/kotlin/io/micronaut/retry/FallbackSpec.kt diff --git a/runtime/src/main/java/io/micronaut/retry/intercept/RecoveryInterceptor.java b/runtime/src/main/java/io/micronaut/retry/intercept/RecoveryInterceptor.java index 6a45a8d2430..03a1bf0576f 100644 --- a/runtime/src/main/java/io/micronaut/retry/intercept/RecoveryInterceptor.java +++ b/runtime/src/main/java/io/micronaut/retry/intercept/RecoveryInterceptor.java @@ -84,9 +84,15 @@ public Object intercept(MethodInvocationContext context) { fallbackForReactiveType(context, interceptedMethod.interceptResultAsPublisher()) ); case COMPLETION_STAGE: - return interceptedMethod.handleResult( + if (context.isSuspend()) { + return interceptedMethod.handleResult( + fallbackForSuspend(context, interceptedMethod.interceptResultAsCompletionStage()) + ); + } else { + return interceptedMethod.handleResult( fallbackForFuture(context, interceptedMethod.interceptResultAsCompletionStage()) - ); + ); + } case SYNCHRONOUS: try { return context.proceed(); @@ -193,6 +199,32 @@ private CompletionStage fallbackForFuture(MethodInvocationContext fallbackForSuspend(MethodInvocationContext context, CompletionStage result) { + CompletableFuture newFuture = new CompletableFuture<>(); + result.whenComplete((o, throwable) -> { + if (throwable == null) { + newFuture.complete(o); + } else { + Optional> fallbackMethod = findFallbackMethod(context); + if (fallbackMethod.isPresent()) { + MethodExecutionHandle fallbackHandle = fallbackMethod.get(); + if (LOG.isDebugEnabled()) { + LOG.debug("Type [{}] resolved fallback: {}", context.getTarget().getClass(), fallbackHandle); + } + try { + newFuture.complete(fallbackHandle.invoke(context.getParameterValues())); + } catch (Throwable t) { + newFuture.completeExceptionally(t); + } + } else { + newFuture.completeExceptionally(throwable); + } + } + }); + + return newFuture; + } + /** * Resolves a fallback for the given execution context and exception. * diff --git a/test-suite-kotlin/src/test/kotlin/io/micronaut/retry/FallbackSpec.kt b/test-suite-kotlin/src/test/kotlin/io/micronaut/retry/FallbackSpec.kt new file mode 100644 index 00000000000..1f623c3f7e1 --- /dev/null +++ b/test-suite-kotlin/src/test/kotlin/io/micronaut/retry/FallbackSpec.kt @@ -0,0 +1,183 @@ +package io.micronaut.retry + +import io.micronaut.context.ApplicationContext +import io.micronaut.context.annotation.Requires +import io.micronaut.http.HttpResponse +import io.micronaut.http.HttpStatus +import io.micronaut.http.annotation.Controller +import io.micronaut.http.annotation.Post +import io.micronaut.http.client.annotation.Client +import io.micronaut.retry.annotation.Fallback +import io.micronaut.retry.annotation.Recoverable +import io.micronaut.runtime.server.EmbeddedServer +import kotlinx.coroutines.runBlocking +import org.junit.jupiter.api.AfterEach +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.assertThrows +import kotlin.test.assertEquals +import kotlin.test.assertNull + +class FallbackSpec { + + lateinit var server: EmbeddedServer + lateinit var fallbackClient: FallbackClient + + @BeforeEach + fun setUp() { + server = ApplicationContext.run(EmbeddedServer::class.java, mapOf("spec.name" to "FallbackClientSpec")) + val context = server.applicationContext + fallbackClient = context.getBean(FallbackClient::class.java) + } + + @AfterEach + fun tearDown() { + server.close() + } + + @Test + fun `server ok with string output`() { + runBlocking { + val response = fallbackClient.stringOutput(false, false) + assertEquals("server ok", response) + } + } + + @Test + fun `server ok with HttpResponse output`() { + runBlocking { + val response = fallbackClient.httpResponseOutput(false, false) + assertEquals(HttpStatus.OK, response.status) + assertEquals("server ok", response.body()) + } + } + + @Test + fun `server ok with null`() { + runBlocking { + val response = fallbackClient.nullOutput(false, false) + assertNull(response) + } + } + + @Test + fun `server fail with string output`() { + runBlocking { + val response = fallbackClient.stringOutput(true, false) + assertEquals("fallback ok", response) + } + } + + @Test + fun `server fail with HttpResponse output`() { + runBlocking { + val response = fallbackClient.httpResponseOutput(true, false) + assertEquals(HttpStatus.OK, response.status) + assertEquals("fallback ok", response.body()) + } + } + + @Test + fun `server fail with null`() { + runBlocking { + val response = fallbackClient.nullOutput(true, false) + assertNull(response) + } + } + + @Test + fun `faillback fail with string output`() { + runBlocking { + val exception = assertThrows { fallbackClient.stringOutput(true, true) } + assertEquals("fallback fail", exception.message) + } + } + + @Test + fun `faillback fail with HttpResponse output`() { + runBlocking { + val exception = assertThrows { fallbackClient.httpResponseOutput(true, true) } + assertEquals("fallback fail", exception.message) + } + } + + @Test + fun `faillback fail with null`() { + runBlocking { + val exception = assertThrows { fallbackClient.nullOutput(true, true) } + assertEquals("fallback fail", exception.message) + } + } + + +} + + +@Requires(property = "spec.name", value = "FallbackClientSpec") +@Controller("/fallback") +class FallbackClientController { + + @Post("stringOutput") + fun stringOutput(serverFail: Boolean, fallbackFail: Boolean): HttpResponse { + return httpResponseOutput(serverFail, fallbackFail) + } + + @Post("httpResponseOutput") + fun httpResponseOutput(serverFail: Boolean, fallbackFail: Boolean): HttpResponse { + return if (serverFail) { + HttpResponse.serverError("server fail") + } else { + HttpResponse.ok("server ok") + } + } + + @Post("nullOutput") + fun nullOutput(serverFail: Boolean, fallbackFail: Boolean): HttpResponse { + return if (serverFail) { + HttpResponse.serverError() + } else { + HttpResponse.ok() + } + } +} + +@Client("/fallback") +@Recoverable(api = FallbackClientFallback::class) +interface FallbackClient { + + @Post("stringOutput") + suspend fun stringOutput(serverFail: Boolean, fallbackFail: Boolean): String + + @Post("httpResponseOutput") + suspend fun httpResponseOutput(serverFail: Boolean, fallbackFail: Boolean): HttpResponse + + @Post("nullOutput") + suspend fun nullOutput(serverFail: Boolean, fallbackFail: Boolean): String? +} + +@Fallback +open class FallbackClientFallback : FallbackClient { + override suspend fun stringOutput(serverFail: Boolean, fallbackFail: Boolean): String { + return if (fallbackFail) { + throw RuntimeException("fallback fail") + } else { + "fallback ok" + } + } + + override suspend fun httpResponseOutput(serverFail: Boolean, fallbackFail: Boolean): HttpResponse { + return if (fallbackFail) { + throw RuntimeException("fallback fail") + } else { + HttpResponse.ok("fallback ok") + } + } + + override suspend fun nullOutput(serverFail: Boolean, fallbackFail: Boolean): String? { + return if (fallbackFail) { + throw RuntimeException("fallback fail") + } else { + null + } + } +}