diff --git a/.github/workflows/graalvm.yml b/.github/workflows/graalvm.yml index 23d722775de..7443db9c146 100644 --- a/.github/workflows/graalvm.yml +++ b/.github/workflows/graalvm.yml @@ -21,6 +21,9 @@ jobs: matrix: java: ['17'] graalvm: ['latest', 'dev'] + include: + - graalvm: 'latest' + java: '11' steps: # https://github.com/actions/virtual-environments/issues/709 - name: Free disk space diff --git a/core/src/main/java/io/micronaut/core/convert/value/ConvertibleMultiValuesMap.java b/core/src/main/java/io/micronaut/core/convert/value/ConvertibleMultiValuesMap.java index 6ae90088e84..f0df52f9562 100644 --- a/core/src/main/java/io/micronaut/core/convert/value/ConvertibleMultiValuesMap.java +++ b/core/src/main/java/io/micronaut/core/convert/value/ConvertibleMultiValuesMap.java @@ -26,6 +26,7 @@ import java.util.LinkedHashSet; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; @@ -144,6 +145,23 @@ protected Map> wrapValues(Map> value return Collections.unmodifiableMap(values); } + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + ConvertibleMultiValuesMap that = (ConvertibleMultiValuesMap) o; + return values.equals(that.values); + } + + @Override + public int hashCode() { + return Objects.hash(values); + } + @Override public ConversionService getConversionService() { return conversionService; diff --git a/core/src/main/java/io/micronaut/core/convert/value/ConvertibleValuesMap.java b/core/src/main/java/io/micronaut/core/convert/value/ConvertibleValuesMap.java index f8492761b48..d752474a22e 100644 --- a/core/src/main/java/io/micronaut/core/convert/value/ConvertibleValuesMap.java +++ b/core/src/main/java/io/micronaut/core/convert/value/ConvertibleValuesMap.java @@ -24,6 +24,7 @@ import java.util.Collections; import java.util.LinkedHashMap; import java.util.Map; +import java.util.Objects; import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; @@ -106,6 +107,23 @@ public static ConvertibleValues empty() { return EMPTY; } + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + ConvertibleValuesMap that = (ConvertibleValuesMap) o; + return map.equals(that.map); + } + + @Override + public int hashCode() { + return Objects.hash(map); + } + @Override public void setConversionService(ConversionService conversionService) { this.conversionService = conversionService; diff --git a/gradle.properties b/gradle.properties index eed27c334c0..e50ca3c7f65 100644 --- a/gradle.properties +++ b/gradle.properties @@ -44,7 +44,7 @@ developers=Graeme Rocher kapt.use.worker.api=true # Dependency Versions -micronautMavenPluginVersion=3.5.0 +micronautMavenPluginVersion=3.5.1 chromedriverVersion=79.0.3945.36 geckodriverVersion=0.26.0 webdriverBinariesVersion=1.4 diff --git a/http-client-core/src/main/java/io/micronaut/http/client/HttpClient.java b/http-client-core/src/main/java/io/micronaut/http/client/HttpClient.java index b54aa145806..fdba6c1d9bb 100644 --- a/http-client-core/src/main/java/io/micronaut/http/client/HttpClient.java +++ b/http-client-core/src/main/java/io/micronaut/http/client/HttpClient.java @@ -152,9 +152,15 @@ default Publisher> exchange(@NonNull HttpRequest reque * @param The error type * @return A {@link Publisher} that emits a result of the given type */ + @SuppressWarnings("unchecked") default Publisher retrieve(@NonNull HttpRequest request, @NonNull Argument bodyType, @NonNull Argument errorType) { // note: this default impl isn't used by us anymore, it's overridden by DefaultHttpClient - return Flux.from(exchange(request, bodyType, errorType)).map(response -> { + Flux> exchange = Flux.from(exchange(request, bodyType, errorType)); + if (bodyType.getType() == void.class) { + // exchange() returns a HttpResponse, we can't map the Void body properly, so just drop it and complete + return (Publisher) exchange.ignoreElements(); + } + return exchange.map(response -> { if (bodyType.getType() == HttpStatus.class) { return (O) response.getStatus(); } else { diff --git a/http-client-core/src/main/java/io/micronaut/http/client/interceptor/HttpClientIntroductionAdvice.java b/http-client-core/src/main/java/io/micronaut/http/client/interceptor/HttpClientIntroductionAdvice.java index df185295320..34ca6b9ce24 100644 --- a/http-client-core/src/main/java/io/micronaut/http/client/interceptor/HttpClientIntroductionAdvice.java +++ b/http-client-core/src/main/java/io/micronaut/http/client/interceptor/HttpClientIntroductionAdvice.java @@ -299,11 +299,8 @@ public Object intercept(MethodInvocationContext context) { request.setAttribute(HttpAttributes.INVOCATION_CONTEXT, context); // Set the URI template used to make the request for tracing purposes request.setAttribute(HttpAttributes.URI_TEMPLATE, resolveTemplate(annotationMetadata, uriTemplate.toString())); - String serviceId = getClientId(annotationMetadata); Argument errorType = annotationMetadata.classValue(Client.class, "errorType") .map((Function) Argument::of).orElse(HttpClient.DEFAULT_ERROR_TYPE); - request.setAttribute(HttpAttributes.SERVICE_ID, serviceId); - final MediaType[] acceptTypes; Collection accept = request.accept(); @@ -426,7 +423,7 @@ private Publisher httpClientResponsePublisher(HttpClient httpClient, MutableHttp Class argumentType = reactiveValueArgument.getType(); if (Void.class == argumentType || returnType.isVoid()) { request.getHeaders().remove(HttpHeaders.ACCEPT); - return httpClient.exchange(request, Argument.VOID, errorType); + return httpClient.retrieve(request, Argument.VOID, errorType); } else { if (HttpResponse.class.isAssignableFrom(argumentType)) { return httpClient.exchange(request, reactiveValueArgument, errorType); diff --git a/http-client/src/main/java/io/micronaut/http/client/netty/DefaultHttpClient.java b/http-client/src/main/java/io/micronaut/http/client/netty/DefaultHttpClient.java index 8f08d27e8bc..06be496caa1 100644 --- a/http-client/src/main/java/io/micronaut/http/client/netty/DefaultHttpClient.java +++ b/http-client/src/main/java/io/micronaut/http/client/netty/DefaultHttpClient.java @@ -35,7 +35,7 @@ import io.micronaut.core.util.ArrayUtils; import io.micronaut.core.util.CollectionUtils; import io.micronaut.core.util.StringUtils; -import io.micronaut.core.util.SupplierUtil; +import io.micronaut.http.HttpAttributes; import io.micronaut.http.HttpResponse; import io.micronaut.http.HttpResponseWrapper; import io.micronaut.http.HttpStatus; @@ -91,6 +91,7 @@ import io.micronaut.http.sse.Event; import io.micronaut.http.uri.UriBuilder; import io.micronaut.http.uri.UriTemplate; +import io.micronaut.http.util.HttpHeadersUtil; import io.micronaut.json.JsonMapper; import io.micronaut.json.codec.JsonMediaTypeCodec; import io.micronaut.json.codec.JsonStreamMediaTypeCodec; @@ -189,7 +190,6 @@ import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Supplier; -import java.util.regex.Pattern; import static io.micronaut.scheduling.instrument.InvocationInstrumenter.NOOP; @@ -216,9 +216,6 @@ public class DefaultHttpClient implements private static final int DEFAULT_HTTP_PORT = 80; private static final int DEFAULT_HTTPS_PORT = 443; - private static final Supplier HEADER_MASK_PATTERNS = SupplierUtil.memoized(() -> - Pattern.compile(".*(password|cred|cert|key|secret|token|auth|signat).*", Pattern.CASE_INSENSITIVE) - ); /** * Which headers not to copy from the first request when redirecting to a second request. There doesn't * appear to be a spec for this. {@link java.net.HttpURLConnection} seems to drop all headers, but that would be a @@ -792,7 +789,12 @@ private Flux> exchange(io.micronaut.http.HttpRequest Publisher retrieve(io.micronaut.http.HttpRequest request, Argument bodyType, Argument errorType) { setupConversionService(request); // mostly same as default impl, but with exception customization - return Flux.from(exchange(request, bodyType, errorType)).map(response -> { + Flux> exchange = Flux.from(exchange(request, bodyType, errorType)); + if (bodyType.getType() == void.class) { + // exchange() returns a HttpResponse, we can't map the Void body properly, so just drop it and complete + return (Publisher) exchange.ignoreElements(); + } + return exchange.map(response -> { if (bodyType.getType() == HttpStatus.class) { return (O) response.getStatus(); } else { @@ -1241,7 +1243,13 @@ private > Publisher applyFilte Publisher responsePublisher) { if (request instanceof MutableHttpRequest) { - ((MutableHttpRequest) request).uri(requestURI); + MutableHttpRequest mutRequest = (MutableHttpRequest) request; + mutRequest.uri(requestURI); + if (informationalServiceId != null && + !mutRequest.getAttribute(HttpAttributes.SERVICE_ID).isPresent()) { + + mutRequest.setAttribute(HttpAttributes.SERVICE_ID, informationalServiceId); + } List filters = filterResolver.resolveFilters(request, clientFilterEntries); @@ -1552,10 +1560,6 @@ private void sendRequestThroughChannel( new FullHttpResponseHandler<>(responsePromise, poolHandle, secure, finalRequest, bodyType, errorType)); poolHandle.notifyRequestPipelineBuilt(); Publisher> publisher = new NettyFuturePublisher<>(responsePromise, true); - if (bodyType != null && bodyType.isVoid()) { - // don't emit response if bodyType is void - publisher = Flux.from(publisher).filter(r -> false); - } publisher.subscribe(new ForwardingSubscriber<>(emitter)); requestWriter.write(poolHandle, secure, emitter); @@ -1706,7 +1710,7 @@ private ClientFilterChain buildChain(AtomicReference> requ public Publisher> proceed(MutableHttpRequest request) { int pos = integer.incrementAndGet(); - if (pos > len) { + if (pos >= len) { throw new IllegalStateException("The FilterChain.proceed(..) method should be invoked exactly once per filter execution. The method has instead been invoked multiple times by an erroneous filter definition."); } HttpClientFilter httpFilter = filters.get(pos); @@ -1815,7 +1819,7 @@ private void debugRequest(URI requestURI, io.netty.handler.codec.http.HttpReques private void traceRequest(io.micronaut.http.HttpRequest request, io.netty.handler.codec.http.HttpRequest nettyRequest) { HttpHeaders headers = nettyRequest.headers(); - traceHeaders(headers); + HttpHeadersUtil.trace(log, headers.names(), headers::getAll); if (io.micronaut.http.HttpMethod.permitsRequestBody(request.getMethod()) && request.getBody().isPresent() && nettyRequest instanceof FullHttpRequest) { FullHttpRequest fullHttpRequest = (FullHttpRequest) nettyRequest; ByteBuf content = fullHttpRequest.content(); @@ -1839,30 +1843,6 @@ private void traceChunk(ByteBuf content) { log.trace("----"); } - private void traceHeaders(HttpHeaders headers) { - for (String name : headers.names()) { - boolean isMasked = HEADER_MASK_PATTERNS.get().matcher(name).matches(); - List all = headers.getAll(name); - if (all.size() > 1) { - for (String value : all) { - String maskedValue = isMasked ? mask(value) : value; - log.trace("{}: {}", name, maskedValue); - } - } else if (!all.isEmpty()) { - String maskedValue = isMasked ? mask(all.get(0)) : all.get(0); - log.trace("{}: {}", name, maskedValue); - } - } - } - - @Nullable - private String mask(@Nullable String value) { - if (value == null) { - return null; - } - return "*MASKED*"; - } - private static MediaTypeCodecRegistry createDefaultMediaTypeRegistry() { JsonMapper mapper = JsonMapper.createDefault(); ApplicationConfiguration configuration = new ApplicationConfiguration(); @@ -2134,7 +2114,7 @@ protected void channelReadInstrumented(ChannelHandlerContext ctx, R msg) throws HttpHeaders headers = msg.headers(); if (log.isTraceEnabled()) { log.trace("HTTP Client Response Received ({}) for Request: {} {}", msg.status(), finalRequest.getMethodName(), finalRequest.getUri()); - traceHeaders(headers); + HttpHeadersUtil.trace(log, headers.names(), headers::getAll); } buildResponse(responsePromise, msg); removeHandler(ctx); diff --git a/http-client/src/main/java/io/micronaut/http/client/netty/websocket/NettyWebSocketClientHandler.java b/http-client/src/main/java/io/micronaut/http/client/netty/websocket/NettyWebSocketClientHandler.java index a7e7690d2a8..e3ee2fcc948 100644 --- a/http-client/src/main/java/io/micronaut/http/client/netty/websocket/NettyWebSocketClientHandler.java +++ b/http-client/src/main/java/io/micronaut/http/client/netty/websocket/NettyWebSocketClientHandler.java @@ -16,7 +16,6 @@ package io.micronaut.http.client.netty.websocket; import io.micronaut.core.annotation.Internal; -import io.micronaut.core.async.publisher.Publishers; import io.micronaut.core.bind.BoundExecutable; import io.micronaut.core.bind.DefaultExecutableBinder; import io.micronaut.core.bind.ExecutableBinder; @@ -24,26 +23,22 @@ import io.micronaut.core.convert.value.ConvertibleValues; import io.micronaut.core.type.Argument; import io.micronaut.http.MutableHttpRequest; -import io.micronaut.http.bind.DefaultRequestBinderRegistry; import io.micronaut.http.bind.RequestBinderRegistry; import io.micronaut.http.codec.MediaTypeCodecRegistry; import io.micronaut.http.netty.websocket.AbstractNettyWebSocketHandler; import io.micronaut.http.netty.websocket.NettyWebSocketSession; import io.micronaut.http.uri.UriMatchInfo; import io.micronaut.http.uri.UriMatchTemplate; -import io.micronaut.inject.MethodExecutionHandle; import io.micronaut.websocket.CloseReason; import io.micronaut.websocket.WebSocketPongMessage; import io.micronaut.websocket.annotation.ClientWebSocket; import io.micronaut.websocket.bind.WebSocketState; -import io.micronaut.websocket.bind.WebSocketStateBinderRegistry; import io.micronaut.websocket.context.WebSocketBean; import io.micronaut.websocket.exceptions.WebSocketClientException; import io.micronaut.websocket.exceptions.WebSocketSessionException; import io.micronaut.websocket.interceptor.WebSocketSessionAware; import io.netty.channel.Channel; import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelPromise; import io.netty.handler.codec.http.FullHttpResponse; import io.netty.handler.codec.http.HttpHeaderNames; import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame; @@ -52,14 +47,12 @@ import io.netty.handler.ssl.SslHandler; import io.netty.handler.timeout.IdleState; import io.netty.handler.timeout.IdleStateEvent; -import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.publisher.Sinks; import java.util.Collections; import java.util.List; -import java.util.Optional; /** * Handler for WebSocket clients. @@ -78,9 +71,7 @@ public class NettyWebSocketClientHandler extends AbstractNettyWebSocketHandle private final Sinks.One completion = Sinks.one(); private final UriMatchInfo matchInfo; private final MediaTypeCodecRegistry codecRegistry; - private ChannelPromise handshakeFuture; private NettyWebSocketSession clientSession; - private final WebSocketStateBinderRegistry webSocketStateBinderRegistry; private FullHttpResponse handshakeResponse; private Argument clientBodyArgument; private Argument clientPongArgument; @@ -106,12 +97,9 @@ public NettyWebSocketClientHandler( this.codecRegistry = mediaTypeCodecRegistry; this.handshaker = handshaker; this.genericWebSocketBean = webSocketBean; - this.webSocketStateBinderRegistry = new WebSocketStateBinderRegistry(requestBinderRegistry != null ? requestBinderRegistry : new DefaultRequestBinderRegistry(conversionService), conversionService); String clientPath = webSocketBean.getBeanDefinition().stringValue(ClientWebSocket.class).orElse(""); UriMatchTemplate matchTemplate = UriMatchTemplate.of(clientPath); this.matchInfo = matchTemplate.match(request.getPath()).orElse(null); - - callOpenMethod(null); } @Override @@ -142,11 +130,6 @@ public NettyWebSocketSession getSession() { return clientSession; } - @Override - public void handlerAdded(final ChannelHandlerContext ctx) { - handshakeFuture = ctx.newPromise(); - } - @Override public void channelActive(final ChannelHandlerContext ctx) { handshaker.handshake(ctx.channel()).addListener(future -> { @@ -154,7 +137,7 @@ public void channelActive(final ChannelHandlerContext ctx) { ctx.channel().config().setAutoRead(true); ctx.read(); } else { - handshakeFuture.tryFailure(future.cause()); + completion.tryEmitError(future.cause()); } }); } @@ -178,7 +161,6 @@ protected void channelRead0(ChannelHandlerContext ctx, Object msg) { } return; } - handshakeFuture.setSuccess(); this.clientSession = createWebSocketSession(ctx); @@ -188,7 +170,6 @@ protected void channelRead0(ChannelHandlerContext ctx, Object msg) { ((WebSocketSessionAware) targetBean).setWebSocketSession(clientSession); } - ExecutableBinder binder = new DefaultExecutableBinder<>(); BoundExecutable bound = binder.tryBind(messageHandler.getExecutableMethod(), webSocketBinder, new WebSocketState(clientSession, originatingRequest)); List> unboundArguments = bound.getUnboundArguments(); @@ -228,37 +209,11 @@ protected void channelRead0(ChannelHandlerContext ctx, Object msg) { } } - Optional> opt = webSocketBean.openMethod(); - if (opt.isPresent()) { - MethodExecutionHandle openMethod = opt.get(); - - WebSocketState webSocketState = new WebSocketState(clientSession, originatingRequest); - try { - BoundExecutable openMethodBound = binder.bind(openMethod.getExecutableMethod(), webSocketStateBinderRegistry, webSocketState); - Object target = openMethod.getTarget(); - Object result = openMethodBound.invoke(target); - - if (Publishers.isConvertibleToPublisher(result)) { - Publisher reactiveSequence = Publishers.convertPublisher(result, Publisher.class); - Flux.from(reactiveSequence).subscribe( - o -> { }, - error -> completion.tryEmitError(new WebSocketSessionException("Error opening WebSocket client session: " + error.getMessage(), error)), - () -> { - completion.tryEmitValue(targetBean); - } - ); - } else { - completion.tryEmitValue(targetBean); - } - } catch (Throwable e) { - completion.tryEmitError(new WebSocketClientException("Error opening WebSocket client session: " + e.getMessage(), e)); - if (getSession().isOpen()) { - getSession().close(CloseReason.INTERNAL_ERROR); - } - } - } else { - completion.tryEmitValue(targetBean); - } + Flux.from(callOpenMethod(ctx)).subscribe( + o -> { }, + error -> completion.tryEmitError(new WebSocketSessionException("Error opening WebSocket client session: " + error.getMessage(), error)), + () -> completion.tryEmitValue(targetBean) + ); return; } @@ -267,8 +222,6 @@ protected void channelRead0(ChannelHandlerContext ctx, Object msg) { } else { ctx.fireChannelRead(msg); } - - } @Override @@ -296,14 +249,20 @@ public ConvertibleValues getUriVariables() { @Override public void exceptionCaught(final ChannelHandlerContext ctx, final Throwable cause) { - if (!handshakeFuture.isDone()) { - handshakeFuture.setFailure(cause); - } - + completion.tryEmitError(cause); super.exceptionCaught(ctx, cause); } public final Mono getHandshakeCompletedMono() { return completion.asMono(); } + + @Override + protected void handleCloseReason(ChannelHandlerContext ctx, CloseReason cr, boolean writeCloseReason) { + if (!handshaker.isHandshakeComplete()) { + completion.tryEmitError(new WebSocketClientException("Error opening WebSocket client session: " + cr.getReason())); + return; + } + super.handleCloseReason(ctx, cr, writeCloseReason); + } } diff --git a/http-client/src/test/groovy/io/micronaut/http/client/ServiceIdSpec.groovy b/http-client/src/test/groovy/io/micronaut/http/client/ServiceIdSpec.groovy new file mode 100644 index 00000000000..3efb8c7a022 --- /dev/null +++ b/http-client/src/test/groovy/io/micronaut/http/client/ServiceIdSpec.groovy @@ -0,0 +1,101 @@ +package io.micronaut.http.client + +import io.micronaut.context.ApplicationContext +import io.micronaut.context.annotation.Requires +import io.micronaut.http.HttpAttributes +import io.micronaut.http.HttpRequest +import io.micronaut.http.HttpResponse +import io.micronaut.http.HttpVersion +import io.micronaut.http.MutableHttpRequest +import io.micronaut.http.annotation.Controller +import io.micronaut.http.annotation.Filter +import io.micronaut.http.annotation.Get +import io.micronaut.http.client.annotation.Client +import io.micronaut.http.filter.ClientFilterChain +import io.micronaut.http.filter.HttpClientFilter +import io.micronaut.runtime.server.EmbeddedServer +import jakarta.inject.Singleton +import org.reactivestreams.Publisher +import spock.lang.Specification + +class ServiceIdSpec extends Specification { + + def 'service id set by declarative client'() { + given: + def serverCtx = ApplicationContext.run([ + 'spec.name': 'ServiceIdSpec', + ]) + def server = serverCtx.getBean(EmbeddedServer) + server.start() + def clientCtx = ApplicationContext.run([ + 'spec.name': 'ServiceIdSpec', + 'micronaut.http.services.my-client-id.url': server.URI, + ]) + def client = clientCtx.getBean(DeclarativeClient) + def filter = clientCtx.getBean(ServiceIdFilter) + + expect: + filter.serviceId == null + client.index() == "foo" + filter.serviceId == "my-client-id" + + cleanup: + server.close() + serverCtx.close() + clientCtx.close() + } + + def 'service id set by normal client'() { + given: + def serverCtx = ApplicationContext.run([ + 'spec.name': 'ServiceIdSpec', + ]) + def server = serverCtx.getBean(EmbeddedServer) + server.start() + def clientCtx = ApplicationContext.run([ + 'spec.name': 'ServiceIdSpec', + 'micronaut.http.services.my-client-id.url': server.URI, + ]) + def client = clientCtx.getBean(HttpClientRegistry).getClient(HttpVersion.HTTP_1_1, "my-client-id", null) + def filter = clientCtx.getBean(ServiceIdFilter) + + expect: + filter.serviceId == null + client.toBlocking().exchange("/service-id", String).body() == "foo" + filter.serviceId == "my-client-id" + + cleanup: + server.close() + serverCtx.close() + clientCtx.close() + } + + @Client(id = "my-client-id") + static interface DeclarativeClient { + @Get("/service-id") + String index() + } + + @Singleton + @Requires(property = "spec.name", value = "ServiceIdSpec") + @Controller("/service-id") + static class ServiceIdController { + @Get + def index(HttpRequest request) { + return "foo" + } + } + + @Singleton + @Requires(property = "spec.name", value = "ServiceIdSpec") + @Filter(Filter.MATCH_ALL_PATTERN) + static class ServiceIdFilter implements HttpClientFilter { + String serviceId + + @Override + Publisher> doFilter(MutableHttpRequest request, ClientFilterChain chain) { + serviceId = request.getAttribute(HttpAttributes.SERVICE_ID).orElse(null) + return chain.proceed(request) + } + } +} diff --git a/http-client/src/test/groovy/io/micronaut/http/client/aop/ClientFilterSpec.groovy b/http-client/src/test/groovy/io/micronaut/http/client/aop/ClientFilterSpec.groovy index eb57296ac5a..d6efb6145ae 100644 --- a/http-client/src/test/groovy/io/micronaut/http/client/aop/ClientFilterSpec.groovy +++ b/http-client/src/test/groovy/io/micronaut/http/client/aop/ClientFilterSpec.groovy @@ -17,6 +17,7 @@ package io.micronaut.http.client.aop import io.micronaut.context.ApplicationContext import io.micronaut.context.annotation.Requires +import io.micronaut.core.async.annotation.SingleResult import io.micronaut.http.HttpResponse import io.micronaut.http.HttpVersion import io.micronaut.http.MediaType @@ -32,8 +33,9 @@ import io.micronaut.http.filter.ClientFilterChain import io.micronaut.http.filter.HttpClientFilter import io.micronaut.runtime.server.EmbeddedServer import org.reactivestreams.Publisher +import reactor.core.publisher.Flux +import reactor.core.publisher.Mono import spock.lang.AutoCleanup -import spock.lang.Shared import spock.lang.Specification /** @@ -267,4 +269,46 @@ class ClientFilterSpec extends Specification{ throw new RuntimeException("from filter") } } + + void "filter always observes a response"() { + given: + ObservesResponseClient client = context.getBean(ObservesResponseClient) + ObservesResponseFilter filter = context.getBean(ObservesResponseFilter) + + when: + Mono.from(client.monoVoid()).block() == null + then: + filter.observedResponse != null + } + + @Requires(property = 'spec.name', value = "ClientFilterSpec") + @Client('/observes-response') + static interface ObservesResponseClient { + + @Get + @SingleResult + Publisher monoVoid() + } + + @Requires(property = 'spec.name', value = "ClientFilterSpec") + @Filter("/observes-response/**") + static class ObservesResponseFilter implements HttpClientFilter { + HttpResponse observedResponse + + @Override + Publisher> doFilter(MutableHttpRequest request, ClientFilterChain chain) { + return Flux.from(chain.proceed(request)).doOnNext(r -> { + observedResponse = r + }) + } + } + + @Requires(property = 'spec.name', value = "ClientFilterSpec") + @Controller('/observes-response') + static class ObservesResponseController { + @Get + String index() { + return "" + } + } } diff --git a/http-client/src/test/groovy/io/micronaut/http/client/aop/QueryParametersSpec.groovy b/http-client/src/test/groovy/io/micronaut/http/client/aop/QueryParametersSpec.groovy index 4cd249bd05e..4c3b18aa3dd 100644 --- a/http-client/src/test/groovy/io/micronaut/http/client/aop/QueryParametersSpec.groovy +++ b/http-client/src/test/groovy/io/micronaut/http/client/aop/QueryParametersSpec.groovy @@ -27,8 +27,8 @@ import io.micronaut.http.client.HttpClient import io.micronaut.http.client.annotation.Client import io.micronaut.http.client.exceptions.HttpClientResponseException import io.micronaut.runtime.server.EmbeddedServer -import reactor.core.publisher.Flux import spock.lang.AutoCleanup +import spock.lang.Issue import spock.lang.Shared import spock.lang.Specification import spock.lang.Unroll @@ -62,6 +62,18 @@ class QueryParametersSpec extends Specification { flavour << [ "pojo", "singlePojo", "list", "map" ] } + @Issue('https://github.com/micronaut-projects/micronaut-core/issues/8338') + void "test client mappping URL parameters appended through a Map does not modify the Map"() { + when: + // this modification is relatively benign, but if the user passed a Map.of, then trying to remove null leads to + // an exception. Unfortunately we can't test with Map.of. + def map = [term: "Riverside", foo: null] + def result = client.searchExplodedMap("map", map) + then: + result.albums.size() == 2 + map.containsValue(null) + } + @Unroll void "test client mappping multiple URL parameters appended through a Map (served through #flavour)"() { expect: diff --git a/http-client/src/test/groovy/io/micronaut/http/client/netty/DefaultClientHeaderMaskTest.groovy b/http-client/src/test/groovy/io/micronaut/http/client/netty/DefaultClientHeaderMaskTest.groovy index 8f3a693ed91..64134aa22b8 100644 --- a/http-client/src/test/groovy/io/micronaut/http/client/netty/DefaultClientHeaderMaskTest.groovy +++ b/http-client/src/test/groovy/io/micronaut/http/client/netty/DefaultClientHeaderMaskTest.groovy @@ -1,66 +1,64 @@ package io.micronaut.http.client.netty import ch.qos.logback.classic.Level -import ch.qos.logback.classic.Logger +import io.micronaut.context.ApplicationContext +import io.micronaut.context.annotation.Requires +import io.micronaut.http.HttpRequest +import io.micronaut.http.annotation.Controller +import io.micronaut.http.annotation.Get +import io.micronaut.http.client.HttpClient +import io.micronaut.runtime.server.EmbeddedServer +import jakarta.inject.Singleton +import org.slf4j.Logger import ch.qos.logback.classic.spi.ILoggingEvent import ch.qos.logback.core.AppenderBase -import io.micronaut.context.ApplicationContext import io.netty.handler.codec.http.DefaultHttpHeaders import org.slf4j.LoggerFactory import spock.lang.Specification import java.util.concurrent.BlockingQueue import java.util.concurrent.LinkedBlockingQueue -import java.util.concurrent.TimeUnit class DefaultClientHeaderMaskTest extends Specification { - def "check masking works for #value"() { + def "check mask detects common security headers"() { given: - def ctx = ApplicationContext.run() - def client = ctx.createBean(DefaultHttpClient, "http://localhost:8080") + EmbeddedServer server = ApplicationContext.run(EmbeddedServer, ["spec.name": "DefaultClientHeaderMaskTest"]) + ApplicationContext ctx = server.applicationContext + HttpClient client = ctx.createBean(HttpClient, server.URL) expect: - client.mask(value) == expected + client instanceof DefaultHttpClient - cleanup: - ctx.close() + when: + MemoryAppender appender = new MemoryAppender() + Logger log = LoggerFactory.getLogger(DefaultHttpClient.class) - where: - value | expected - null | null - "foo" | "*MASKED*" - "Tim Yates" | "*MASKED*" - } + then: + log instanceof ch.qos.logback.classic.Logger - def "check mask detects common security headers"() { - given: - MemoryAppender appender = new MemoryAppender() - Logger logger = (Logger) LoggerFactory.getLogger(DefaultHttpClient.class) + when: + ch.qos.logback.classic.Logger logger = (ch.qos.logback.classic.Logger) log logger.addAppender(appender) logger.setLevel(Level.TRACE) appender.start() - DefaultHttpHeaders headers = new DefaultHttpHeaders() - headers.add("Authorization", "Bearer foo") - headers.add("Proxy-Authorization", "AWS4-HMAC-SHA256 bar") - headers.add("Cookie", "baz") - headers.add("Set-Cookie", "qux") - headers.add("X-Forwarded-For", "quux") - headers.add("X-Forwarded-Host", "quuz") - headers.add("X-Real-IP", "waldo") - headers.add("X-Forwarded-For", "fred") - headers.add("Credential", "foo") - headers.add("Signature", "bar probably secret") - def ctx = ApplicationContext.run() - def client = ctx.createBean(DefaultHttpClient, "http://localhost:8080") - - when: - client.traceHeaders(headers) + def response = client.toBlocking().exchange(HttpRequest.GET("/masking").headers {headers -> + headers.add("Authorization", "Bearer foo") + headers.add("Proxy-Authorization", "AWS4-HMAC-SHA256 bar") + headers.add("Cookie", "baz") + headers.add("Set-Cookie", "qux") + headers.add("X-Forwarded-For", "quux") + headers.add("X-Forwarded-Host", "quuz") + headers.add("X-Real-IP", "waldo") + headers.add("X-Forwarded-For", "fred") + headers.add("Credential", "foo") + headers.add("Signature", "bar probably secret") + }, String) then: - appender.events.size() == 10 - appender.events.join("\n") == """Authorization: *MASKED* + response.body() == "ok" + appender.events.join("\n").contains("""Authorization: *MASKED* |Proxy-Authorization: *MASKED* |Cookie: baz |Set-Cookie: qux @@ -69,11 +67,21 @@ class DefaultClientHeaderMaskTest extends Specification { |X-Forwarded-Host: quuz |X-Real-IP: waldo |Credential: *MASKED* - |Signature: *MASKED*""".stripMargin() + |Signature: *MASKED*""".stripMargin()) cleanup: - ctx.close() appender.stop() + ctx.close() + } + + @Requires(property = "spec.name", value = "DefaultClientHeaderMaskTest") + @Controller("/masking") + @Singleton + static class MaskedController { + @Get + String get() { + "ok" + } } static class MemoryAppender extends AppenderBase { diff --git a/http-client/src/test/groovy/io/micronaut/http/client/websocket/ClientWebsocketSpec.groovy b/http-client/src/test/groovy/io/micronaut/http/client/websocket/ClientWebsocketSpec.groovy new file mode 100644 index 00000000000..71957bfc1cf --- /dev/null +++ b/http-client/src/test/groovy/io/micronaut/http/client/websocket/ClientWebsocketSpec.groovy @@ -0,0 +1,77 @@ +package io.micronaut.http.client.websocket + +import io.micronaut.context.ApplicationContext +import io.micronaut.context.annotation.Requires +import io.micronaut.websocket.WebSocketClient +import io.micronaut.websocket.annotation.ClientWebSocket +import io.micronaut.websocket.annotation.OnClose +import io.micronaut.websocket.annotation.OnMessage +import io.micronaut.websocket.annotation.OnOpen +import io.micronaut.websocket.exceptions.WebSocketClientException +import jakarta.inject.Inject +import jakarta.inject.Singleton +import reactor.core.publisher.Mono +import spock.lang.Specification + +import java.util.concurrent.ExecutionException + +class ClientWebsocketSpec extends Specification { + void 'websocket bean should not open if there is a connection error'() { + given: + def ctx = ApplicationContext.run(['spec.name': 'ClientWebsocketSpec']) + def client = ctx.getBean(WebSocketClient) + def registry = ctx.getBean(ClientBeanRegistry) + def mono = Mono.from(client.connect(ClientBean.class, 'http://does-not-exist')) + + when: + mono.toFuture().get() + then: + def e = thrown ExecutionException + e.cause instanceof WebSocketClientException + + registry.clientBeans.size() == 1 + !registry.clientBeans[0].opened + !registry.clientBeans[0].autoClosed + !registry.clientBeans[0].onClosed + + cleanup: + client.close() + } + + @Singleton + @Requires(property = 'spec.name', value = 'ClientWebsocketSpec') + static class ClientBeanRegistry { + List clientBeans = new ArrayList<>() + } + + @ClientWebSocket + static class ClientBean implements AutoCloseable { + boolean opened = false + boolean onClosed = false + boolean autoClosed = false + + @Inject + ClientBean(ClientBeanRegistry registry) { + registry.clientBeans.add(this) + } + + @OnOpen + void open() { + opened = true + } + + @OnMessage + void onMessage(String text) { + } + + @OnClose + void onClose() { + onClosed = true + } + + @Override + void close() throws Exception { + autoClosed = true + } + } +} diff --git a/http-netty/src/main/java/io/micronaut/http/netty/websocket/AbstractNettyWebSocketHandler.java b/http-netty/src/main/java/io/micronaut/http/netty/websocket/AbstractNettyWebSocketHandler.java index b7f4d55593c..750110e3077 100644 --- a/http-netty/src/main/java/io/micronaut/http/netty/websocket/AbstractNettyWebSocketHandler.java +++ b/http-netty/src/main/java/io/micronaut/http/netty/websocket/AbstractNettyWebSocketHandler.java @@ -35,6 +35,7 @@ import io.micronaut.inject.MethodExecutionHandle; import io.micronaut.websocket.CloseReason; import io.micronaut.websocket.WebSocketPongMessage; +import io.micronaut.websocket.WebSocketSession; import io.micronaut.websocket.bind.WebSocketState; import io.micronaut.websocket.bind.WebSocketStateBinderRegistry; import io.micronaut.websocket.context.WebSocketBean; @@ -55,6 +56,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; import reactor.core.scheduler.Schedulers; import java.io.IOException; @@ -89,14 +91,11 @@ public abstract class AbstractNettyWebSocketHandler extends SimpleChannelInbound protected final HttpRequest originatingRequest; protected final MethodExecutionHandle messageHandler; protected final MethodExecutionHandle pongHandler; - protected final NettyWebSocketSession session; protected final MediaTypeCodecRegistry mediaTypeCodecRegistry; protected final WebSocketVersion webSocketVersion; protected final String subProtocol; protected final WebSocketSessionRepository webSocketSessionRepository; protected final ConversionService conversionService; - private final Argument bodyArgument; - private final Argument pongArgument; private final AtomicBoolean closed = new AtomicBoolean(false); private final AtomicReference frameBuffer = new AtomicReference<>(); @@ -135,139 +134,69 @@ protected AbstractNettyWebSocketHandler( this.pongHandler = webSocketBean.pongMethod().orElse(null); this.mediaTypeCodecRegistry = mediaTypeCodecRegistry; this.webSocketVersion = version; - this.session = createWebSocketSession(ctx); this.conversionService = conversionService; - - if (session != null) { - - ExecutableBinder binder = new DefaultExecutableBinder<>(); - - if (messageHandler != null) { - BoundExecutable bound = binder.tryBind(messageHandler.getExecutableMethod(), webSocketBinder, new WebSocketState(session, originatingRequest)); - List> unboundArguments = bound.getUnboundArguments(); - - if (unboundArguments.size() == 1) { - this.bodyArgument = unboundArguments.iterator().next(); - } else { - this.bodyArgument = null; - if (LOG.isErrorEnabled()) { - LOG.error("WebSocket @OnMessage method " + webSocketBean.getTarget() + "." + messageHandler.getExecutableMethod() + " should define exactly 1 message parameter, but found 2 possible candidates: " + unboundArguments); - } - - if (session.isOpen()) { - session.close(CloseReason.INTERNAL_ERROR); - } - } - } else { - this.bodyArgument = null; - } - - if (pongHandler != null) { - BoundExecutable bound = binder.tryBind(pongHandler.getExecutableMethod(), webSocketBinder, new WebSocketState(session, originatingRequest)); - List> unboundArguments = bound.getUnboundArguments(); - if (unboundArguments.size() == 1 && unboundArguments.get(0).isAssignableFrom(WebSocketPongMessage.class)) { - this.pongArgument = unboundArguments.get(0); - } else { - this.pongArgument = null; - if (LOG.isErrorEnabled()) { - LOG.error("WebSocket @OnMessage pong handler method " + webSocketBean.getTarget() + "." + pongHandler.getExecutableMethod() + " should define exactly 1 message parameter assignable from a WebSocketPongMessage, but found: " + unboundArguments); - } - - if (session.isOpen()) { - session.close(CloseReason.INTERNAL_ERROR); - } - } - } else { - this.pongArgument = null; - } - } else { - this.bodyArgument = null; - this.pongArgument = null; - } } /** * Calls the open method of the websocket bean. * - * @param ctx THe handler context + * @param ctx The handler context + * @return Publisher for any errors, or the result of the open method */ - protected void callOpenMethod(ChannelHandlerContext ctx) { - if (session == null) { - return; - } + protected Publisher callOpenMethod(ChannelHandlerContext ctx) { + WebSocketSession session = getSession(); Optional> executionHandle = webSocketBean.openMethod(); if (executionHandle.isPresent()) { MethodExecutionHandle openMethod = executionHandle.get(); - BoundExecutable boundExecutable = null; + + BoundExecutable boundExecutable; try { boundExecutable = bindMethod(originatingRequest, webSocketBinder, openMethod, Collections.emptyList()); } catch (Throwable e) { - if (LOG.isErrorEnabled()) { - LOG.error("Error Binding method @OnOpen for WebSocket [" + webSocketBean + "]: " + e.getMessage(), e); - } - if (session.isOpen()) { session.close(CloseReason.INTERNAL_ERROR); } + return Mono.error(e); } - if (boundExecutable != null) { - try { - BoundExecutable finalBoundExecutable = boundExecutable; - Object result = invokeExecutable(finalBoundExecutable, openMethod); - if (Publishers.isConvertibleToPublisher(result)) { - Flux flowable = Flux.from(instrumentPublisher(ctx, result)); - flowable.subscribe( - o -> { - }, - error -> { - if (LOG.isErrorEnabled()) { - LOG.error("Error Opening WebSocket [" + webSocketBean + "]: " + error.getMessage(), error); - } - if (session.isOpen()) { - session.close(CloseReason.INTERNAL_ERROR); - } - }, - () -> { - } - ); - } - } catch (Throwable e) { - forwardErrorToUser(ctx, t -> { - if (LOG.isErrorEnabled()) { - LOG.error("Error Opening WebSocket [" + webSocketBean + "]: " + t.getMessage(), t); + try { + Object result = invokeExecutable(boundExecutable, openMethod); + if (Publishers.isConvertibleToPublisher(result)) { + return Flux.from(instrumentPublisher(ctx, result)).doOnError(t -> { + if (session.isOpen()) { + session.close(CloseReason.INTERNAL_ERROR); } - }, e); - // since we failed to call onOpen, we should always close here - if (session.isOpen()) { - session.close(CloseReason.INTERNAL_ERROR); - } + }); + } else { + return Mono.empty(); + } + } catch (Throwable e) { + // since we failed to call onOpen, we should always close here + if (session.isOpen()) { + session.close(CloseReason.INTERNAL_ERROR); } + return Mono.error(e); } + } else { + return Mono.empty(); } } /** * @return The body argument for the message handler */ - public Argument getBodyArgument() { - return bodyArgument; - } + public abstract Argument getBodyArgument(); /** * @return The pong argument for the pong handler */ - public Argument getPongArgument() { - return pongArgument; - } + public abstract Argument getPongArgument(); /** * @return The session */ - public NettyWebSocketSession getSession() { - return session; - } + public abstract NettyWebSocketSession getSession(); @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { @@ -275,7 +204,7 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { forwardErrorToUser(ctx, e -> handleUnexpected(ctx, e), cause); } - private void forwardErrorToUser(ChannelHandlerContext ctx, Consumer fallback, Throwable cause) { + protected final void forwardErrorToUser(ChannelHandlerContext ctx, Consumer fallback, Throwable cause) { Optional> opt = webSocketBean.errorMethod(); if (opt.isPresent()) { @@ -447,10 +376,10 @@ protected void handleWebSocketFrame(ChannelHandlerContext ctx, WebSocketFrame ms o -> { }, error -> messageProcessingException(ctx, error), - () -> messageHandled(ctx, session, v) + () -> messageHandled(ctx, v) ); } else { - messageHandled(ctx, session, v); + messageHandled(ctx, v); } } catch (Throwable e) { messageProcessingException(ctx, e); @@ -532,10 +461,9 @@ private void messageProcessingException(ChannelHandlerContext ctx, Throwable e) * Method called once a message has been handled by the handler. * * @param ctx The channel handler context - * @param session The session * @param message The message that was handled */ - protected void messageHandled(ChannelHandlerContext ctx, NettyWebSocketSession session, Object message) { + protected void messageHandled(ChannelHandlerContext ctx, Object message) { // no-op } @@ -551,12 +479,12 @@ protected void writeCloseFrameAndTerminate(ChannelHandlerContext ctx, CloseReaso } /** - * Used to close thee session with a given reason. + * Used to close the session with a given reason. * @param ctx The context * @param cr The reason * @param writeCloseReason Whether to allow writing the close reason to the remote */ - private void handleCloseReason(ChannelHandlerContext ctx, CloseReason cr, boolean writeCloseReason) { + protected void handleCloseReason(ChannelHandlerContext ctx, CloseReason cr, boolean writeCloseReason) { cleanupBuffer(); if (closed.compareAndSet(false, true)) { if (LOG.isDebugEnabled()) { diff --git a/http-server-netty/src/main/java/io/micronaut/http/server/netty/websocket/NettyServerWebSocketHandler.java b/http-server-netty/src/main/java/io/micronaut/http/server/netty/websocket/NettyServerWebSocketHandler.java index 31963261aeb..7d0d99344bb 100644 --- a/http-server-netty/src/main/java/io/micronaut/http/server/netty/websocket/NettyServerWebSocketHandler.java +++ b/http-server-netty/src/main/java/io/micronaut/http/server/netty/websocket/NettyServerWebSocketHandler.java @@ -20,7 +20,10 @@ import io.micronaut.core.annotation.Nullable; import io.micronaut.core.async.publisher.Publishers; import io.micronaut.core.bind.BoundExecutable; +import io.micronaut.core.bind.DefaultExecutableBinder; +import io.micronaut.core.bind.ExecutableBinder; import io.micronaut.core.convert.value.ConvertibleValues; +import io.micronaut.core.type.Argument; import io.micronaut.core.type.Executable; import io.micronaut.core.util.KotlinUtils; import io.micronaut.http.HttpAttributes; @@ -36,7 +39,9 @@ import io.micronaut.inject.MethodExecutionHandle; import io.micronaut.web.router.UriRouteMatch; import io.micronaut.websocket.CloseReason; +import io.micronaut.websocket.WebSocketPongMessage; import io.micronaut.websocket.WebSocketSession; +import io.micronaut.websocket.bind.WebSocketState; import io.micronaut.websocket.context.WebSocketBean; import io.micronaut.websocket.event.WebSocketMessageProcessedEvent; import io.micronaut.websocket.event.WebSocketSessionClosedEvent; @@ -56,6 +61,7 @@ import reactor.core.scheduler.Schedulers; import java.security.Principal; +import java.util.List; import java.util.Optional; import java.util.Set; import java.util.function.Function; @@ -77,10 +83,14 @@ public class NettyServerWebSocketHandler extends AbstractNettyWebSocketHandler { */ public static final String ID = "websocket-handler"; + private final NettyWebSocketSession serverSession; private final NettyEmbeddedServices nettyEmbeddedServices; @Nullable private final CoroutineHelper coroutineHelper; + private final Argument bodyArgument; + private final Argument pongArgument; + /** * Default constructor. * @@ -114,18 +124,67 @@ public class NettyServerWebSocketHandler extends AbstractNettyWebSocketHandler { webSocketSessionRepository, nettyEmbeddedServices.getApplicationContext().getConversionService()); + this.serverSession = createWebSocketSession(ctx); + + ExecutableBinder binder = new DefaultExecutableBinder<>(); + + if (messageHandler != null) { + BoundExecutable bound = binder.tryBind(messageHandler.getExecutableMethod(), webSocketBinder, new WebSocketState(serverSession, originatingRequest)); + List> unboundArguments = bound.getUnboundArguments(); + + if (unboundArguments.size() == 1) { + this.bodyArgument = unboundArguments.iterator().next(); + } else { + this.bodyArgument = null; + if (LOG.isErrorEnabled()) { + LOG.error("WebSocket @OnMessage method " + webSocketBean.getTarget() + "." + messageHandler.getExecutableMethod() + " should define exactly 1 message parameter, but found 2 possible candidates: " + unboundArguments); + } + + if (serverSession.isOpen()) { + serverSession.close(CloseReason.INTERNAL_ERROR); + } + } + } else { + this.bodyArgument = null; + } + + if (pongHandler != null) { + BoundExecutable bound = binder.tryBind(pongHandler.getExecutableMethod(), webSocketBinder, new WebSocketState(serverSession, originatingRequest)); + List> unboundArguments = bound.getUnboundArguments(); + if (unboundArguments.size() == 1 && unboundArguments.get(0).isAssignableFrom(WebSocketPongMessage.class)) { + this.pongArgument = unboundArguments.get(0); + } else { + this.pongArgument = null; + if (LOG.isErrorEnabled()) { + LOG.error("WebSocket @OnMessage pong handler method " + webSocketBean.getTarget() + "." + pongHandler.getExecutableMethod() + " should define exactly 1 message parameter assignable from a WebSocketPongMessage, but found: " + unboundArguments); + } + + if (serverSession.isOpen()) { + serverSession.close(CloseReason.INTERNAL_ERROR); + } + } + } else { + this.pongArgument = null; + } + this.nettyEmbeddedServices = nettyEmbeddedServices; this.coroutineHelper = coroutineHelper; request.setAttribute(HttpAttributes.ROUTE_MATCH, routeMatch); request.setAttribute(HttpAttributes.ROUTE, routeMatch.getRoute()); - callOpenMethod(ctx); + Flux.from(callOpenMethod(ctx)).subscribe(v -> { }, t -> { + forwardErrorToUser(ctx, e -> { + if (LOG.isErrorEnabled()) { + LOG.error("Error Opening WebSocket [" + webSocketBean + "]: " + e.getMessage(), e); + } + }, t); + }); ApplicationEventPublisher eventPublisher = nettyEmbeddedServices.getEventPublisher(WebSocketSessionOpenEvent.class); try { - eventPublisher.publishEvent(new WebSocketSessionOpenEvent(session)); + eventPublisher.publishEvent(new WebSocketSessionOpenEvent(serverSession)); } catch (Exception e) { if (LOG.isErrorEnabled()) { LOG.error("Error publishing WebSocket opened event: " + e.getMessage(), e); @@ -133,6 +192,21 @@ public class NettyServerWebSocketHandler extends AbstractNettyWebSocketHandler { } } + @Override + public NettyWebSocketSession getSession() { + return serverSession; + } + + @Override + public Argument getBodyArgument() { + return bodyArgument; + } + + @Override + public Argument getPongArgument() { + return pongArgument; + } + @Override public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { if (evt instanceof IdleStateEvent) { @@ -277,11 +351,11 @@ private Object invokeExecutable0(BoundExecutable boundExecutable, MethodExecutio } @Override - protected void messageHandled(ChannelHandlerContext ctx, NettyWebSocketSession session, Object message) { + protected void messageHandled(ChannelHandlerContext ctx, Object message) { ctx.executor().execute(() -> { try { nettyEmbeddedServices.getEventPublisher(WebSocketMessageProcessedEvent.class) - .publishEvent(new WebSocketMessageProcessedEvent<>(session, message)); + .publishEvent(new WebSocketMessageProcessedEvent<>(getSession(), message)); } catch (Exception e) { if (LOG.isErrorEnabled()) { LOG.error("Error publishing WebSocket message processed event: " + e.getMessage(), e); @@ -295,12 +369,12 @@ public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { Channel channel = ctx.channel(); channel.attr(NettyWebSocketSession.WEB_SOCKET_SESSION_KEY).set(null); if (LOG.isDebugEnabled()) { - LOG.debug("Removing WebSocket Server session: " + session); + LOG.debug("Removing WebSocket Server session: " + serverSession); } webSocketSessionRepository.removeChannel(channel); try { nettyEmbeddedServices.getEventPublisher(WebSocketSessionClosedEvent.class) - .publishEvent(new WebSocketSessionClosedEvent(session)); + .publishEvent(new WebSocketSessionClosedEvent(serverSession)); } catch (Exception e) { if (LOG.isErrorEnabled()) { LOG.error("Error publishing WebSocket closed event: " + e.getMessage(), e); diff --git a/http-server-netty/src/test/groovy/io/micronaut/http/server/netty/cors/CorsFilterEnabledSpec.groovy b/http-server-netty/src/test/groovy/io/micronaut/http/server/netty/cors/CorsFilterEnabledSpec.groovy new file mode 100644 index 00000000000..3a94796f892 --- /dev/null +++ b/http-server-netty/src/test/groovy/io/micronaut/http/server/netty/cors/CorsFilterEnabledSpec.groovy @@ -0,0 +1,19 @@ +package io.micronaut.http.server.netty.cors + +import io.micronaut.context.ApplicationContext +import io.micronaut.http.server.cors.CorsFilter +import spock.lang.AutoCleanup +import spock.lang.Shared +import spock.lang.Specification + +class CorsFilterEnabledSpec extends Specification { + + @AutoCleanup + @Shared + ApplicationContext applicationContext = ApplicationContext.run() + + void "CorsFilter is not enabled by default"() { + expect: + !applicationContext.containsBean(CorsFilter) + } +} diff --git a/http-server-netty/src/test/groovy/io/micronaut/http/server/netty/cors/CorsFilterSpec.groovy b/http-server-netty/src/test/groovy/io/micronaut/http/server/netty/cors/CorsFilterSpec.groovy index 72f5e124c25..8553e89adb1 100644 --- a/http-server-netty/src/test/groovy/io/micronaut/http/server/netty/cors/CorsFilterSpec.groovy +++ b/http-server-netty/src/test/groovy/io/micronaut/http/server/netty/cors/CorsFilterSpec.groovy @@ -16,19 +16,26 @@ package io.micronaut.http.server.netty.cors import io.micronaut.context.ApplicationContext +import io.micronaut.core.async.publisher.Publishers +import io.micronaut.core.util.StringUtils import io.micronaut.http.* import io.micronaut.http.annotation.Controller import io.micronaut.http.annotation.Get +import io.micronaut.http.filter.ServerFilterChain import io.micronaut.http.server.HttpServerConfiguration import io.micronaut.http.server.cors.CorsFilter import io.micronaut.http.server.cors.CorsOriginConfiguration import io.micronaut.runtime.server.EmbeddedServer import io.micronaut.web.router.RouteMatch import io.micronaut.web.router.Router +import io.micronaut.web.router.UriRouteMatch import org.apache.http.client.utils.URIBuilder +import org.reactivestreams.Publisher +import reactor.core.publisher.Mono import spock.lang.AutoCleanup import spock.lang.Shared import spock.lang.Specification +import spock.lang.Unroll import java.util.stream.Collectors @@ -36,70 +43,82 @@ import static io.micronaut.http.HttpHeaders.* class CorsFilterSpec extends Specification { - @Shared @AutoCleanup + @Shared + @AutoCleanup EmbeddedServer embeddedServer = ApplicationContext.run(EmbeddedServer) - CorsFilter buildCorsHandler(HttpServerConfiguration.CorsConfiguration config) { - new CorsFilter(config ?: new HttpServerConfiguration.CorsConfiguration()) - } - - void "test handleRequest for non CORS request"() { + void "non CORS request is passed through"() { given: - def config = new HttpServerConfiguration.CorsConfiguration() - HttpRequest request = Mock(HttpRequest) - HttpHeaders headers = Mock(HttpHeaders) - request.getHeaders() >> headers - headers.getOrigin() >> Optional.empty() + HttpServerConfiguration.CorsConfiguration config = enabledCorsConfiguration() CorsFilter corsHandler = buildCorsHandler(config) + HttpRequest request = createRequest(null as String) when: - def result = corsHandler.handleRequest(request) + Optional> result = Mono.from(corsHandler.doFilter(request, okChain())).blockOptional() then: "the request is passed through" - !result.isPresent() + result.isPresent() + + when: + MutableHttpResponse response = result.get() + + then: + HttpStatus.OK == response.status() + response.headers.names().isEmpty() } - void "test handleRequest with no matching configuration"() { + void "request with origin and no matching configuration"() { given: - HttpRequest request = Mock(HttpRequest) - HttpHeaders headers = Mock(HttpHeaders) - request.getHeaders() >> headers - - def config = new HttpServerConfiguration.CorsConfiguration() + String origin = 'http://www.bar.com' + HttpRequest request = createRequest(origin) CorsOriginConfiguration originConfig = new CorsOriginConfiguration() originConfig.allowedOrigins = ['http://www.foo.com'] - config.configurations = new LinkedHashMap() - config.configurations.put('foo', originConfig) + HttpServerConfiguration.CorsConfiguration config = enabledCorsConfiguration([foo: originConfig]) CorsFilter corsHandler = buildCorsHandler(config) when: - def result = corsHandler.handleRequest(request) + Optional> result = Mono.from(corsHandler.doFilter(request, okChain())).blockOptional() + + then: + result.isPresent() + + when: + MutableHttpResponse response = result.get() then: "the request is passed through because no configuration matches the origin" - 2 * headers.getOrigin() >> Optional.of('http://www.bar.com') - !result.isPresent() + HttpStatus.OK == response.status() + response.headers.names().isEmpty() } - void "test handleRequest with regex matching configuration"() { + @Unroll + void "regex matching configuration"(List regex, String origin) { given: - HttpRequest request = Mock(HttpRequest) - HttpHeaders headers = Mock(HttpHeaders) - request.getHeaders() >> headers + HttpRequest request = createRequest(origin) request.getAttribute(HttpAttributes.ROUTE_MATCH, RouteMatch.class) >> Optional.empty() - def config = new HttpServerConfiguration.CorsConfiguration() CorsOriginConfiguration originConfig = new CorsOriginConfiguration() originConfig.allowedOrigins = regex - config.configurations = new LinkedHashMap() - config.configurations.put('foo', originConfig) + HttpServerConfiguration.CorsConfiguration config = enabledCorsConfiguration([foo: originConfig]) CorsFilter corsHandler = buildCorsHandler(config) when: - def result = corsHandler.handleRequest(request) + Optional> result = Mono.from(corsHandler.doFilter(request, okChain())).blockOptional() - then: "the request is passed through because no configuration matches the origin" - 2 * headers.getOrigin() >> Optional.of(origin) - !result.isPresent() + then: + result.isPresent() + + when: + MutableHttpResponse response = result.get() + + then: + HttpStatus.OK == response.status() + response.headers.names().size() == 3 + response.headers.find { it.key == 'Access-Control-Allow-Origin' } + response.headers.find { it.key == 'Vary' } + response.headers.find { it.key == 'Access-Control-Allow-Credentials' } + response.headers.find { it.key == 'Access-Control-Allow-Origin' }.value == [origin] + response.headers.find { it.key == 'Vary' }.value == ['Origin'] + response.headers.find { it.key == 'Access-Control-Allow-Credentials' }.value == [StringUtils.TRUE] where: regex | origin @@ -112,198 +131,251 @@ class CorsFilterSpec extends Specification { void "test handleRequest with disallowed method"() { given: - HttpRequest request = Mock(HttpRequest) - HttpHeaders headers = Mock(HttpHeaders) - request.getHeaders() >> headers + String origin = 'http://www.foo.com' + HttpRequest request = createRequest(origin) - def config = new HttpServerConfiguration.CorsConfiguration() CorsOriginConfiguration originConfig = new CorsOriginConfiguration() originConfig.allowedOrigins = ['http://www.foo.com'] originConfig.allowedMethods = [HttpMethod.GET] - config.configurations = new LinkedHashMap() - config.configurations.put('foo', originConfig) + HttpServerConfiguration.CorsConfiguration config = enabledCorsConfiguration([foo: originConfig]) + CorsFilter corsHandler = buildCorsHandler(config) when: - def result = corsHandler.handleRequest(request) + Optional> result = Mono.from(corsHandler.doFilter(request, okChain())).blockOptional() - then: "the request is rejected because the method is not in the list of allowedMethods" - 2 * headers.getOrigin() >> Optional.of('http://www.foo.com') - 1 * request.getMethod() >> HttpMethod.POST + then: result.isPresent() - result.get().status == HttpStatus.FORBIDDEN + + when: + MutableHttpResponse response = result.get() + + then: + HttpStatus.FORBIDDEN == response.status() + response.headers.names().isEmpty() } - void "test handleRequest with disallowed header (not preflight)"() { + void "with disallowed header (not preflight) the request is passed through because allowed headers are only checked for preflight requests"() { given: - HttpRequest request = Mock(HttpRequest) - HttpHeaders headers = Mock(HttpHeaders) - request.getHeaders() >> headers + String origin = 'http://www.foo.com' + HttpRequest request = createRequest(origin) + request.getMethod() >> HttpMethod.GET - def config = new HttpServerConfiguration.CorsConfiguration() CorsOriginConfiguration originConfig = new CorsOriginConfiguration() originConfig.allowedOrigins = ['http://www.foo.com'] originConfig.allowedMethods = [HttpMethod.GET] - config.configurations = new LinkedHashMap() - config.configurations.put('foo', originConfig) + HttpServerConfiguration.CorsConfiguration config = enabledCorsConfiguration([foo: originConfig]) CorsFilter corsHandler = buildCorsHandler(config) when: - def result = corsHandler.handleRequest(request) + Optional> result = Mono.from(corsHandler.doFilter(request, okChain())).blockOptional() - then: "the request is passed through because allowed headers are only checked for preflight requests" - 2 * headers.getOrigin() >> Optional.of('http://www.foo.com') - 1 * request.getMethod() >> HttpMethod.GET - !result.isPresent() - 0 * headers.get(ACCESS_CONTROL_REQUEST_HEADERS, _) + then: + result.isPresent() + + when: + MutableHttpResponse response = result.get() + + then: + HttpStatus.OK == response.status() + response.headers.names().size() == 3 + response.headers.find { it.key == 'Access-Control-Allow-Origin' } + response.headers.find { it.key == 'Vary' } + response.headers.find { it.key == 'Access-Control-Allow-Credentials' } + response.headers.find { it.key == 'Access-Control-Allow-Origin' }.value == ['http://www.foo.com'] + response.headers.find { it.key == 'Vary' }.value == ['Origin'] + response.headers.find { it.key == 'Access-Control-Allow-Credentials' }.value == [StringUtils.TRUE] } void "test preflight handleRequest with disallowed header"() { given: - HttpRequest request = Mock(HttpRequest) - HttpHeaders headers = Mock(HttpHeaders) - request.getHeaders() >> headers - def config = new HttpServerConfiguration.CorsConfiguration() + String origin = 'http://www.foo.com' + HttpHeaders headers = Stub(HttpHeaders) { + getOrigin() >> Optional.of(origin) + getFirst(ACCESS_CONTROL_REQUEST_METHOD, _) >> Optional.of(HttpMethod.GET) + get(ACCESS_CONTROL_REQUEST_HEADERS, _) >> Optional.of(['foo', 'bar']) + contains(ACCESS_CONTROL_REQUEST_METHOD) >> true + } + HttpRequest request = createRequest(headers) + request.getMethod() >> HttpMethod.OPTIONS + request.getUri() >> new URIBuilder( '/example' ).build() + List> routes = embeddedServer.getApplicationContext().getBean(Router). + findAny(request.getUri().toString(), request) + .collect(Collectors.toList()) + + request.getAttribute(HttpAttributes.AVAILABLE_HTTP_METHODS, _) >> Optional.of(routes.stream().map(route->route.getHttpMethod()).collect(Collectors.toList())) + CorsOriginConfiguration originConfig = new CorsOriginConfiguration() originConfig.allowedOrigins = ['http://www.foo.com'] originConfig.allowedMethods = [HttpMethod.GET] originConfig.allowedHeaders = ['foo'] - config.configurations = new LinkedHashMap() - config.configurations.put('foo', originConfig) + + HttpServerConfiguration.CorsConfiguration config = enabledCorsConfiguration([foo: originConfig]) + CorsFilter corsHandler = buildCorsHandler(config) - request.getMethod() >> HttpMethod.OPTIONS - def uri = new URIBuilder( '/example' ) - request.getUri() >> uri.build() - def routes = embeddedServer.getApplicationContext().getBean(Router). - findAny(request.getUri().toString(), request) - .collect(Collectors.toList()) - request.getAttribute(HttpAttributes.AVAILABLE_HTTP_METHODS, _) >> Optional.of(routes.stream().map(route->route.getHttpMethod()).collect(Collectors.toList())) + when: + Optional> result = Mono.from(corsHandler.doFilter(request, okChain())).blockOptional() + + then: + result.isPresent() when: - headers.contains(ACCESS_CONTROL_REQUEST_METHOD) >> true - def result = corsHandler.handleRequest(request) + MutableHttpResponse response = result.get() then: "the request is rejected because bar is not allowed" - 2 * headers.getOrigin() >> Optional.of('http://www.foo.com') - 1 * headers.getFirst(ACCESS_CONTROL_REQUEST_METHOD, _) >> Optional.of(HttpMethod.GET) - 1 * headers.get(ACCESS_CONTROL_REQUEST_HEADERS, _) >> Optional.of(['foo', 'bar']) - result.get().status == HttpStatus.FORBIDDEN + HttpStatus.FORBIDDEN == response.status() } - void "test preflight handleRequest with allowed header"() { + void "test preflight with allowed header"() { given: - HttpRequest request = Mock(HttpRequest) - HttpHeaders headers = Mock(HttpHeaders) - request.getHeaders() >> headers - def config = new HttpServerConfiguration.CorsConfiguration() + String origin = 'http://www.foo.com' + + HttpHeaders headers = Stub(HttpHeaders) { + getOrigin() >> Optional.of(origin) + getFirst(ACCESS_CONTROL_REQUEST_METHOD, _) >> Optional.of(HttpMethod.GET) + get(ACCESS_CONTROL_REQUEST_HEADERS, _) >> Optional.of(['foo']) + contains(ACCESS_CONTROL_REQUEST_METHOD) >> true + } + HttpRequest request = createRequest(headers) + request.getMethod() >> HttpMethod.OPTIONS + request.getUri() >> new URIBuilder( '/example' ).build() + List> routes = embeddedServer.getApplicationContext().getBean(Router). + findAny(request.getUri().toString(), request) + .collect(Collectors.toList()) + request.getAttribute(HttpAttributes.AVAILABLE_HTTP_METHODS, _) >> Optional.of(routes.stream().map(route->route.getHttpMethod()).collect(Collectors.toList())) + CorsOriginConfiguration originConfig = new CorsOriginConfiguration() originConfig.allowedOrigins = ['http://www.foo.com'] originConfig.allowedMethods = [HttpMethod.GET] originConfig.allowedHeaders = ['foo', 'bar'] - config.configurations = new LinkedHashMap() - config.configurations.put('foo', originConfig) + + HttpServerConfiguration.CorsConfiguration config = enabledCorsConfiguration([foo: originConfig]) + CorsFilter corsHandler = buildCorsHandler(config) - request.getMethod() >> HttpMethod.OPTIONS - def uri = new URIBuilder( '/example' ) - request.getUri() >> uri.build() - def routes = embeddedServer.getApplicationContext().getBean(Router). - findAny(request.getUri().toString(), request) - .collect(Collectors.toList()) - request.getAttribute(HttpAttributes.AVAILABLE_HTTP_METHODS, _) >> Optional.of(routes.stream().map(route->route.getHttpMethod()).collect(Collectors.toList())) when: - headers.contains(ACCESS_CONTROL_REQUEST_METHOD) >> true - def result = corsHandler.handleRequest(request) - - then: "the request is successful" - 4 * headers.getOrigin() >> Optional.of('http://www.foo.com') - 2 * headers.getFirst(ACCESS_CONTROL_REQUEST_METHOD, _) >> Optional.of(HttpMethod.GET) - 2 * headers.get(ACCESS_CONTROL_REQUEST_HEADERS, _) >> Optional.of(['foo']) - result.get().status == HttpStatus.OK + Optional> result = Mono.from(corsHandler.doFilter(request, okChain())).blockOptional() + + then: + result.isPresent() + + when: + MutableHttpResponse response = result.get() + + then: + HttpStatus.OK == response.status() + response.headers.names().size() == 6 + response.headers.find { it.key == 'Access-Control-Allow-Origin' } + response.headers.find { it.key == 'Vary' } + response.headers.find { it.key == 'Access-Control-Allow-Credentials' } + response.headers.find { it.key == 'Access-Control-Allow-Methods' } + response.headers.find { it.key == 'Access-Control-Allow-Headers' } + response.headers.find { it.key == 'Access-Control-Max-Age' } + response.headers.find { it.key == 'Access-Control-Allow-Origin' }.value == ['http://www.foo.com'] + response.headers.find { it.key == 'Vary' }.value == ['Origin'] + response.headers.find { it.key == 'Access-Control-Allow-Credentials' }.value == [StringUtils.TRUE] + response.headers.find { it.key == 'Access-Control-Allow-Methods' }.value == ['GET'] + response.headers.find { it.key == 'Access-Control-Allow-Headers' }.value == ['foo'] + response.headers.find { it.key == 'Access-Control-Max-Age' }.value == ['1800'] } void "test handleResponse when configuration not present"() { given: - def config = new HttpServerConfiguration.CorsConfiguration() + String origin = 'http://www.bar.com' + HttpServerConfiguration.CorsConfiguration config = new HttpServerConfiguration.CorsConfiguration() CorsOriginConfiguration originConfig = new CorsOriginConfiguration() originConfig.allowedOrigins = ['http://www.foo.com'] - config.configurations = new LinkedHashMap() - config.configurations.put('foo', originConfig) + config.setConfigurations([foo: originConfig]) CorsFilter corsHandler = buildCorsHandler(config) - HttpRequest request = Mock(HttpRequest) - HttpHeaders headers = Mock(HttpHeaders) - request.getHeaders() >> headers - + HttpHeaders headers = Stub(HttpHeaders) { + getOrigin() >> Optional.of(origin) + } + HttpRequest request = Stub(HttpRequest) { + getHeaders() >> headers + } when: - def result = corsHandler.handleRequest(request) + Optional> result = corsHandler.handleRequest(request) then: "the response is not modified" - 2 * headers.getOrigin() >> Optional.of('http://www.bar.com') notThrown(NullPointerException) !result.isPresent() } - void "test handleResponse for normal request"() { + void "verify behaviour for normal request"() { given: - def config = new HttpServerConfiguration.CorsConfiguration() + String origin = 'http://www.foo.com' + HttpHeaders headers = Stub(HttpHeaders) { + getOrigin() >> Optional.of(origin) + contains(ACCESS_CONTROL_REQUEST_METHOD) >> true + } + HttpRequest request = Stub(HttpRequest) { + getHeaders() >> headers + } + CorsOriginConfiguration originConfig = new CorsOriginConfiguration() originConfig.exposedHeaders = ['Foo-Header', 'Bar-Header'] - config.configurations = new LinkedHashMap() - config.configurations.put('foo', originConfig) + + HttpServerConfiguration.CorsConfiguration config = enabledCorsConfiguration([foo: originConfig]) CorsFilter corsHandler = buildCorsHandler(config) - HttpRequest request = Mock(HttpRequest) - HttpHeaders headers = Mock(HttpHeaders) - request.getHeaders() >> headers - headers.getOrigin() >> Optional.of('http://www.foo.com') when: - headers.contains(ACCESS_CONTROL_REQUEST_METHOD) >> true - def result = corsHandler.handleRequest(request) + Optional> result = Mono.from(corsHandler.doFilter(request, okChain())).blockOptional() then: - !result.isPresent() + result.isPresent() when: - MutableHttpResponse response = HttpResponse.ok() - corsHandler.handleResponse(request, response) + MutableHttpResponse response = result.get() - then: "the response is not modified" + then: + HttpStatus.OK == response.status() + response.headers.names().size() == 5 response.getHeaders().get(ACCESS_CONTROL_ALLOW_ORIGIN) == 'http://www.foo.com' // The origin is echo'd response.getHeaders().get(VARY) == 'Origin' // The vary header is set response.getHeaders().getAll(ACCESS_CONTROL_EXPOSE_HEADERS) == ['Foo-Header', 'Bar-Header' ]// Expose headers are set from config response.getHeaders().get(ACCESS_CONTROL_ALLOW_CREDENTIALS) == 'true' // Allow credentials header is set + response.getHeaders().get(ACCESS_CONTROL_MAX_AGE) == '1800' } void "test handleResponse for preflight request"() { given: - def config = new HttpServerConfiguration.CorsConfiguration() + HttpHeaders headers = Stub(HttpHeaders) { + contains(ACCESS_CONTROL_REQUEST_METHOD) >> true + get(ACCESS_CONTROL_REQUEST_HEADERS, _) >> Optional.of(['X-Header', 'Y-Header']) + getFirst(ACCESS_CONTROL_REQUEST_METHOD, _) >> Optional.of(HttpMethod.GET) + getOrigin() >> Optional.of('http://www.foo.com') + } + URI uri = new URIBuilder('/example').build() + HttpRequest request = Stub(HttpRequest) { + getHeaders() >> headers + getMethod() >> HttpMethod.OPTIONS + getUri() >> uri + } + List> routes = embeddedServer.getApplicationContext().getBean(Router). + findAny(uri.toString(), request) + .collect(Collectors.toList()) + request.getAttribute(HttpAttributes.AVAILABLE_HTTP_METHODS, _) >> Optional.of(routes.stream().map(route -> route.getHttpMethod()).collect(Collectors.toList())) + CorsOriginConfiguration originConfig = new CorsOriginConfiguration() originConfig.exposedHeaders = ['Foo-Header', 'Bar-Header'] - config.configurations = new LinkedHashMap() - config.configurations.put('foo', originConfig) + HttpServerConfiguration.CorsConfiguration config = enabledCorsConfiguration([foo: originConfig]) + CorsFilter corsHandler = buildCorsHandler(config) - HttpRequest request = Mock(HttpRequest) - HttpHeaders headers = Mock(HttpHeaders) - request.getHeaders() >> headers - headers.getOrigin() >> Optional.of('http://www.foo.com') - request.getMethod() >> HttpMethod.OPTIONS - def uri = new URIBuilder( '/example' ) - request.getUri() >> uri.build() - def routes = embeddedServer.getApplicationContext().getBean(Router). - findAny(request.getUri().toString(), request) - .collect(Collectors.toList()) - request.getAttribute(HttpAttributes.AVAILABLE_HTTP_METHODS, _) >> Optional.of(routes.stream().map(route->route.getHttpMethod()).collect(Collectors.toList())) + when: + Optional> result = Mono.from(corsHandler.doFilter(request, okChain())).blockOptional() + + then: + result.isPresent() when: - headers.contains(ACCESS_CONTROL_REQUEST_METHOD) >> true - HttpResponse response = corsHandler.handleRequest(request).get() + MutableHttpResponse response = result.get() - then: "the response is not modified" - 2 * headers.get(ACCESS_CONTROL_REQUEST_HEADERS, _) >> Optional.of(['X-Header', 'Y-Header']) - 2 * headers.getFirst(ACCESS_CONTROL_REQUEST_METHOD, _) >> Optional.of(HttpMethod.GET) + then: + HttpStatus.OK == response.status() + response.headers.names().size() == 7 response.getHeaders().get(ACCESS_CONTROL_ALLOW_METHODS) == 'GET' response.getHeaders().get(ACCESS_CONTROL_ALLOW_ORIGIN) == 'http://www.foo.com' // The origin is echo'd response.getHeaders().get(VARY) == 'Origin' // The vary header is set @@ -315,32 +387,44 @@ class CorsFilterSpec extends Specification { void "test handleResponse for preflight request with single header"() { given: - def config = new HttpServerConfiguration.CorsConfiguration(singleHeader: true) CorsOriginConfiguration originConfig = new CorsOriginConfiguration() originConfig.exposedHeaders = ['Foo-Header', 'Bar-Header'] - config.configurations = new LinkedHashMap() - config.configurations.put('foo', originConfig) + + HttpServerConfiguration.CorsConfiguration config = new HttpServerConfiguration.CorsConfiguration(singleHeader: true, enabled: true) + config.setConfigurations([foo: originConfig]) + CorsFilter corsHandler = buildCorsHandler(config) - HttpRequest request = Mock(HttpRequest) - HttpHeaders headers = Mock(HttpHeaders) - request.getHeaders() >> headers - headers.getOrigin() >> Optional.of('http://www.foo.com') - request.getMethod() >> HttpMethod.OPTIONS - def uri = new URIBuilder( '/example' ) - request.getUri() >> uri.build() - def routes = embeddedServer.getApplicationContext().getBean(Router). + + HttpHeaders headers = Stub(HttpHeaders) { + getOrigin() >> Optional.of('http://www.foo.com') + contains(ACCESS_CONTROL_REQUEST_METHOD) >> true + get(ACCESS_CONTROL_REQUEST_HEADERS, _) >> Optional.of(['X-Header', 'Y-Header']) + getFirst(ACCESS_CONTROL_REQUEST_METHOD, _) >> Optional.of(HttpMethod.GET) + } + URI uri = new URIBuilder( '/example' ).build() + HttpRequest request = Stub(HttpRequest) { + getHeaders() >> headers + getMethod() >> HttpMethod.OPTIONS + getUri() >> uri + } + List> routes = embeddedServer.getApplicationContext().getBean(Router). findAny(request.getUri().toString(), request) .collect(Collectors.toList()) - request.getAttribute(HttpAttributes.AVAILABLE_HTTP_METHODS, _) >> Optional.of(routes.stream().map(route->route.getHttpMethod()).collect(Collectors.toList())) when: - headers.contains(ACCESS_CONTROL_REQUEST_METHOD) >> true - HttpResponse response = corsHandler.handleRequest(request).get() + Optional> result = Mono.from(corsHandler.doFilter(request, okChain())).blockOptional() + + then: + result.isPresent() + + when: + MutableHttpResponse response = result.get() + + then: + HttpStatus.OK == response.status() then: "the response is not modified" - 2 * headers.get(ACCESS_CONTROL_REQUEST_HEADERS, _) >> Optional.of(['X-Header', 'Y-Header']) - 2 * headers.getFirst(ACCESS_CONTROL_REQUEST_METHOD, _) >> Optional.of(HttpMethod.GET) response.getHeaders().get(ACCESS_CONTROL_ALLOW_METHODS) == 'GET' response.getHeaders().get(ACCESS_CONTROL_ALLOW_ORIGIN) == 'http://www.foo.com' // The origin is echo'd response.getHeaders().get(VARY) == 'Origin' // The vary header is set @@ -352,63 +436,84 @@ class CorsFilterSpec extends Specification { void "test preflight handleRequest on route that doesn't exists"() { given: - HttpRequest request = Mock(HttpRequest) - HttpHeaders headers = Mock(HttpHeaders) - request.getHeaders() >> headers - def uri = new URIBuilder( '/doesnt-exists-route' ) - request.getUri() >> uri.build() - def config = new HttpServerConfiguration.CorsConfiguration() + String origin = 'http://www.foo.com' + HttpHeaders headers = Stub(HttpHeaders) { + getFirst(ACCESS_CONTROL_REQUEST_METHOD, _) >> Optional.of(HttpMethod.GET) + getOrigin() >> Optional.of(origin) + contains(ACCESS_CONTROL_REQUEST_METHOD) >> true + } + URI uri = new URIBuilder( '/doesnt-exists-route' ).build() + HttpRequest request = Stub(HttpRequest) { + getHeaders() >> headers + getUri() >> uri + getMethod() >> HttpMethod.OPTIONS + } + List> routes = embeddedServer.getApplicationContext().getBean(Router). + findAny(uri.toString(), request) + .collect(Collectors.toList()) + request.getAttribute(HttpAttributes.AVAILABLE_HTTP_METHODS, _) >> Optional.of(routes.stream().map(route->route.getHttpMethod()).collect(Collectors.toList())) + CorsOriginConfiguration originConfig = new CorsOriginConfiguration() originConfig.allowedOrigins = ['http://www.foo.com'] originConfig.allowedMethods = [HttpMethod.GET] originConfig.allowedHeaders = ['foo', 'bar'] - config.configurations = new LinkedHashMap() - config.configurations.put('foo', originConfig) + + HttpServerConfiguration.CorsConfiguration config = enabledCorsConfiguration([foo: originConfig]) + CorsFilter corsHandler = buildCorsHandler(config) - request.getMethod() >> HttpMethod.OPTIONS - def routes = embeddedServer.getApplicationContext().getBean(Router). - findAny(request.getUri().toString(), request) - .collect(Collectors.toList()) - request.getAttribute(HttpAttributes.AVAILABLE_HTTP_METHODS, _) >> Optional.of(routes.stream().map(route->route.getHttpMethod()).collect(Collectors.toList())) + when: + Optional> result = Mono.from(corsHandler.doFilter(request, okChain())).blockOptional() + + then: + result.isPresent() when: - headers.contains(ACCESS_CONTROL_REQUEST_METHOD) >> true - 1 * headers.getFirst(ACCESS_CONTROL_REQUEST_METHOD, _) >> Optional.of(HttpMethod.GET) - def result = corsHandler.handleRequest(request) + MutableHttpResponse response = result.get() - then: "the request is successful" - 2 * headers.getOrigin() >> Optional.of('http://www.foo.com') - !result.isPresent() + then: + HttpStatus.OK == response.status() } void "test preflight handleRequest on route that does exist but doesn't handle requested HTTP Method"() { given: - def config = new HttpServerConfiguration.CorsConfiguration() + CorsOriginConfiguration originConfig = new CorsOriginConfiguration() originConfig.exposedHeaders = ['Foo-Header', 'Bar-Header'] - config.configurations = new LinkedHashMap() - config.configurations.put('foo', originConfig) + + HttpServerConfiguration.CorsConfiguration config = enabledCorsConfiguration([foo: originConfig]) + CorsFilter corsHandler = buildCorsHandler(config) - HttpRequest request = Mock(HttpRequest) - HttpHeaders headers = Mock(HttpHeaders) - request.getHeaders() >> headers - headers.getOrigin() >> Optional.of('http://www.foo.com') - request.getMethod() >> HttpMethod.OPTIONS - def uri = new URIBuilder( '/example' ) - request.getUri() >> uri.build() - def routes = embeddedServer.getApplicationContext().getBean(Router). + + String origin = 'http://www.foo.com' + HttpHeaders headers = Stub(HttpHeaders) { + getOrigin() >> Optional.of(origin) + getFirst(ACCESS_CONTROL_REQUEST_METHOD, _) >> Optional.of(HttpMethod.POST) + contains(ACCESS_CONTROL_REQUEST_METHOD) >> true + } + URI uri = new URIBuilder( '/example' ).build() + HttpRequest request = Stub(HttpRequest) { + getHeaders() >> headers + getMethod() >> HttpMethod.OPTIONS + getUri() >> uri + } + + List> routes = embeddedServer.getApplicationContext().getBean(Router). findAny(request.getUri().toString(), request) .collect(Collectors.toList()) - request.getAttribute(HttpAttributes.AVAILABLE_HTTP_METHODS, _) >> Optional.of(routes.stream().map(route->route.getHttpMethod()).collect(Collectors.toList())) + when: - headers.contains(ACCESS_CONTROL_REQUEST_METHOD) >> true - def result = corsHandler.handleRequest(request) + Optional> result = Mono.from(corsHandler.doFilter(request, okChain())).blockOptional() - then: "the request is successful" - 1 * headers.getFirst(ACCESS_CONTROL_REQUEST_METHOD, _) >> Optional.of(HttpMethod.POST) - !result.isPresent() + then: + result.isPresent() + + when: + MutableHttpResponse response = result.get() + + then: + HttpStatus.OK == response.status() } @Controller @@ -417,4 +522,43 @@ class CorsFilterSpec extends Specification { @Get("/example") String example() { return "Example"} } + + private HttpRequest createRequest(String originHeader) { + HttpHeaders headers = Stub(HttpHeaders) { + getOrigin() >> Optional.ofNullable(originHeader) + } + createRequest(headers) + } + + private HttpRequest createRequest(HttpHeaders headers) { + Stub(HttpRequest) { + getHeaders() >> headers + } + } + + private ServerFilterChain okChain() { + new ServerFilterChain() { + @Override + Publisher> proceed(HttpRequest req) { + Publishers.just(HttpResponse.ok()) + } + } + } + + private HttpServerConfiguration.CorsConfiguration enabledCorsConfiguration(Map corsConfigurationMap = null) { + HttpServerConfiguration.CorsConfiguration config = new HttpServerConfiguration.CorsConfiguration() { + @Override + boolean isEnabled() { + true + } + } + if (corsConfigurationMap != null) { + config.setConfigurations(corsConfigurationMap) + } + config + } + + private CorsFilter buildCorsHandler(HttpServerConfiguration.CorsConfiguration config) { + new CorsFilter(config ?: enabledCorsConfiguration()) + } } diff --git a/http-server-netty/src/test/groovy/io/micronaut/http/server/netty/cors/CorsOriginConverterEnabledSpec.groovy b/http-server-netty/src/test/groovy/io/micronaut/http/server/netty/cors/CorsOriginConverterEnabledSpec.groovy new file mode 100644 index 00000000000..2f98a0bfca4 --- /dev/null +++ b/http-server-netty/src/test/groovy/io/micronaut/http/server/netty/cors/CorsOriginConverterEnabledSpec.groovy @@ -0,0 +1,19 @@ +package io.micronaut.http.server.netty.cors + +import io.micronaut.context.ApplicationContext +import io.micronaut.http.server.cors.CorsOriginConverter +import spock.lang.AutoCleanup +import spock.lang.Shared +import spock.lang.Specification + +class CorsOriginConverterEnabledSpec extends Specification { + + @AutoCleanup + @Shared + ApplicationContext applicationContext = ApplicationContext.run() + + void "CorsOriginConverter is not enabled by default"() { + expect: + !applicationContext.containsBean(CorsOriginConverter) + } +} diff --git a/http-server/src/main/java/io/micronaut/http/server/cors/CorsFilter.java b/http-server/src/main/java/io/micronaut/http/server/cors/CorsFilter.java index 705b5c55fea..ddc8d63a0e5 100644 --- a/http-server/src/main/java/io/micronaut/http/server/cors/CorsFilter.java +++ b/http-server/src/main/java/io/micronaut/http/server/cors/CorsFilter.java @@ -15,6 +15,8 @@ */ package io.micronaut.http.server.cors; +import io.micronaut.core.annotation.NonNull; +import io.micronaut.core.annotation.Nullable; import io.micronaut.core.async.publisher.Publishers; import io.micronaut.core.convert.ArgumentConversionContext; import io.micronaut.core.convert.ConversionContext; @@ -31,11 +33,13 @@ import io.micronaut.http.filter.ServerFilterChain; import io.micronaut.http.filter.ServerFilterPhase; import io.micronaut.http.server.HttpServerConfiguration; +import org.jetbrains.annotations.NotNull; import org.reactivestreams.Publisher; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import java.util.ArrayList; import java.util.List; -import java.util.Map; import java.util.Optional; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -43,6 +47,7 @@ import static io.micronaut.http.HttpAttributes.AVAILABLE_HTTP_METHODS; import static io.micronaut.http.HttpHeaders.*; +import static io.micronaut.http.annotation.Filter.MATCH_ALL_PATTERN; /** * Responsible for handling CORS requests and responses. @@ -51,9 +56,9 @@ * @author Graeme Rocher * @since 1.0 */ -@Filter("/**") +@Filter(MATCH_ALL_PATTERN) public class CorsFilter implements HttpServerFilter { - + private static final Logger LOG = LoggerFactory.getLogger(CorsFilter.class); private static final ArgumentConversionContext CONVERSION_CONTEXT_HTTP_METHOD = ImmutableArgumentConversionContext.of(HttpMethod.class); protected final HttpServerConfiguration.CorsConfiguration corsConfiguration; @@ -67,19 +72,23 @@ public CorsFilter(HttpServerConfiguration.CorsConfiguration corsConfiguration) { @Override public Publisher> doFilter(HttpRequest request, ServerFilterChain chain) { - boolean originHeaderPresent = request.getHeaders().getOrigin().isPresent(); - if (originHeaderPresent) { - MutableHttpResponse response = handleRequest(request).orElse(null); - if (response != null) { - return Publishers.just(response); - } else { - return Publishers.then(chain.proceed(request), mutableHttpResponse -> { - handleResponse(request, mutableHttpResponse); - }); - } - } else { + String origin = request.getHeaders().getOrigin().orElse(null); + if (origin == null) { + LOG.trace("Http Header " + HttpHeaders.ORIGIN + " not present. Proceeding with the request."); return chain.proceed(request); } + CorsOriginConfiguration corsOriginConfiguration = getConfiguration(origin).orElse(null); + if (corsOriginConfiguration != null) { + if (CorsUtil.isPreflightRequest(request)) { + return handlePreflightRequest(request, chain, corsOriginConfiguration); + } + if (!validateMethodToMatch(request, corsOriginConfiguration).isPresent()) { + return forbidden(); + } + return Publishers.then(chain.proceed(request), resp -> decorateResponseWithHeaders(request, resp, corsOriginConfiguration)); + } + LOG.trace("CORS configuration not found for {} origin", origin); + return chain.proceed(request); } @Override @@ -92,34 +101,10 @@ public int getOrder() { * * @param request The {@link HttpRequest} object * @param response The {@link MutableHttpResponse} object + * @deprecated not used */ + @Deprecated protected void handleResponse(HttpRequest request, MutableHttpResponse response) { - HttpHeaders headers = request.getHeaders(); - Optional originHeader = headers.getOrigin(); - originHeader.ifPresent(requestOrigin -> { - - Optional optionalConfig = getConfiguration(requestOrigin); - - if (optionalConfig.isPresent()) { - CorsOriginConfiguration config = optionalConfig.get(); - - if (CorsUtil.isPreflightRequest(request)) { - Optional result = headers.getFirst(ACCESS_CONTROL_REQUEST_METHOD, CONVERSION_CONTEXT_HTTP_METHOD); - setAllowMethods(result.get(), response); - Optional> allowedHeaders = headers.get(ACCESS_CONTROL_REQUEST_HEADERS, ConversionContext.LIST_OF_STRING); - allowedHeaders.ifPresent(val -> - setAllowHeaders(val, response) - ); - - setMaxAge(config.getMaxAge(), response); - } - - setOrigin(requestOrigin, response); - setVary(response); - setExposeHeaders(config.getExposedHeaders(), response); - setAllowCredentials(config, response); - } - }); } /** @@ -127,54 +112,21 @@ protected void handleResponse(HttpRequest request, MutableHttpResponse res * * @param request The {@link HttpRequest} object * @return An optional {@link MutableHttpResponse}. The request should proceed normally if empty + * @deprecated Not used any more. */ + @Deprecated protected Optional> handleRequest(HttpRequest request) { - HttpHeaders headers = request.getHeaders(); - Optional originHeader = headers.getOrigin(); - if (originHeader.isPresent()) { - - String requestOrigin = originHeader.get(); - boolean preflight = CorsUtil.isPreflightRequest(request); - - Optional optionalConfig = getConfiguration(requestOrigin); - - if (optionalConfig.isPresent()) { - CorsOriginConfiguration config = optionalConfig.get(); - - HttpMethod requestMethod = request.getMethod(); - - List allowedMethods = config.getAllowedMethods(); - HttpMethod methodToMatch = preflight ? headers.getFirst(ACCESS_CONTROL_REQUEST_METHOD, CONVERSION_CONTEXT_HTTP_METHOD).orElse(requestMethod) : requestMethod; - - if (!isAnyMethod(allowedMethods)) { - if (allowedMethods.stream().noneMatch(method -> method.equals(methodToMatch))) { - return Optional.of(HttpResponse.status(HttpStatus.FORBIDDEN)); - } - } - - Optional> availableHttpMethods = (Optional>) request.getAttribute(AVAILABLE_HTTP_METHODS, new ArrayList().getClass()); - - if (preflight && availableHttpMethods.isPresent() && availableHttpMethods.get().stream().anyMatch(method -> method.equals(methodToMatch))) { - Optional> accessControlHeaders = headers.get(ACCESS_CONTROL_REQUEST_HEADERS, ConversionContext.LIST_OF_STRING); - - List allowedHeaders = config.getAllowedHeaders(); - - if (!isAny(allowedHeaders) && accessControlHeaders.isPresent()) { - if (!accessControlHeaders.get().stream() - .allMatch(header -> allowedHeaders.stream() - .anyMatch(allowedHeader -> allowedHeader.equalsIgnoreCase(header.trim())))) { - return Optional.of(HttpResponse.status(HttpStatus.FORBIDDEN)); - } - } + return Optional.empty(); + } - MutableHttpResponse ok = HttpResponse.ok(); - handleResponse(request, ok); - return Optional.of(ok); - } - } + @NonNull + private Optional validateMethodToMatch(@NonNull HttpRequest request, + @NonNull CorsOriginConfiguration config) { + HttpMethod methodToMatch = methodToMatch(request); + if (!methodAllowed(config, methodToMatch)) { + return Optional.empty(); } - - return Optional.empty(); + return Optional.of(methodToMatch); } /** @@ -213,15 +165,17 @@ protected void setVary(MutableHttpResponse response) { * @param origin The origin * @param response The {@link MutableHttpResponse} object */ - protected void setOrigin(String origin, MutableHttpResponse response) { - response.header(ACCESS_CONTROL_ALLOW_ORIGIN, origin); + protected void setOrigin(@Nullable String origin, @NonNull MutableHttpResponse response) { + if (origin != null) { + response.header(ACCESS_CONTROL_ALLOW_ORIGIN, origin); + } } /** * @param method The {@link HttpMethod} object * @param response The {@link MutableHttpResponse} object */ - protected void setAllowMethods(HttpMethod method, MutableHttpResponse response) { + protected void setAllowMethods(HttpMethod method, MutableHttpResponse response) { response.header(ACCESS_CONTROL_ALLOW_METHODS, method); } @@ -229,7 +183,7 @@ protected void setAllowMethods(HttpMethod method, MutableHttpResponse response) * @param optionalAllowHeaders A list with optional allow headers * @param response The {@link MutableHttpResponse} object */ - protected void setAllowHeaders(List optionalAllowHeaders, MutableHttpResponse response) { + protected void setAllowHeaders(List optionalAllowHeaders, MutableHttpResponse response) { List allowHeaders = optionalAllowHeaders.stream().map(Object::toString).collect(Collectors.toList()); if (corsConfiguration.isSingleHeader()) { String headerValue = String.join(",", allowHeaders); @@ -248,38 +202,28 @@ protected void setAllowHeaders(List optionalAllowHeaders, MutableHttpResponse * @param maxAge The max age * @param response The {@link MutableHttpResponse} object */ - protected void setMaxAge(long maxAge, MutableHttpResponse response) { + protected void setMaxAge(long maxAge, MutableHttpResponse response) { if (maxAge > -1) { response.header(ACCESS_CONTROL_MAX_AGE, Long.toString(maxAge)); } } - private Optional getConfiguration(String requestOrigin) { - Map corsConfigurations = corsConfiguration.getConfigurations(); - for (Map.Entry config : corsConfigurations.entrySet()) { - List allowedOrigins = config.getValue().getAllowedOrigins(); - if (!allowedOrigins.isEmpty()) { - boolean matches = false; - if (isAny(allowedOrigins)) { - matches = true; - } - if (!matches) { - matches = allowedOrigins.stream().anyMatch(origin -> { - if (origin.equals(requestOrigin)) { - return true; - } - Pattern p = Pattern.compile(origin); - Matcher m = p.matcher(requestOrigin); - return m.matches(); - }); - } + @NonNull + private Optional getConfiguration(@NonNull String requestOrigin) { + return corsConfiguration.getConfigurations().values().stream() + .filter(config -> { + List allowedOrigins = config.getAllowedOrigins(); + return !allowedOrigins.isEmpty() && (isAny(allowedOrigins) || allowedOrigins.stream().anyMatch(origin -> matchesOrigin(origin, requestOrigin))); + }).findFirst(); + } - if (matches) { - return Optional.of(config.getValue()); - } - } + private boolean matchesOrigin(@NonNull String origin, @NonNull String requestOrigin) { + if (origin.equals(requestOrigin)) { + return true; } - return Optional.empty(); + Pattern p = Pattern.compile(origin); + Matcher m = p.matcher(requestOrigin); + return m.matches(); } private boolean isAny(List values) { @@ -289,4 +233,95 @@ private boolean isAny(List values) { private boolean isAnyMethod(List allowedMethods) { return allowedMethods == CorsOriginConfiguration.ANY_METHOD; } + + private boolean methodAllowed(@NonNull CorsOriginConfiguration config, + @NonNull HttpMethod methodToMatch) { + List allowedMethods = config.getAllowedMethods(); + return isAnyMethod(allowedMethods) || allowedMethods.stream().anyMatch(method -> method.equals(methodToMatch)); + } + + @NonNull + private HttpMethod methodToMatch(@NonNull HttpRequest request) { + HttpMethod requestMethod = request.getMethod(); + return CorsUtil.isPreflightRequest(request) ? request.getHeaders().getFirst(ACCESS_CONTROL_REQUEST_METHOD, CONVERSION_CONTEXT_HTTP_METHOD).orElse(requestMethod) : requestMethod; + } + + private boolean hasAllowedHeaders(@NonNull HttpRequest request, @NonNull CorsOriginConfiguration config) { + Optional> accessControlHeaders = request.getHeaders().get(ACCESS_CONTROL_REQUEST_HEADERS, ConversionContext.LIST_OF_STRING); + List allowedHeaders = config.getAllowedHeaders(); + return isAny(allowedHeaders) || ( + accessControlHeaders.isPresent() && + accessControlHeaders.get().stream().allMatch(header -> allowedHeaders.stream().anyMatch(allowedHeader -> allowedHeader.equalsIgnoreCase(header.trim()))) + ); + } + + @NotNull + private static Publisher> forbidden() { + return Publishers.just(HttpResponse.status(HttpStatus.FORBIDDEN)); + } + + @NonNull + private void decorateResponseWithHeadersForPreflightRequest(@NonNull HttpRequest request, + @NonNull MutableHttpResponse response, + @NonNull CorsOriginConfiguration config) { + HttpHeaders headers = request.getHeaders(); + headers.getFirst(ACCESS_CONTROL_REQUEST_METHOD, CONVERSION_CONTEXT_HTTP_METHOD) + .ifPresent(methods -> setAllowMethods(methods, response)); + headers.get(ACCESS_CONTROL_REQUEST_HEADERS, ConversionContext.LIST_OF_STRING) + .ifPresent(val -> setAllowHeaders(val, response)); + setMaxAge(config.getMaxAge(), response); + } + + @NonNull + private void decorateResponseWithHeaders(@NonNull HttpRequest request, + @NonNull MutableHttpResponse response, + @NonNull CorsOriginConfiguration config) { + HttpHeaders headers = request.getHeaders(); + setOrigin(headers.getOrigin().orElse(null), response); + setVary(response); + setExposeHeaders(config.getExposedHeaders(), response); + setAllowCredentials(config, response); + } + + @NonNull + private Publisher> handlePreflightRequest(@NonNull HttpRequest request, + @NonNull ServerFilterChain chain, + @NonNull CorsOriginConfiguration corsOriginConfiguration) { + Optional statusOptional = validatePreflightRequest(request, corsOriginConfiguration); + if (statusOptional.isPresent()) { + HttpStatus status = statusOptional.get(); + if (status.getCode() >= 400) { + return Publishers.just(HttpResponse.status(status)); + } + MutableHttpResponse resp = HttpResponse.status(status); + decorateResponseWithHeadersForPreflightRequest(request, resp, corsOriginConfiguration); + decorateResponseWithHeaders(request, resp, corsOriginConfiguration); + return Publishers.just(resp); + } + return Publishers.then(chain.proceed(request), resp -> { + decorateResponseWithHeadersForPreflightRequest(request, resp, corsOriginConfiguration); + decorateResponseWithHeaders(request, resp, corsOriginConfiguration); + }); + } + + @NonNull + private Optional validatePreflightRequest(@NonNull HttpRequest request, + @NonNull CorsOriginConfiguration config) { + Optional methodToMatchOptional = validateMethodToMatch(request, config); + if (!methodToMatchOptional.isPresent()) { + return Optional.of(HttpStatus.FORBIDDEN); + } + HttpMethod methodToMatch = methodToMatchOptional.get(); + + Optional> availableHttpMethods = (Optional>) request.getAttribute(AVAILABLE_HTTP_METHODS, new ArrayList().getClass()); + if (CorsUtil.isPreflightRequest(request) && + availableHttpMethods.isPresent() && + availableHttpMethods.get().stream().anyMatch(method -> method.equals(methodToMatch))) { + if (!hasAllowedHeaders(request, config)) { + return Optional.of(HttpStatus.FORBIDDEN); + } + return Optional.of(HttpStatus.OK); + } + return Optional.empty(); + } } diff --git a/http-server/src/test/groovy/io/micronaut/http/server/util/MockHttpHeaders.java b/http-server/src/test/groovy/io/micronaut/http/server/util/MockHttpHeaders.java index 3942ae6f3d8..4885e543ec8 100644 --- a/http-server/src/test/groovy/io/micronaut/http/server/util/MockHttpHeaders.java +++ b/http-server/src/test/groovy/io/micronaut/http/server/util/MockHttpHeaders.java @@ -89,7 +89,7 @@ public Collection> values() { @Override public Optional get(CharSequence name, ArgumentConversionContext conversionContext) { - return ConversionService.SHARED.convert(get(name), conversionContext); + return conversionService.convert(get(name), conversionContext); } @Override diff --git a/http-validation/src/main/java/io/micronaut/validation/routes/rules/MissingParameterRule.java b/http-validation/src/main/java/io/micronaut/validation/routes/rules/MissingParameterRule.java index 313dab5a412..02953bd2918 100644 --- a/http-validation/src/main/java/io/micronaut/validation/routes/rules/MissingParameterRule.java +++ b/http-validation/src/main/java/io/micronaut/validation/routes/rules/MissingParameterRule.java @@ -15,11 +15,13 @@ */ package io.micronaut.validation.routes.rules; +import io.micronaut.core.annotation.AnnotatedElement; import io.micronaut.core.bind.annotation.Bindable; +import io.micronaut.core.naming.Named; import io.micronaut.http.uri.UriMatchTemplate; +import io.micronaut.inject.ast.ClassElement; import io.micronaut.inject.ast.MethodElement; import io.micronaut.inject.ast.ParameterElement; -import io.micronaut.inject.ast.PropertyElement; import io.micronaut.validation.routes.RouteValidationResult; import java.util.ArrayList; @@ -27,8 +29,10 @@ import java.util.LinkedHashSet; import java.util.List; import java.util.Objects; +import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; +import java.util.stream.Stream; /** * Validates all route uri variables are present in the route arguments. @@ -48,16 +52,16 @@ public RouteValidationResult validate(List templates, Paramete .filter(p -> p.hasAnnotation("io.micronaut.http.annotation.Body")) .map(ParameterElement::getType) .filter(Objects::nonNull) - .flatMap(t -> t.getBeanProperties().stream()) - .map(PropertyElement::getName) + .flatMap(MissingParameterRule::findProperties) + .map(Named::getName) .collect(Collectors.toList())); // RequestBean has properties inside routeVariables.addAll(Arrays.stream(parameters) .filter(p -> p.hasAnnotation("io.micronaut.http.annotation.RequestBean")) .map(ParameterElement::getType) - .flatMap(t -> t.getBeanProperties().stream()) - .filter(p -> p.hasStereotype(Bindable.class)) + .flatMap(MissingParameterRule::findProperties) + .filter(p -> p.getAnnotationMetadata().hasStereotype(Bindable.class)) .map(p -> p.getAnnotationMetadata().stringValue(Bindable.class).orElse(p.getName())) .collect(Collectors.toSet())); @@ -72,4 +76,13 @@ public RouteValidationResult validate(List templates, Paramete return new RouteValidationResult(errorMessages.toArray(new String[0])); } + private static Stream findProperties(ClassElement t) { + if (t.isRecord()) { + Optional primaryConstructor = t.getPrimaryConstructor(); + if (primaryConstructor.isPresent()) { + return Arrays.stream(primaryConstructor.get().getParameters()); + } + } + return t.getBeanProperties().stream(); + } } diff --git a/http-validation/src/main/java/io/micronaut/validation/routes/rules/RequestBeanParameterRule.java b/http-validation/src/main/java/io/micronaut/validation/routes/rules/RequestBeanParameterRule.java index aac4a23bd21..5aa576e9226 100644 --- a/http-validation/src/main/java/io/micronaut/validation/routes/rules/RequestBeanParameterRule.java +++ b/http-validation/src/main/java/io/micronaut/validation/routes/rules/RequestBeanParameterRule.java @@ -52,15 +52,17 @@ private List validate(ParameterElement parameterElement) { // @Creator constructor List constructorParameters = Arrays.asList(primaryConstructor.get().getParameters()); - // Check no constructor parameter has any @Bindable annotation - // We could allow this, but this would add some complexity, some annotations that can be used in combination - // with @Bindable works only on fields (e.g. bean validation annotations) and this might confuse Micronaut users - constructorParameters.stream() + if (!parameterElement.getType().isRecord()) { + // Check no constructor parameter has any @Bindable annotation + // We could allow this, but this would add some complexity, some annotations that can be used in combination + // with @Bindable works only on fields (e.g. bean validation annotations) and this might confuse Micronaut users + constructorParameters.stream() .filter(p -> p.hasStereotype(Bindable.class)) .forEach(p -> errors.add("Parameter of Primary Constructor (or @Creator Method) [" + p.getName() + "] for type [" - + parameterElement.getType().getName() + "] has one of @Bindable annotations. This is not supported." - + "\nNote1: Primary constructor is a constructor that have parameters or is annotated with @Creator." - + "\nNote2: In case you have multiple @Creator constructors, first is used as primary constructor.")); + + parameterElement.getType().getName() + "] has one of @Bindable annotations. This is not supported." + + "\nNote1: Primary constructor is a constructor that have parameters or is annotated with @Creator." + + "\nNote2: In case you have multiple @Creator constructors, first is used as primary constructor.")); + } // Check readonly bindable properties can be set via constructor beanProperties.stream() diff --git a/http-validation/src/test/groovy/io/micronaut/validation/routes/RequestBeanParameterRuleSpec.groovy b/http-validation/src/test/groovy/io/micronaut/validation/routes/RequestBeanParameterRuleSpec.groovy index fcb126bbfe6..11a988792a2 100644 --- a/http-validation/src/test/groovy/io/micronaut/validation/routes/RequestBeanParameterRuleSpec.groovy +++ b/http-validation/src/test/groovy/io/micronaut/validation/routes/RequestBeanParameterRuleSpec.groovy @@ -1,6 +1,7 @@ package io.micronaut.validation.routes import io.micronaut.annotation.processing.test.AbstractTypeElementSpec +import spock.lang.IgnoreIf class RequestBeanParameterRuleSpec extends AbstractTypeElementSpec { @@ -21,22 +22,22 @@ class Foo { String abc(@RequestBean Bean bean) { return ""; } - + @Introspected private static class Bean { - + @Nullable @QueryValue private final String abc; - + public Bean(String abc) { this.abc = abc; } - + public String getAbc() { return abc; } - + } - + } """) @@ -61,35 +62,35 @@ class Foo { String abc(@RequestBean Bean bean) { return ""; } - + @Introspected private static class Bean { - + @Nullable @QueryValue private final String abc; - + @Nullable @QueryValue private final String def; - + public Bean(String abc) { this.abc = abc; this.def = null; } - - @Creator + + @Creator public Bean(String abc, String def) { this.abc = abc; this.def = def; } - + public String getAbc() { return abc; } - + public String getDef() { return def; } - + } - + } """) @@ -114,25 +115,55 @@ class Foo { String abc(@RequestBean Bean bean) { return ""; } - + @Introspected private static class Bean { - + @Nullable @QueryValue private String abc; - + @Creator public static Bean of(String abc) { Bean bean = new Bean(); bean.abc = abc; return bean; } - + public String getAbc() { return abc; } - + + } + +} + +""") + then: + noExceptionThrown() + } + + @IgnoreIf({ !jvm.isJava14Compatible() }) + void "test RequestBean compiles with record"() { + when: + buildTypeElement(""" + +package test; + +import io.micronaut.http.annotation.*; +import io.micronaut.core.annotation.*; +import io.micronaut.core.annotation.Nullable; + +@Controller("/foo") +class Foo { + + @Get("/abc/{abc}") + String abc(@RequestBean Bean bean) { + return ""; + } + + @Introspected + public record Bean(@Nullable @PathVariable String abc) { } - + } """) @@ -157,20 +188,20 @@ class Foo { String abc(@RequestBean Bean bean) { return ""; } - + @Introspected private static class Bean { - + @Nullable @QueryValue private String abc; - + public String getAbc() { return abc; } - + public void setAbc(String abc) { this.abc = abc; } - + } - + } """) @@ -195,18 +226,18 @@ class Foo { String abc(@RequestBean Bean bean) { return ""; } - + @Introspected public static class Bean { - + @Nullable @QueryValue private String abc; - + public String getAbc() { return abc; } - + } - + } """) @@ -232,28 +263,28 @@ class Foo { String abc(@RequestBean Bean bean) { return ""; } - + @Introspected public static class Bean { - + @Nullable @QueryValue private String abc; - + @Nullable @QueryValue private String def; - + public Bean(String def) { this.def = def; } - + public String getAbc() { return abc; } - + public String getDef() { return def; } - + } - + } """) @@ -279,23 +310,23 @@ class Foo { String abc(@RequestBean Bean bean) { return ""; } - + @Introspected public static class Bean { - + @Nullable @QueryValue private String abc; - + @Creator public Bean(@Nullable @QueryValue String abc) { this.abc = abc; } - + public String getAbc() { return abc; } - + } - + } """) @@ -321,31 +352,31 @@ class Foo { String abc(@RequestBean Bean bean) { return ""; } - + @Introspected public static class Bean { - + @Nullable @QueryValue private String abc; - + @Nullable @QueryValue private String def; - + @Creator public Bean(String def) { this.def = def; } - + public String getAbc() { return abc; } - + public void setAbc() { this.abc = abc; } - + public String getDef() { return def; } - + } - + } """) diff --git a/http/build.gradle b/http/build.gradle index 1908fb32a8e..20149d14dff 100644 --- a/http/build.gradle +++ b/http/build.gradle @@ -19,6 +19,7 @@ dependencies { testImplementation project(":jackson-databind") testImplementation project(":inject") testImplementation project(":runtime") + testImplementation(libs.managed.logback.classic) } tasks.named("compileKotlin") { diff --git a/http/src/main/java/io/micronaut/http/uri/UriTemplate.java b/http/src/main/java/io/micronaut/http/uri/UriTemplate.java index c329d4a0f1d..64e844dde49 100644 --- a/http/src/main/java/io/micronaut/http/uri/UriTemplate.java +++ b/http/src/main/java/io/micronaut/http/uri/UriTemplate.java @@ -27,7 +27,6 @@ import java.util.Collections; import java.util.List; import java.util.Map; -import java.util.Objects; import java.util.Optional; import java.util.StringJoiner; import java.util.function.Predicate; @@ -995,7 +994,6 @@ public String expand(Map parameters, boolean previousHasContent, result = joiner.toString(); } else if (found instanceof Map) { Map map = (Map) found; - map.values().removeIf(Objects::isNull); if (map.isEmpty()) { return ""; } @@ -1020,6 +1018,9 @@ public String expand(Map parameters, boolean previousHasContent, } map.forEach((key, some) -> { + if (some == null) { + return; + } String ks = key.toString(); Iterable values = (some instanceof Iterable) ? (Iterable) some : Collections.singletonList(some); for (Object value: values) { @@ -1038,7 +1039,12 @@ public String expand(Map parameters, boolean previousHasContent, } } }); - result = joiner.toString(); + if (joiner.length() == 0) { + // only null entries + return ""; + } else { + result = joiner.toString(); + } } else { String str = found.toString(); str = applyModifier(modifierStr, modifierChar, str, str.length()); diff --git a/http/src/main/java/io/micronaut/http/util/HttpHeadersUtil.java b/http/src/main/java/io/micronaut/http/util/HttpHeadersUtil.java new file mode 100644 index 00000000000..efcbc4213b2 --- /dev/null +++ b/http/src/main/java/io/micronaut/http/util/HttpHeadersUtil.java @@ -0,0 +1,95 @@ +/* + * Copyright 2017-2022 original authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.micronaut.http.util; + +import io.micronaut.core.annotation.NonNull; +import io.micronaut.core.annotation.Nullable; +import io.micronaut.core.util.SupplierUtil; +import io.micronaut.http.HttpHeaders; +import org.slf4j.Logger; + +import java.util.List; +import java.util.Set; +import java.util.function.Function; +import java.util.function.Supplier; +import java.util.regex.Pattern; + +/** + * Utility class to work with {@link io.micronaut.http.HttpHeaders} or HTTP Headers. + * @author Sergio del Amo + * @since 3.8.0 + */ +public final class HttpHeadersUtil { + private static final Supplier HEADER_MASK_PATTERNS = SupplierUtil.memoized(() -> + Pattern.compile(".*(password|cred|cert|key|secret|token|auth|signat).*", Pattern.CASE_INSENSITIVE) + ); + + private HttpHeadersUtil() { + + } + + /** + * Trace HTTP Headers. + * @param log Logger + * @param httpHeaders HTTP Headers + */ + public static void trace(@NonNull Logger log, + @NonNull HttpHeaders httpHeaders) { + trace(log, httpHeaders.names(), httpHeaders::getAll); + } + + /** + * Trace HTTP Headers. + * @param log Logger + * @param names HTTP Header names + * @param getAllHeaders Function to get all the header values for a particular header name + */ + public static void trace(@NonNull Logger log, + @NonNull Set names, + @NonNull Function> getAllHeaders) { + names.forEach(name -> trace(log, name, getAllHeaders)); + } + + /** + * Trace HTTP Headers. + * @param log Logger + * @param name HTTP Header name + * @param getAllHeaders Function to get all the header values for a particular header name + */ + public static void trace(@NonNull Logger log, + @NonNull String name, + @NonNull Function> getAllHeaders) { + boolean isMasked = HEADER_MASK_PATTERNS.get().matcher(name).matches(); + List all = getAllHeaders.apply(name); + if (all.size() > 1) { + for (String value : all) { + String maskedValue = isMasked ? mask(value) : value; + log.trace("{}: {}", name, maskedValue); + } + } else if (!all.isEmpty()) { + String maskedValue = isMasked ? mask(all.get(0)) : all.get(0); + log.trace("{}: {}", name, maskedValue); + } + } + + @Nullable + private static String mask(@Nullable String value) { + if (value == null) { + return null; + } + return "*MASKED*"; + } +} diff --git a/http/src/test/groovy/io/micronaut/http/util/HttpHeadersUtilSpec.groovy b/http/src/test/groovy/io/micronaut/http/util/HttpHeadersUtilSpec.groovy new file mode 100644 index 00000000000..54098639c29 --- /dev/null +++ b/http/src/test/groovy/io/micronaut/http/util/HttpHeadersUtilSpec.groovy @@ -0,0 +1,78 @@ +package io.micronaut.http.util + +import ch.qos.logback.classic.Level +import ch.qos.logback.classic.spi.ILoggingEvent +import ch.qos.logback.core.AppenderBase +import io.micronaut.http.HttpHeaders +import org.slf4j.Logger +import org.slf4j.LoggerFactory +import spock.lang.Specification + +import java.util.concurrent.BlockingQueue +import java.util.concurrent.LinkedBlockingQueue + +class HttpHeadersUtilSpec extends Specification { + def "check masking works for #value"() { + expect: + expected == HttpHeadersUtil.mask(value) + + where: + value | expected + null | null + "foo" | "*MASKED*" + "Tim Yates" | "*MASKED*" + } + + def "check mask detects common security headers"() { + given: + MemoryAppender appender = new MemoryAppender() + Logger log = LoggerFactory.getLogger(HttpHeadersUtilSpec.class) + + expect: + log instanceof ch.qos.logback.classic.Logger + + when: + ch.qos.logback.classic.Logger logger = (ch.qos.logback.classic.Logger) log + logger.addAppender(appender) + logger.setLevel(Level.TRACE) + appender.start() + + HttpHeaders headers = new MockHttpHeaders([ + "Authorization": ["Bearer foo"], + "Proxy-Authorization": ["AWS4-HMAC-SHA256 bar"], + "Cookie": ["baz"], + "Set-Cookie": ["qux"], + "X-Forwarded-For": ["quux", "fred"], + "X-Forwarded-Host": ["quuz"], + "X-Real-IP": ["waldo"], + "Credential": ["foo"], + "Signature": ["bar probably secret"]]) + + HttpHeadersUtil.trace(log, headers) + + then: + appender.events.size() == headers.values().collect { it -> it.size() }.sum() + appender.events.contains("Authorization: *MASKED*") + appender.events.contains("Cookie: baz") + appender.events.contains("Credential: *MASKED*") + appender.events.contains("Set-Cookie: qux") + appender.events.contains("Proxy-Authorization: *MASKED*") + appender.events.contains("Signature: *MASKED*") + appender.events.contains("X-Forwarded-For: quux") + appender.events.contains("X-Forwarded-For: fred") + appender.events.contains("X-Forwarded-Host: quuz") + appender.events.contains("X-Real-IP: waldo") + + cleanup: + appender.stop() + } + + static class MemoryAppender extends AppenderBase { + final BlockingQueue events = new LinkedBlockingQueue<>() + + @Override + protected void append(ILoggingEvent e) { + events.add(e.formattedMessage) + } + } +} diff --git a/http/src/test/groovy/io/micronaut/http/util/MockHttpHeaders.java b/http/src/test/groovy/io/micronaut/http/util/MockHttpHeaders.java new file mode 100644 index 00000000000..55cb1f9bd94 --- /dev/null +++ b/http/src/test/groovy/io/micronaut/http/util/MockHttpHeaders.java @@ -0,0 +1,99 @@ +/* + * Copyright 2017-2020 original authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.micronaut.http.util; + +import io.micronaut.core.annotation.Nullable; +import io.micronaut.core.convert.ArgumentConversionContext; +import io.micronaut.core.convert.ConversionService; +import io.micronaut.http.MutableHttpHeaders; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; + +public class MockHttpHeaders implements MutableHttpHeaders { + + private final Map> headers; + private ConversionService conversionService = ConversionService.SHARED; + + public MockHttpHeaders(Map> headers) { + this.headers = headers; + } + + @Override + public MutableHttpHeaders add(CharSequence header, CharSequence value) { + headers.compute(header, (key, val) -> { + if (val == null) { + val = new ArrayList<>(); + } + val.add(value.toString()); + return val; + }); + return this; + } + + @Override + public MutableHttpHeaders remove(CharSequence header) { + headers.remove(header); + return this; + } + + @Override + public List getAll(CharSequence name) { + List values = headers.get(name); + if (values == null) { + return Collections.emptyList(); + } else { + return values; + } + } + + @Nullable + @Override + public String get(CharSequence name) { + List values = headers.get(name); + if (values == null || values.isEmpty()) { + return null; + } else { + return values.get(0); + } + } + + @Override + public Set names() { + return headers.keySet().stream().map(CharSequence::toString).collect(Collectors.toSet()); + } + + @Override + public Collection> values() { + return headers.values(); + } + + @Override + public Optional get(CharSequence name, ArgumentConversionContext conversionContext) { + return conversionService.convert(get(name), conversionContext); + } + + @Override + public void setConversionService(ConversionService conversionService) { + this.conversionService = conversionService; + } +} diff --git a/inject-java-test/src/main/groovy/io/micronaut/annotation/processing/test/AbstractTypeElementSpec.groovy b/inject-java-test/src/main/groovy/io/micronaut/annotation/processing/test/AbstractTypeElementSpec.groovy index 73febdf455f..59ce2579cab 100644 --- a/inject-java-test/src/main/groovy/io/micronaut/annotation/processing/test/AbstractTypeElementSpec.groovy +++ b/inject-java-test/src/main/groovy/io/micronaut/annotation/processing/test/AbstractTypeElementSpec.groovy @@ -70,6 +70,7 @@ import javax.tools.JavaFileObject import java.lang.annotation.Annotation import java.util.stream.Collectors import java.util.stream.StreamSupport + /** * Base class to extend from to allow compilation of Java sources * at runtime to allow testing of compile time behavior. @@ -332,7 +333,7 @@ class Test { return metadata } - protected TypeElement buildTypeElement(String cls) { + protected TypeElement buildTypeElement(@Language('java') String cls) { List elements = [] newJavaParser().parseLines("", diff --git a/jackson-databind/src/main/java/io/micronaut/jackson/modules/BeanIntrospectionModule.java b/jackson-databind/src/main/java/io/micronaut/jackson/modules/BeanIntrospectionModule.java index 9cb190d388c..6b5e268de32 100644 --- a/jackson-databind/src/main/java/io/micronaut/jackson/modules/BeanIntrospectionModule.java +++ b/jackson-databind/src/main/java/io/micronaut/jackson/modules/BeanIntrospectionModule.java @@ -27,7 +27,19 @@ import com.fasterxml.jackson.core.ObjectCodec; import com.fasterxml.jackson.core.SerializableString; import com.fasterxml.jackson.core.io.SerializedString; -import com.fasterxml.jackson.databind.*; +import com.fasterxml.jackson.databind.BeanDescription; +import com.fasterxml.jackson.databind.DeserializationConfig; +import com.fasterxml.jackson.databind.DeserializationContext; +import com.fasterxml.jackson.databind.JavaType; +import com.fasterxml.jackson.databind.JsonDeserializer; +import com.fasterxml.jackson.databind.JsonMappingException; +import com.fasterxml.jackson.databind.JsonSerializer; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.PropertyMetadata; +import com.fasterxml.jackson.databind.PropertyName; +import com.fasterxml.jackson.databind.PropertyNamingStrategy; +import com.fasterxml.jackson.databind.SerializationConfig; +import com.fasterxml.jackson.databind.SerializerProvider; import com.fasterxml.jackson.databind.annotation.JsonDeserialize; import com.fasterxml.jackson.databind.annotation.JsonNaming; import com.fasterxml.jackson.databind.annotation.JsonSerialize; @@ -536,6 +548,22 @@ public Object setAndReturn(Object instance, Object value) throws IOException { } return null; } + + @Override + public JsonFormat.Value findPropertyFormat(MapperConfig config, Class baseType) { + JsonFormat.Value v1 = config.getDefaultPropertyFormat(baseType); + JsonFormat.Value v2 = null; + if (property != null) { + AnnotationValue formatAnnotation = property.getAnnotation(JsonFormat.class); + if (formatAnnotation != null) { + v2 = parseJsonFormat(formatAnnotation); + } + } + if (v1 == null) { + return (v2 == null) ? EMPTY_FORMAT : v2; + } + return (v2 == null) ? v1 : v1.withOverrides(v2); + } }; } } diff --git a/jackson-databind/src/main/java/io/micronaut/jackson/serialize/ConvertibleMultiValuesSerializer.java b/jackson-databind/src/main/java/io/micronaut/jackson/serialize/ConvertibleMultiValuesSerializer.java index 9626980c5d1..455412cd895 100644 --- a/jackson-databind/src/main/java/io/micronaut/jackson/serialize/ConvertibleMultiValuesSerializer.java +++ b/jackson-databind/src/main/java/io/micronaut/jackson/serialize/ConvertibleMultiValuesSerializer.java @@ -50,12 +50,12 @@ public void serialize(ConvertibleMultiValues value, JsonGenerator gen, Serial if (len > 0) { gen.writeFieldName(fieldName); if (len == 1) { - gen.writeObject(v.get(0)); + serializers.defaultSerializeValue(v.get(0), gen); } else { gen.writeStartArray(); for (Object o : v) { - gen.writeObject(o); + serializers.defaultSerializeValue(o, gen); } gen.writeEndArray(); } diff --git a/jackson-databind/src/main/java/io/micronaut/jackson/serialize/ConvertibleValuesSerializer.java b/jackson-databind/src/main/java/io/micronaut/jackson/serialize/ConvertibleValuesSerializer.java index b495d828249..3ebadecf0c9 100644 --- a/jackson-databind/src/main/java/io/micronaut/jackson/serialize/ConvertibleValuesSerializer.java +++ b/jackson-databind/src/main/java/io/micronaut/jackson/serialize/ConvertibleValuesSerializer.java @@ -47,7 +47,7 @@ public void serialize(ConvertibleValues value, JsonGenerator gen, SerializerP Object v = entry.getValue(); if (v != null) { gen.writeFieldName(fieldName); - gen.writeObject(v); + serializers.defaultSerializeValue(v, gen); } } gen.writeEndObject(); diff --git a/jackson-databind/src/test/groovy/io/micronaut/jackson/modules/BeanIntrospectionModuleRecordSpec.groovy b/jackson-databind/src/test/groovy/io/micronaut/jackson/modules/BeanIntrospectionModuleRecordSpec.groovy index cc5f5757ee8..c8f4821dcf3 100644 --- a/jackson-databind/src/test/groovy/io/micronaut/jackson/modules/BeanIntrospectionModuleRecordSpec.groovy +++ b/jackson-databind/src/test/groovy/io/micronaut/jackson/modules/BeanIntrospectionModuleRecordSpec.groovy @@ -8,6 +8,9 @@ import io.micronaut.context.annotation.Requires import io.micronaut.core.beans.BeanIntrospection import jakarta.inject.Singleton import spock.lang.IgnoreIf +import spock.lang.Issue + +import java.time.LocalDateTime @IgnoreIf({ !jvm.isJava14Compatible() }) class BeanIntrospectionModuleRecordSpec extends AbstractTypeElementSpec { @@ -36,6 +39,33 @@ record Test(String foo, String bar) { ignoreReflectiveProperties << [true, false] } + @Issue('https://github.com/micronaut-projects/micronaut-core/issues/8330') + def 'JsonFormat'() { + given: + BeanIntrospection introspection = buildBeanIntrospection('test.Test', ''' +package test; +import java.time.LocalDateTime; +import com.fasterxml.jackson.annotation.JsonFormat; +import io.micronaut.core.annotation.Introspected; + +@Introspected +record Test(@JsonFormat(pattern = "dd.MM.yyyy HH:mm:ss") LocalDateTime date) { +} +''') + def ctx = ApplicationContext.run(['spec.name': 'BeanIntrospectionModuleRecordSpec']) + ctx.getBean(StaticBeanIntrospectionModule).introspectionMap[introspection.beanType] = introspection + ctx.getBean(BeanIntrospectionModule).ignoreReflectiveProperties = ignoreReflectiveProperties + def mapper = ctx.getBean(ObjectMapper) + + when: + def value = mapper.readValue('{"date":"13.11.2022 22:44:55"}', introspection.beanType) + then: + value.date == LocalDateTime.of(2022, 11, 13, 22, 44, 55) + + where: + ignoreReflectiveProperties << [true, false] + } + @Singleton @Replaces(BeanIntrospectionModule) @Requires(property = "spec.name", value = 'BeanIntrospectionModuleRecordSpec') diff --git a/jackson-databind/src/test/groovy/io/micronaut/jackson/modules/BeanIntrospectionModuleSpec.groovy b/jackson-databind/src/test/groovy/io/micronaut/jackson/modules/BeanIntrospectionModuleSpec.groovy index cc99c764324..ce054da2a81 100644 --- a/jackson-databind/src/test/groovy/io/micronaut/jackson/modules/BeanIntrospectionModuleSpec.groovy +++ b/jackson-databind/src/test/groovy/io/micronaut/jackson/modules/BeanIntrospectionModuleSpec.groovy @@ -33,16 +33,121 @@ import io.micronaut.jackson.JacksonConfiguration import io.micronaut.jackson.modules.testcase.EmailTemplate import io.micronaut.jackson.modules.testcase.Notification import io.micronaut.jackson.modules.testclasses.HTTPCheck -import io.micronaut.jackson.modules.wrappers.* +import io.micronaut.jackson.modules.testclasses.InstanceInfo +import io.micronaut.jackson.modules.wrappers.BooleanWrapper +import io.micronaut.jackson.modules.wrappers.DoubleWrapper +import io.micronaut.jackson.modules.wrappers.IntWrapper +import io.micronaut.jackson.modules.wrappers.IntegerWrapper +import io.micronaut.jackson.modules.wrappers.LongWrapper +import io.micronaut.jackson.modules.wrappers.StringWrapper import spock.lang.Issue -import spock.lang.Unroll import spock.lang.Specification +import spock.lang.Unroll import java.beans.ConstructorProperties import java.time.LocalDateTime class BeanIntrospectionModuleSpec extends Specification { + void "test serialize/deserialize wrap/unwrap - simple"() { + given: + ApplicationContext ctx = ApplicationContext.run( + 'jackson.deserialization.UNWRAP_ROOT_VALUE': true, + 'jackson.serialization.WRAP_ROOT_VALUE': true + ) + ObjectMapper objectMapper = ctx.getBean(ObjectMapper) + + when: + Author author = new Author(name:"Bob") + + def result = objectMapper.writeValueAsString(author) + + then: + result == '{"Author":{"name":"Bob"}}' + + when: + def read = objectMapper.readValue(result, Author) + + then: + author == read + + } + + void "test serialize/deserialize wrap/unwrap -* complex"() { + given: + ApplicationContext ctx = ApplicationContext.run( + 'jackson.deserialization.UNWRAP_ROOT_VALUE': true, + 'jackson.serialization.WRAP_ROOT_VALUE': true + ) + ObjectMapper objectMapper = ctx.getBean(ObjectMapper) + + when: + HTTPCheck check = new HTTPCheck(headers:[ + Accept:['application/json', 'application/xml'] + ] ) + + def result = objectMapper.writeValueAsString(check) + + then: + result == '{"HTTPCheck":{"Header":{"Accept":["application/json","application/xml"]}}}' + + when: + def read = objectMapper.readValue(result, HTTPCheck) + + then: + check == read + + } + + void "test serialize/deserialize wrap/unwrap -* constructors"() { + given: + ApplicationContext ctx = ApplicationContext.run( + 'jackson.deserialization.UNWRAP_ROOT_VALUE': true, + 'jackson.serialization.WRAP_ROOT_VALUE': true + ) + ObjectMapper objectMapper = ctx.getBean(ObjectMapper) + + when: + IntrospectionCreator check = new IntrospectionCreator("test") + + def result = objectMapper.writeValueAsString(check) + + then: + result == '{"IntrospectionCreator":{"label":"TEST"}}' + + when: + def read = objectMapper.readValue(result, IntrospectionCreator) + + then: + check == read + + } + + void "test serialize/deserialize wrap/unwrap -* constructors & JsonRootName"() { + given: + ApplicationContext ctx = ApplicationContext.run( + 'jackson.deserialization.UNWRAP_ROOT_VALUE': true, + 'jackson.serialization.WRAP_ROOT_VALUE': true + ) + ObjectMapper objectMapper = ctx.getBean(ObjectMapper) + + when: + InstanceInfo check = new InstanceInfo("test") + + def result = objectMapper.writeValueAsString(check) + + then: + result == '{"instance":{"hostName":"test"}}' + + when: + def read = objectMapper.readValue(result, InstanceInfo) + + then: + check == read + + } + + void "test serialize/deserialize convertible values"() { given: ApplicationContext ctx = ApplicationContext.run() @@ -622,6 +727,7 @@ class BeanIntrospectionModuleSpec extends Specification { } @Introspected + @EqualsAndHashCode static class Author { String name } @@ -850,6 +956,7 @@ class BeanIntrospectionModuleSpec extends Specification { } @Introspected + @EqualsAndHashCode static class IntrospectionCreator { private final String name diff --git a/jackson-databind/src/test/groovy/io/micronaut/jackson/modules/testclasses/HTTPCheck.java b/jackson-databind/src/test/groovy/io/micronaut/jackson/modules/testclasses/HTTPCheck.java index 6ace4bbc12d..a011da667b3 100644 --- a/jackson-databind/src/test/groovy/io/micronaut/jackson/modules/testclasses/HTTPCheck.java +++ b/jackson-databind/src/test/groovy/io/micronaut/jackson/modules/testclasses/HTTPCheck.java @@ -8,6 +8,7 @@ import java.util.List; import java.util.Map; +import java.util.Objects; @JsonNaming(PropertyNamingStrategies.UpperCamelCaseStrategy.class) @Introspected @@ -29,4 +30,17 @@ public void setHeaders(Map> headers) { this.headers = ConvertibleMultiValues.of(headers); } } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + HTTPCheck httpCheck = (HTTPCheck) o; + return headers.equals(httpCheck.headers); + } + + @Override + public int hashCode() { + return Objects.hash(headers); + } } diff --git a/jackson-databind/src/test/groovy/io/micronaut/jackson/modules/testclasses/InstanceInfo.java b/jackson-databind/src/test/groovy/io/micronaut/jackson/modules/testclasses/InstanceInfo.java new file mode 100644 index 00000000000..57509ffc3e0 --- /dev/null +++ b/jackson-databind/src/test/groovy/io/micronaut/jackson/modules/testclasses/InstanceInfo.java @@ -0,0 +1,37 @@ +package io.micronaut.jackson.modules.testclasses; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonRootName; +import io.micronaut.core.annotation.Introspected; + +import java.util.Objects; + +@JsonRootName("instance") +@Introspected +public class InstanceInfo { + private final String hostName; + + @JsonCreator + InstanceInfo( + @JsonProperty("hostName") String hostName) { + this.hostName = hostName; + } + + public String getHostName() { + return hostName; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + InstanceInfo that = (InstanceInfo) o; + return hostName.equals(that.hostName); + } + + @Override + public int hashCode() { + return Objects.hash(hostName); + } +} diff --git a/parent/build.gradle b/parent/build.gradle new file mode 100644 index 00000000000..e69de29bb2d