Skip to content

Commit

Permalink
Remove createRestRequest changes in favor of new security rest channe…
Browse files Browse the repository at this point in the history
…l in security plugin

Signed-off-by: Craig Perkins <cwperx@amazon.com>
  • Loading branch information
cwperks committed Oct 5, 2023
1 parent 91fc5bc commit ddaca29
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 176 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@

package org.opensearch.http.netty4;

import org.opensearch.core.common.bytes.BytesArray;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.http.HttpRequest;
import org.opensearch.rest.RestRequest;
import org.opensearch.transport.netty4.Netty4Utils;

Expand All @@ -55,47 +55,52 @@
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.cookie.Cookie;
import io.netty.handler.codec.http.cookie.ServerCookieDecoder;
import io.netty.handler.codec.http.cookie.ServerCookieEncoder;

public class Netty4HttpRequest implements org.opensearch.http.HttpRequest {
public class Netty4HttpRequest implements HttpRequest {

private final HttpRequest request;
private final FullHttpRequest request;
private final BytesReference content;
private final HttpHeadersMap headers;
private final AtomicBoolean released;
private final Exception inboundException;
private final boolean pooled;

public Netty4HttpRequest(HttpRequest request) {
Netty4HttpRequest(FullHttpRequest request) {
this(
request,
new HttpHeadersMap(request.headers()),
new AtomicBoolean(false),
true,
(request instanceof FullHttpRequest) ? Netty4Utils.toBytesReference(((FullHttpRequest) request).content()) : BytesArray.EMPTY
Netty4Utils.toBytesReference(request.content())
);
}

Netty4HttpRequest(HttpRequest request, Exception inboundException) {
Netty4HttpRequest(FullHttpRequest request, Exception inboundException) {
this(
request,
new HttpHeadersMap(request.headers()),
new AtomicBoolean(false),
true,
(request instanceof FullHttpRequest) ? Netty4Utils.toBytesReference(((FullHttpRequest) request).content()) : BytesArray.EMPTY,
Netty4Utils.toBytesReference(request.content()),
inboundException
);
}

private Netty4HttpRequest(HttpRequest request, HttpHeadersMap headers, AtomicBoolean released, boolean pooled, BytesReference content) {
private Netty4HttpRequest(
FullHttpRequest request,
HttpHeadersMap headers,
AtomicBoolean released,
boolean pooled,
BytesReference content
) {
this(request, headers, released, pooled, content, null);
}

private Netty4HttpRequest(
HttpRequest request,
FullHttpRequest request,
HttpHeadersMap headers,
AtomicBoolean released,
boolean pooled,
Expand Down Expand Up @@ -157,32 +162,27 @@ public BytesReference content() {

@Override
public void release() {
assert request instanceof FullHttpRequest : "release can only be called when underlying request object is of type FullHttpRequest";
if (pooled && released.compareAndSet(false, true)) {
FullHttpRequest req = (FullHttpRequest) request;
req.release();
request.release();
}
}

@Override
public org.opensearch.http.HttpRequest releaseAndCopy() {
assert request instanceof FullHttpRequest
: "releaseAndCopy can only be called when underlying request object is of type FullHttpRequest";
public HttpRequest releaseAndCopy() {
assert released.get() == false;
if (pooled == false) {
return this;
}
FullHttpRequest req = (FullHttpRequest) request;
try {
final ByteBuf copiedContent = Unpooled.copiedBuffer(req.content());
final ByteBuf copiedContent = Unpooled.copiedBuffer(request.content());
return new Netty4HttpRequest(
new DefaultFullHttpRequest(
req.protocolVersion(),
req.method(),
req.uri(),
request.protocolVersion(),
request.method(),
request.uri(),
copiedContent,
req.headers(),
req.trailingHeaders()
request.headers(),
request.trailingHeaders()
),
headers,
new AtomicBoolean(false),
Expand Down Expand Up @@ -214,29 +214,27 @@ public List<String> strictCookies() {
@Override
public HttpVersion protocolVersion() {
if (request.protocolVersion().equals(io.netty.handler.codec.http.HttpVersion.HTTP_1_0)) {
return org.opensearch.http.HttpRequest.HttpVersion.HTTP_1_0;
return HttpRequest.HttpVersion.HTTP_1_0;
} else if (request.protocolVersion().equals(io.netty.handler.codec.http.HttpVersion.HTTP_1_1)) {
return org.opensearch.http.HttpRequest.HttpVersion.HTTP_1_1;
return HttpRequest.HttpVersion.HTTP_1_1;
} else {
throw new IllegalArgumentException("Unexpected http protocol version: " + request.protocolVersion());
}
}

@Override
public org.opensearch.http.HttpRequest removeHeader(String header) {
public HttpRequest removeHeader(String header) {
HttpHeaders headersWithoutContentTypeHeader = new DefaultHttpHeaders();
headersWithoutContentTypeHeader.add(request.headers());
headersWithoutContentTypeHeader.remove(header);
HttpHeaders trailingHeaders = new DefaultHttpHeaders();
if (request instanceof FullHttpRequest) {
trailingHeaders.add(((FullHttpRequest) request).trailingHeaders());
trailingHeaders.remove(header);
}
trailingHeaders.add(request.trailingHeaders());
trailingHeaders.remove(header);
FullHttpRequest requestWithoutHeader = new DefaultFullHttpRequest(
request.protocolVersion(),
request.method(),
request.uri(),
(request instanceof FullHttpRequest) ? ((FullHttpRequest) request).content() : Unpooled.EMPTY_BUFFER,
request.content(),
headersWithoutContentTypeHeader,
trailingHeaders
);
Expand All @@ -253,7 +251,7 @@ public Exception getInboundException() {
return inboundException;
}

public HttpRequest nettyRequest() {
public FullHttpRequest nettyRequest() {
return request;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -412,13 +412,13 @@ private void handleIncomingRequest(final HttpRequest httpRequest, final HttpChan
{
RestRequest innerRestRequest;
try {
innerRestRequest = RestRequest.request(xContentRegistry, httpRequest, httpChannel, true);
innerRestRequest = RestRequest.request(xContentRegistry, httpRequest, httpChannel);
} catch (final RestRequest.ContentTypeHeaderException e) {
badRequestCause = ExceptionsHelper.useOrSuppress(badRequestCause, e);
innerRestRequest = requestWithoutContentTypeHeader(httpRequest, httpChannel, badRequestCause, true);
innerRestRequest = requestWithoutContentTypeHeader(httpRequest, httpChannel, badRequestCause);
} catch (final RestRequest.BadParameterException e) {
badRequestCause = ExceptionsHelper.useOrSuppress(badRequestCause, e);
innerRestRequest = RestRequest.requestWithoutParameters(xContentRegistry, httpRequest, httpChannel, true);
innerRestRequest = RestRequest.requestWithoutParameters(xContentRegistry, httpRequest, httpChannel);
}
restRequest = innerRestRequest;
}
Expand Down Expand Up @@ -466,75 +466,13 @@ private void handleIncomingRequest(final HttpRequest httpRequest, final HttpChan
dispatchRequest(restRequest, channel, badRequestCause);
}

public static RestRequest createRestRequest(
final NamedXContentRegistry xContentRegistry,
final HttpRequest httpRequest,
final HttpChannel httpChannel
) {
Exception badRequestCause = httpRequest.getInboundException();

/*
* We want to create a REST request from the incoming request from Netty. However, creating this request could fail if there
* are incorrectly encoded parameters, or the Content-Type header is invalid. If one of these specific failures occurs, we
* attempt to create a REST request again without the input that caused the exception (e.g., we remove the Content-Type header,
* or skip decoding the parameters). Once we have a request in hand, we then dispatch the request as a bad request with the
* underlying exception that caused us to treat the request as bad.
*/
final RestRequest restRequest;
{
RestRequest innerRestRequest;
try {
innerRestRequest = RestRequest.request(xContentRegistry, httpRequest, httpChannel, false);
} catch (final RestRequest.ContentTypeHeaderException e) {
badRequestCause = ExceptionsHelper.useOrSuppress(badRequestCause, e);
innerRestRequest = requestWithoutContentTypeHeader(xContentRegistry, httpRequest, httpChannel, badRequestCause, false);
} catch (final RestRequest.BadParameterException e) {
badRequestCause = ExceptionsHelper.useOrSuppress(badRequestCause, e);
innerRestRequest = RestRequest.requestWithoutParameters(xContentRegistry, httpRequest, httpChannel, false);
}
restRequest = innerRestRequest;
}
return restRequest;
}

private static RestRequest requestWithoutContentTypeHeader(
NamedXContentRegistry xContentRegistry,
HttpRequest httpRequest,
HttpChannel httpChannel,
Exception badRequestCause,
boolean shouldGenerateRequestId
) {
private RestRequest requestWithoutContentTypeHeader(HttpRequest httpRequest, HttpChannel httpChannel, Exception badRequestCause) {
HttpRequest httpRequestWithoutContentType = httpRequest.removeHeader("Content-Type");
try {
return RestRequest.request(xContentRegistry, httpRequestWithoutContentType, httpChannel, shouldGenerateRequestId);
return RestRequest.request(xContentRegistry, httpRequestWithoutContentType, httpChannel);
} catch (final RestRequest.BadParameterException e) {
badRequestCause.addSuppressed(e);
return RestRequest.requestWithoutParameters(
xContentRegistry,
httpRequestWithoutContentType,
httpChannel,
shouldGenerateRequestId
);
}
}

private RestRequest requestWithoutContentTypeHeader(
HttpRequest httpRequest,
HttpChannel httpChannel,
Exception badRequestCause,
boolean shouldGenerateRequestId
) {
HttpRequest httpRequestWithoutContentType = httpRequest.removeHeader("Content-Type");
try {
return RestRequest.request(xContentRegistry, httpRequestWithoutContentType, httpChannel, shouldGenerateRequestId);
} catch (final RestRequest.BadParameterException e) {
badRequestCause.addSuppressed(e);
return RestRequest.requestWithoutParameters(
xContentRegistry,
httpRequestWithoutContentType,
httpChannel,
shouldGenerateRequestId
);
return RestRequest.requestWithoutParameters(xContentRegistry, httpRequestWithoutContentType, httpChannel);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ public RestController(
this.headersToCopy = headersToCopy;
this.usageService = usageService;
if (handlerWrapper == null) {
handlerWrapper = (h) -> h;
handlerWrapper = h -> h; // passthrough if no wrapper set
}
this.handlerWrapper = handlerWrapper;
this.client = client;
Expand Down
61 changes: 2 additions & 59 deletions server/src/main/java/org/opensearch/rest/RestRequest.java
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ public class RestRequest implements ToXContent.Params {

// tchar pattern as defined by RFC7230 section 3.2.6
private static final Pattern TCHAR_PATTERN = Pattern.compile("[a-zA-z0-9!#$%&'*+\\-.\\^_`|~]+");

private static final AtomicLong requestIdGenerator = new AtomicLong();

private final NamedXContentRegistry xContentRegistry;
Expand Down Expand Up @@ -151,7 +152,7 @@ protected RestRequest(RestRequest restRequest) {
* with an unpooled copy. This is supposed to be used before passing requests to {@link RestHandler} instances that can not safely
* handle http requests that use pooled buffers as determined by {@link RestHandler#allowsUnsafeBuffers()}.
*/
protected void ensureSafeBuffers() {
void ensureSafeBuffers() {
httpRequest = httpRequest.releaseAndCopy();
}

Expand Down Expand Up @@ -179,36 +180,6 @@ public static RestRequest request(NamedXContentRegistry xContentRegistry, HttpRe
);
}

/**
* Creates a new REST request. This method will throw {@link BadParameterException} if the path cannot be
* decoded
*
* @param xContentRegistry the content registry
* @param httpRequest the http request
* @param httpChannel the http channel
* @param shouldGenerateRequestId should generate a new request id
* @throws BadParameterException if the parameters can not be decoded
* @throws ContentTypeHeaderException if the Content-Type header can not be parsed
*/
public static RestRequest request(
NamedXContentRegistry xContentRegistry,
HttpRequest httpRequest,
HttpChannel httpChannel,
boolean shouldGenerateRequestId
) {
Map<String, String> params = params(httpRequest.uri());
String path = path(httpRequest.uri());
return new RestRequest(
xContentRegistry,
params,
path,
httpRequest.getHeaders(),
httpRequest,
httpChannel,
shouldGenerateRequestId ? requestIdGenerator.incrementAndGet() : -1
);
}

private static Map<String, String> params(final String uri) {
final Map<String, String> params = new HashMap<>();
int index = uri.indexOf('?');
Expand Down Expand Up @@ -257,34 +228,6 @@ public static RestRequest requestWithoutParameters(
);
}

/**
* Creates a new REST request. The path is not decoded so this constructor will not throw a
* {@link BadParameterException}.
*
* @param xContentRegistry the content registry
* @param httpRequest the http request
* @param httpChannel the http channel
* @param shouldGenerateRequestId should generate new request id
* @throws ContentTypeHeaderException if the Content-Type header can not be parsed
*/
public static RestRequest requestWithoutParameters(
NamedXContentRegistry xContentRegistry,
HttpRequest httpRequest,
HttpChannel httpChannel,
boolean shouldGenerateRequestId
) {
Map<String, String> params = Collections.emptyMap();
return new RestRequest(
xContentRegistry,
params,
httpRequest.uri(),
httpRequest.getHeaders(),
httpRequest,
httpChannel,
shouldGenerateRequestId ? requestIdGenerator.incrementAndGet() : -1
);
}

/**
* The method used.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@

import java.net.InetSocketAddress;
import java.net.UnknownHostException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
Expand Down Expand Up @@ -150,21 +149,6 @@ public void testHttpPublishPort() throws Exception {
}
}

public void testCreateRestRequestDoesNotGenerateRequestID() {
FakeRestRequest fakeRestRequest = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withContent(
new BytesArray("bar".getBytes(StandardCharsets.UTF_8)),
null
).withPath("/foo").withHeaders(Collections.singletonMap("Content-Type", Collections.singletonList("text/plain"))).build();

RestRequest request = AbstractHttpServerTransport.createRestRequest(
xContentRegistry(),
fakeRestRequest.getHttpRequest(),
fakeRestRequest.getHttpChannel()
);

assertEquals("request should not generate id", -1, request.getRequestId());
}

public void testDispatchDoesNotModifyThreadContext() {
final HttpServerTransport.Dispatcher dispatcher = new HttpServerTransport.Dispatcher() {

Expand Down

0 comments on commit ddaca29

Please sign in to comment.