diff --git a/common/src/main/java/com/google/tsunami/common/net/http/HttpClient.java b/common/src/main/java/com/google/tsunami/common/net/http/HttpClient.java index ce0c05ee..b96de8eb 100644 --- a/common/src/main/java/com/google/tsunami/common/net/http/HttpClient.java +++ b/common/src/main/java/com/google/tsunami/common/net/http/HttpClient.java @@ -21,7 +21,9 @@ import static com.google.common.util.concurrent.MoreExecutors.directExecutor; import com.google.common.base.Ascii; +import com.google.common.collect.ImmutableSet; import com.google.common.flogger.GoogleLogger; +import com.google.common.io.ByteSource; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; import com.google.protobuf.ByteString; @@ -76,10 +78,6 @@ public final class HttpClient { * @throws IOException if an I/O error occurs during the HTTP request. */ public HttpResponse sendAsIs(HttpRequest httpRequest) throws IOException { - if (!httpRequest.method().equals(HttpMethod.GET)) { - throw new IllegalArgumentException("sendAsIs method should only be used for GET method."); - } - HttpURLConnection connection = connectionFactory.openConnection(httpRequest.url()); connection.setRequestMethod(httpRequest.method().toString()); httpRequest.headers().names().stream() @@ -93,7 +91,12 @@ public HttpResponse sendAsIs(HttpRequest httpRequest) throws IOException { headerValue -> connection.setRequestProperty(headerName, headerValue))); connection.setRequestProperty(USER_AGENT, TSUNAMI_USER_AGENT); - connection.connect(); + if (ImmutableSet.of(HttpMethod.POST, HttpMethod.PUT, HttpMethod.DELETE) + .contains(httpRequest.method())) { + connection.setDoOutput(true); + ByteSource.wrap(httpRequest.requestBody().orElse(ByteString.EMPTY).toByteArray()) + .copyTo(connection.getOutputStream()); + } int responseCode = connection.getResponseCode(); HttpHeaders.Builder responseHeadersBuilder = HttpHeaders.builder(); diff --git a/common/src/test/java/com/google/tsunami/common/net/http/HttpClientTest.java b/common/src/test/java/com/google/tsunami/common/net/http/HttpClientTest.java index 8ae2d1d5..6b19e6df 100644 --- a/common/src/test/java/com/google/tsunami/common/net/http/HttpClientTest.java +++ b/common/src/test/java/com/google/tsunami/common/net/http/HttpClientTest.java @@ -84,22 +84,23 @@ public void tearDown() throws IOException { @Test public void sendAsIs_always_returnsExpectedHttpResponse() throws IOException, InterruptedException { - String responseBody = "test response"; - mockWebServer.enqueue( - new MockResponse() - .setResponseCode(HttpStatus.OK.code()) - .setHeader(CONTENT_TYPE, MediaType.PLAIN_TEXT_UTF_8.toString()) - .setBody(responseBody)); + mockWebServer.setDispatcher(new SendAsIsTestDispatcher()); mockWebServer.start(); + String expectedResponseBody = SendAsIsTestDispatcher.buildBody("GET", ""); HttpUrl baseUrl = mockWebServer.url("/"); String requestUrl = - new URL(baseUrl.scheme(), baseUrl.host(), baseUrl.port(), "/%2e%2e/%2e%2e/etc/passwd") + new URL( + baseUrl.scheme(), + baseUrl.host(), + baseUrl.port(), + "/send-as-is/%2e%2e/%2e%2e/etc/passwd") .toString(); HttpResponse response = httpClient.sendAsIs(get(requestUrl).withEmptyHeaders().build()); - assertThat(mockWebServer.takeRequest().getPath()).isEqualTo("/%2e%2e/%2e%2e/etc/passwd"); + assertThat(mockWebServer.takeRequest().getPath()) + .isEqualTo("/send-as-is/%2e%2e/%2e%2e/etc/passwd"); assertThat(response) .isEqualTo( HttpResponse.builder() @@ -108,24 +109,45 @@ public void sendAsIs_always_returnsExpectedHttpResponse() HttpHeaders.builder() .addHeader(CONTENT_TYPE, MediaType.PLAIN_TEXT_UTF_8.toString()) // MockWebServer always adds this response header. - .addHeader(CONTENT_LENGTH, String.valueOf(responseBody.length())) + .addHeader(CONTENT_LENGTH, String.valueOf(expectedResponseBody.length())) .build()) - .setBodyBytes(ByteString.copyFrom(responseBody, UTF_8)) + .setBodyBytes(ByteString.copyFrom(expectedResponseBody, UTF_8)) .build()); } @Test - public void sendAsIs_withNonGetRequest_throws() throws IOException, InterruptedException { + public void sendAsIs_withPostRequest_returnsExpectedHttpResponse() + throws IOException, InterruptedException { + mockWebServer.setDispatcher(new SendAsIsTestDispatcher()); mockWebServer.start(); + String requestBody = "POST BODY"; + String expectedResponseBody = SendAsIsTestDispatcher.buildBody("POST", requestBody); HttpUrl baseUrl = mockWebServer.url("/"); String requestUrl = - new URL(baseUrl.scheme(), baseUrl.host(), baseUrl.port(), "/%2e%2e/%2e%2e/etc/passwd") + new URL(baseUrl.scheme(), baseUrl.host(), baseUrl.port(), "/send-as-is/%2e%2e/%2e%2e/path") .toString(); - assertThrows( - IllegalArgumentException.class, - () -> httpClient.sendAsIs(post(requestUrl).withEmptyHeaders().build())); + HttpResponse response = + httpClient.sendAsIs( + post(requestUrl) + .setRequestBody(ByteString.copyFrom(requestBody, UTF_8)) + .withEmptyHeaders() + .build()); + + assertThat(mockWebServer.takeRequest().getPath()).isEqualTo("/send-as-is/%2e%2e/%2e%2e/path"); + assertThat(response) + .isEqualTo( + HttpResponse.builder() + .setStatus(HttpStatus.OK) + .setHeaders( + HttpHeaders.builder() + .addHeader(CONTENT_TYPE, MediaType.PLAIN_TEXT_UTF_8.toString()) + // MockWebServer always adds this response header. + .addHeader(CONTENT_LENGTH, String.valueOf(expectedResponseBody.length())) + .build()) + .setBodyBytes(ByteString.copyFrom(expectedResponseBody, UTF_8)) + .build()); } @Test @@ -763,4 +785,23 @@ public MockResponse dispatch(RecordedRequest recordedRequest) { return new MockResponse().setResponseCode(HttpStatus.NOT_FOUND.code()); } } + + static final class SendAsIsTestDispatcher extends Dispatcher { + static final String SEND_AS_IS_PATH = "/send-as-is/"; + + static String buildBody(String method, String requestBody) { + return String.format("Method: %s\nRequest Body: %s", method, requestBody); + } + + @Override + public MockResponse dispatch(RecordedRequest recordedRequest) { + if (recordedRequest.getPath().startsWith(SEND_AS_IS_PATH)) { + return new MockResponse() + .setHeader(CONTENT_TYPE, MediaType.PLAIN_TEXT_UTF_8.toString()) + .setBody(buildBody(recordedRequest.getMethod(), recordedRequest.getBody().readUtf8())) + .setResponseCode(HttpStatus.OK.code()); + } + return new MockResponse().setResponseCode(HttpStatus.NOT_FOUND.code()); + } + } }