Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,7 @@ private void performRequest(
for (String name : list(servletRequest.getHeaderNames())) {
for (String value : list(servletRequest.getHeaders(name))) {
// TODO: decide what should and shouldn't be forwarded
if (!name.equalsIgnoreCase("Accept-Encoding")
&& !name.equalsIgnoreCase("Host")
if (!name.equalsIgnoreCase("Host")
&& (addXForwardedHeaders || !name.startsWith("X-Forwarded"))) {
requestBuilder.addHeader(name, value);
}
Expand Down Expand Up @@ -270,26 +269,27 @@ private static WebApplicationException badRequest(String message)
private ProxyResponse recordBackendForQueryId(Request request, ProxyResponse response, Optional<String> username,
RoutingDestination routingDestination)
{
log.debug("For Request [%s] got Response [%s]", request.getUri(), response.body());
String body = response.decompressedBody();
log.debug("For Request [%s] got Response [%s]", request.getUri(), body);

QueryHistoryManager.QueryDetail queryDetail = getQueryDetailsFromRequest(request, username);

log.debug("Extracting proxy destination : [%s] for request : [%s]", queryDetail.getBackendUrl(), request.getUri());

if (response.statusCode() == OK.getStatusCode()) {
try {
HashMap<String, String> results = OBJECT_MAPPER.readValue(response.body(), HashMap.class);
HashMap<String, String> results = OBJECT_MAPPER.readValue(body, HashMap.class);
queryDetail.setQueryId(results.get("id"));
routingManager.setBackendForQueryId(queryDetail.getQueryId(), queryDetail.getBackendUrl());
routingManager.setRoutingGroupForQueryId(queryDetail.getQueryId(), routingDestination.routingGroup());
log.debug("QueryId [%s] mapped with proxy [%s]", queryDetail.getQueryId(), queryDetail.getBackendUrl());
}
catch (IOException e) {
log.error("Failed to get QueryId from response [%s] , Status code [%s]", response.body(), response.statusCode());
log.error("Failed to get QueryId from response [%s] , Status code [%s]", body, response.statusCode());
}
}
else {
log.error("Non OK HTTP Status code with response [%s] , Status code [%s], user: [%s]", response.body(), response.statusCode(), username.orElse(null));
log.error("Non OK HTTP Status code with response [%s] , Status code [%s], user: [%s]", body, response.statusCode(), username.orElse(null));
}
queryDetail.setRoutingGroup(routingDestination.routingGroup());
queryDetail.setExternalUrl(routingDestination.externalUrl());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,12 @@
import io.trino.gateway.ha.config.ProxyResponseConfiguration;
import io.trino.gateway.proxyserver.ProxyResponseHandler.ProxyResponse;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.io.InputStream;
import java.util.zip.GZIPInputStream;

import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.Objects.requireNonNull;

public class ProxyResponseHandler
Expand All @@ -47,7 +50,9 @@ public ProxyResponse handleException(Request request, Exception exception)
public ProxyResponse handle(Request request, Response response)
{
try {
return new ProxyResponse(response.getStatusCode(), response.getHeaders(), new String(response.getInputStream().readNBytes((int) responseSize.toBytes()), StandardCharsets.UTF_8));
// Store raw bytes to preserve compression
byte[] responseBodyBytes = response.getInputStream().readNBytes((int) responseSize.toBytes());
return new ProxyResponse(response.getStatusCode(), response.getHeaders(), responseBodyBytes);
}
catch (IOException e) {
throw new ProxyException("Failed reading response from remote Trino server", e);
Expand All @@ -57,11 +62,36 @@ public ProxyResponse handle(Request request, Response response)
public record ProxyResponse(
int statusCode,
ListMultimap<HeaderName, String> headers,
String body)
byte[] body)
{
public ProxyResponse
{
requireNonNull(headers, "headers is null");
requireNonNull(body, "body is null");
}

/**
* Get the response body as a decompressed string for JSON parsing and logging.
* Only call this when you need to parse the content, not when passing through
* to clients.
*/
public String decompressedBody()
{
// Check if the response is gzip-compressed
String contentEncoding = headers.get(HeaderName.of("Content-Encoding")).stream().findFirst().orElse(null);

if ("gzip".equalsIgnoreCase(contentEncoding)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the only encoding trino clients request? Then we should only allow that encoding here.
You can also test with Accept-Encoding: gzip, deflate
The spooling protocol offers additional encodings.

try (InputStream inputStream = new GZIPInputStream(new ByteArrayInputStream(body))) {
return new String(inputStream.readAllBytes(), UTF_8);
}
catch (IOException e) {
// If decompression fails, return the body as UTF-8 string
return new String(body, UTF_8);
}
}

// Not compressed, convert bytes to string
return new String(body, UTF_8);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be great if we could deal with just bytes, but it could have some other side effects.

}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,16 @@
@TestInstance(PER_CLASS)
final class TestProxyRequestHandler
{
private static final String OK = "OK";
private static final int NOT_FOUND = 404;
private static final MediaType MEDIA_TYPE = MediaType.parse("application/json; charset=utf-8");

private final OkHttpClient httpClient = new OkHttpClient();
private final MockWebServer mockTrinoServer = new MockWebServer();
private final PostgreSQLContainer postgresql = createPostgreSqlContainer();

private final int routerPort = 21001 + (int) (Math.random() * 1000);
private final int customBackendPort = 21000 + (int) (Math.random() * 1000);

private static final String OK = "OK";
private static final int NOT_FOUND = 404;
private static final MediaType MEDIA_TYPE = MediaType.parse("application/json; charset=utf-8");

private final String customPutEndpoint = "/v1/custom"; // this is enabled in test-config-template.yml
private final String healthCheckEndpoint = "/v1/info";

Expand All @@ -70,7 +69,8 @@ void setup()
throws Exception
{
prepareMockBackend(mockTrinoServer, customBackendPort, "default custom response");
mockTrinoServer.setDispatcher(new Dispatcher() {
mockTrinoServer.setDispatcher(new Dispatcher()
{
@Override
public MockResponse dispatch(RecordedRequest request)
{
Expand All @@ -80,6 +80,14 @@ public MockResponse dispatch(RecordedRequest request)
.setBody("{\"starting\": false}");
}

if (request.getPath().equals(healthCheckEndpoint + "?test-compression")) {
// Return the Accept-Encoding header value for compression testing
String acceptEncoding = request.getHeader("Accept-Encoding");
return new MockResponse().setResponseCode(200)
.setHeader(CONTENT_TYPE, JSON_UTF_8)
.setBody(acceptEncoding != null ? acceptEncoding : "null");
}

if (request.getMethod().equals("PUT") && request.getPath().equals(customPutEndpoint)) {
return new MockResponse().setResponseCode(200)
.setHeader(CONTENT_TYPE, JSON_UTF_8)
Expand Down Expand Up @@ -131,18 +139,18 @@ void testGetQueryDetailsFromRequest()
{
// A sample query longer than 200 characters to test against truncation.
String longQuery = """
SELECT
c.customer_name,
c.customer_region,
COUNT(o.order_id) AS total_orders,
SUM(o.order_value) AS total_revenue
FROM
hive.sales_data.customers AS c
JOIN
hive.sales_data.orders AS o
ON c.customer_id = o.customer_id
WHERE
o.order_date >= date '2023-01-01'""";
SELECT
c.customer_name,
c.customer_region,
COUNT(o.order_id) AS total_orders,
SUM(o.order_value) AS total_revenue
FROM
hive.sales_data.customers AS c
JOIN
hive.sales_data.orders AS o
ON c.customer_id = o.customer_id
WHERE
o.order_date >= date '2023-01-01'""";

io.airlift.http.client.Request request = preparePost()
.setUri(URI.create("http://localhost:" + routerPort + V1_STATEMENT_PATH))
Expand All @@ -159,4 +167,49 @@ void testGetQueryDetailsFromRequest()
assertThat(queryDetail.getSource()).isEqualTo("trino-cli");
assertThat(queryDetail.getBackendUrl()).isEqualTo("http://localhost:" + routerPort);
}

@Test
void testAcceptEncodingHeaderForwarding()
throws Exception
{
// Test that Accept-Encoding header is properly forwarded to backends
String url = "http://localhost:" + routerPort + healthCheckEndpoint + "?test-compression";
String expectedAcceptEncoding = "gzip, deflate, br";
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add a test with just String expectedAcceptEncoding = "deflate, br";


Request request = new Request.Builder()
.url(url)
.get()
.addHeader("Accept-Encoding", expectedAcceptEncoding)
.build();

try (Response response = httpClient.newCall(request).execute()) {
assertThat(response.code()).isEqualTo(200);
assertThat(response.body()).isNotNull();

// The mock backend returns the Accept-Encoding header value in the response body
assertThat(response.body().string()).isEqualTo(expectedAcceptEncoding);
}
}

@Test
void testDefaultAcceptEncodingHeaderForwarding()
throws Exception
{
// Test that requests without explicit Accept-Encoding header work correctly
// Note: OkHttp automatically adds "Accept-Encoding: gzip" when none is specified
String url = "http://localhost:" + routerPort + healthCheckEndpoint + "?test-compression";

Request request = new Request.Builder()
.url(url)
.get()
.build(); // No explicit Accept-Encoding header

try (Response response = httpClient.newCall(request).execute()) {
assertThat(response.code()).isEqualTo(200);
assertThat(response.body()).isNotNull();

// OkHttp automatically adds "Accept-Encoding: gzip" when none is specified
assertThat(response.body().string()).isEqualTo("gzip");
}
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could also add a test which breaks current gateway max limit and make sure it work with your change?

}