-
Notifications
You must be signed in to change notification settings - Fork 120
Add support for request and response compression #765
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,9 +22,12 @@ | |
| import io.trino.gateway.ha.config.ProxyResponseConfiguration; | ||
| import io.trino.gateway.proxyserver.ProxyResponseHandler.ProxyResponse; | ||
|
|
||
| import java.io.ByteArrayInputStream; | ||
| import java.io.IOException; | ||
| import java.nio.charset.StandardCharsets; | ||
| import java.io.InputStream; | ||
| import java.util.zip.GZIPInputStream; | ||
|
|
||
| import static java.nio.charset.StandardCharsets.UTF_8; | ||
| import static java.util.Objects.requireNonNull; | ||
|
|
||
| public class ProxyResponseHandler | ||
|
|
@@ -47,7 +50,9 @@ public ProxyResponse handleException(Request request, Exception exception) | |
| public ProxyResponse handle(Request request, Response response) | ||
| { | ||
| try { | ||
| return new ProxyResponse(response.getStatusCode(), response.getHeaders(), new String(response.getInputStream().readNBytes((int) responseSize.toBytes()), StandardCharsets.UTF_8)); | ||
| // Store raw bytes to preserve compression | ||
| byte[] responseBodyBytes = response.getInputStream().readNBytes((int) responseSize.toBytes()); | ||
| return new ProxyResponse(response.getStatusCode(), response.getHeaders(), responseBodyBytes); | ||
| } | ||
| catch (IOException e) { | ||
| throw new ProxyException("Failed reading response from remote Trino server", e); | ||
|
|
@@ -57,11 +62,36 @@ public ProxyResponse handle(Request request, Response response) | |
| public record ProxyResponse( | ||
| int statusCode, | ||
| ListMultimap<HeaderName, String> headers, | ||
| String body) | ||
| byte[] body) | ||
| { | ||
| public ProxyResponse | ||
| { | ||
| requireNonNull(headers, "headers is null"); | ||
| requireNonNull(body, "body is null"); | ||
| } | ||
|
|
||
| /** | ||
| * Get the response body as a decompressed string for JSON parsing and logging. | ||
| * Only call this when you need to parse the content, not when passing through | ||
| * to clients. | ||
| */ | ||
| public String decompressedBody() | ||
| { | ||
| // Check if the response is gzip-compressed | ||
| String contentEncoding = headers.get(HeaderName.of("Content-Encoding")).stream().findFirst().orElse(null); | ||
|
|
||
| if ("gzip".equalsIgnoreCase(contentEncoding)) { | ||
| try (InputStream inputStream = new GZIPInputStream(new ByteArrayInputStream(body))) { | ||
| return new String(inputStream.readAllBytes(), UTF_8); | ||
| } | ||
| catch (IOException e) { | ||
| // If decompression fails, return the body as UTF-8 string | ||
| return new String(body, UTF_8); | ||
| } | ||
| } | ||
|
|
||
| // Not compressed, convert bytes to string | ||
| return new String(body, UTF_8); | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it would be great if we could deal with just bytes, but it could have some other side effects. |
||
| } | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -51,17 +51,16 @@ | |
| @TestInstance(PER_CLASS) | ||
| final class TestProxyRequestHandler | ||
| { | ||
| private static final String OK = "OK"; | ||
| private static final int NOT_FOUND = 404; | ||
| private static final MediaType MEDIA_TYPE = MediaType.parse("application/json; charset=utf-8"); | ||
Chaho12 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
Chaho12 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| private final OkHttpClient httpClient = new OkHttpClient(); | ||
| private final MockWebServer mockTrinoServer = new MockWebServer(); | ||
| private final PostgreSQLContainer postgresql = createPostgreSqlContainer(); | ||
|
|
||
| private final int routerPort = 21001 + (int) (Math.random() * 1000); | ||
| private final int customBackendPort = 21000 + (int) (Math.random() * 1000); | ||
|
|
||
| private static final String OK = "OK"; | ||
| private static final int NOT_FOUND = 404; | ||
| private static final MediaType MEDIA_TYPE = MediaType.parse("application/json; charset=utf-8"); | ||
|
|
||
| private final String customPutEndpoint = "/v1/custom"; // this is enabled in test-config-template.yml | ||
| private final String healthCheckEndpoint = "/v1/info"; | ||
|
|
||
|
|
@@ -70,7 +69,8 @@ void setup() | |
| throws Exception | ||
| { | ||
| prepareMockBackend(mockTrinoServer, customBackendPort, "default custom response"); | ||
| mockTrinoServer.setDispatcher(new Dispatcher() { | ||
| mockTrinoServer.setDispatcher(new Dispatcher() | ||
| { | ||
| @Override | ||
| public MockResponse dispatch(RecordedRequest request) | ||
| { | ||
|
|
@@ -80,6 +80,14 @@ public MockResponse dispatch(RecordedRequest request) | |
| .setBody("{\"starting\": false}"); | ||
| } | ||
|
|
||
| if (request.getPath().equals(healthCheckEndpoint + "?test-compression")) { | ||
Chaho12 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| // Return the Accept-Encoding header value for compression testing | ||
| String acceptEncoding = request.getHeader("Accept-Encoding"); | ||
| return new MockResponse().setResponseCode(200) | ||
| .setHeader(CONTENT_TYPE, JSON_UTF_8) | ||
| .setBody(acceptEncoding != null ? acceptEncoding : "null"); | ||
| } | ||
|
|
||
| if (request.getMethod().equals("PUT") && request.getPath().equals(customPutEndpoint)) { | ||
| return new MockResponse().setResponseCode(200) | ||
| .setHeader(CONTENT_TYPE, JSON_UTF_8) | ||
|
|
@@ -131,18 +139,18 @@ void testGetQueryDetailsFromRequest() | |
| { | ||
| // A sample query longer than 200 characters to test against truncation. | ||
| String longQuery = """ | ||
| SELECT | ||
| c.customer_name, | ||
| c.customer_region, | ||
| COUNT(o.order_id) AS total_orders, | ||
| SUM(o.order_value) AS total_revenue | ||
| FROM | ||
| hive.sales_data.customers AS c | ||
| JOIN | ||
| hive.sales_data.orders AS o | ||
| ON c.customer_id = o.customer_id | ||
| WHERE | ||
| o.order_date >= date '2023-01-01'"""; | ||
| SELECT | ||
| c.customer_name, | ||
| c.customer_region, | ||
| COUNT(o.order_id) AS total_orders, | ||
| SUM(o.order_value) AS total_revenue | ||
| FROM | ||
| hive.sales_data.customers AS c | ||
| JOIN | ||
| hive.sales_data.orders AS o | ||
| ON c.customer_id = o.customer_id | ||
| WHERE | ||
| o.order_date >= date '2023-01-01'"""; | ||
|
|
||
| io.airlift.http.client.Request request = preparePost() | ||
| .setUri(URI.create("http://localhost:" + routerPort + V1_STATEMENT_PATH)) | ||
|
|
@@ -159,4 +167,49 @@ void testGetQueryDetailsFromRequest() | |
| assertThat(queryDetail.getSource()).isEqualTo("trino-cli"); | ||
| assertThat(queryDetail.getBackendUrl()).isEqualTo("http://localhost:" + routerPort); | ||
| } | ||
|
|
||
| @Test | ||
| void testAcceptEncodingHeaderForwarding() | ||
| throws Exception | ||
| { | ||
| // Test that Accept-Encoding header is properly forwarded to backends | ||
| String url = "http://localhost:" + routerPort + healthCheckEndpoint + "?test-compression"; | ||
| String expectedAcceptEncoding = "gzip, deflate, br"; | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you please add a test with just |
||
|
|
||
| Request request = new Request.Builder() | ||
| .url(url) | ||
| .get() | ||
| .addHeader("Accept-Encoding", expectedAcceptEncoding) | ||
| .build(); | ||
|
|
||
| try (Response response = httpClient.newCall(request).execute()) { | ||
| assertThat(response.code()).isEqualTo(200); | ||
| assertThat(response.body()).isNotNull(); | ||
|
|
||
| // The mock backend returns the Accept-Encoding header value in the response body | ||
| assertThat(response.body().string()).isEqualTo(expectedAcceptEncoding); | ||
| } | ||
| } | ||
|
|
||
| @Test | ||
| void testDefaultAcceptEncodingHeaderForwarding() | ||
| throws Exception | ||
| { | ||
| // Test that requests without explicit Accept-Encoding header work correctly | ||
| // Note: OkHttp automatically adds "Accept-Encoding: gzip" when none is specified | ||
| String url = "http://localhost:" + routerPort + healthCheckEndpoint + "?test-compression"; | ||
|
|
||
| Request request = new Request.Builder() | ||
| .url(url) | ||
| .get() | ||
| .build(); // No explicit Accept-Encoding header | ||
|
|
||
| try (Response response = httpClient.newCall(request).execute()) { | ||
| assertThat(response.code()).isEqualTo(200); | ||
| assertThat(response.body()).isNotNull(); | ||
|
|
||
| // OkHttp automatically adds "Accept-Encoding: gzip" when none is specified | ||
| assertThat(response.body().string()).isEqualTo("gzip"); | ||
| } | ||
| } | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could also add a test which breaks current gateway max limit and make sure it work with your change? |
||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this the only encoding trino clients request? Then we should only allow that encoding here.
You can also test with
Accept-Encoding: gzip, deflateThe spooling protocol offers additional encodings.