Skip to content

Commit

Permalink
wrap client http response body in buffered input stream to support ma…
Browse files Browse the repository at this point in the history
…rk/reset (zalando#963) (zalando#1041)
  • Loading branch information
noffke authored Feb 21, 2023
1 parent 7d2d192 commit 4f4f1a6
Show file tree
Hide file tree
Showing 6 changed files with 208 additions and 39 deletions.
21 changes: 21 additions & 0 deletions logbook-spring/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
<groupId>org.zalando</groupId>
<artifactId>logbook-test</artifactId>
</dependency>

<dependency>
<groupId>org.springframework</groupId>
<artifactId>spring-test</artifactId>
Expand All @@ -69,5 +70,25 @@
</exclusion>
</exclusions>
</dependency>

<dependency>
<groupId>com.github.tomakehurst</groupId>
<artifactId>wiremock-jre8</artifactId>
<version>2.28.0</version>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.apache.httpcomponents</groupId>
<artifactId>httpclient</artifactId>
<version>4.5.13</version>
<scope>test</scope>
<exclusions>
<exclusion>
<groupId>commons-logging</groupId>
<artifactId>commons-logging</artifactId>
</exclusion>
</exclusions>
</dependency>
</dependencies>
</project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package org.zalando.logbook.spring;

import java.io.BufferedInputStream;
import java.io.IOException;
import java.io.InputStream;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.http.client.ClientHttpResponse;

public class BufferingClientHttpResponseWrapper implements ClientHttpResponse {

private final ClientHttpResponse delegate;
private final InputStream body;

public BufferingClientHttpResponseWrapper(ClientHttpResponse delegate) throws IOException {
this.delegate = delegate;
final InputStream delegateBody = delegate.getBody();
this.body = delegateBody.markSupported() ? delegateBody : new BufferedInputStream(delegateBody);
}

@Override
public HttpStatus getStatusCode() throws IOException {
return delegate.getStatusCode();
}

@Override
public int getRawStatusCode() throws IOException {
return delegate.getRawStatusCode();
}

@Override
public String getStatusText() throws IOException {
return delegate.getStatusText();
}

@Override
public void close() {
try {
body.close();
} catch (IOException e){
throw new RuntimeException(e);
}
delegate.close();
}

@Override
public InputStream getBody() {
return body;
}

@Override
public HttpHeaders getHeaders() {
return delegate.getHeaders();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ public ClientHttpResponse intercept(HttpRequest request, byte[] body, ClientHttp
final org.zalando.logbook.HttpRequest httpRequest = new LocalRequest(request, body);
final Logbook.ResponseProcessingStage stage = logbook.process(httpRequest).write();

ClientHttpResponse response = execution.execute(request, body);
ClientHttpResponse response = new BufferingClientHttpResponseWrapper(execution.execute(request, body));

final HttpResponse httpResponse = new RemoteResponse(response);
stage.process(httpResponse).write();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ public State without() {
@Override
public State buffer(final ClientHttpResponse response) throws IOException {
InputStream responseBodyStream = response.getBody();
responseBodyStream.mark(Integer.MAX_VALUE);
byte[] data = ByteStreams.toByteArray(responseBodyStream);
responseBodyStream.reset();
return new Buffering(data);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
package org.zalando.logbook.spring;

import java.io.BufferedInputStream;
import java.io.IOException;
import java.io.InputStream;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.http.client.ClientHttpResponse;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

@ExtendWith(MockitoExtension.class)
class BufferingClientHttpResponseWrapperTest {

@Mock
private ClientHttpResponse delegate;

@Mock
private InputStream inputStream;

private BufferingClientHttpResponseWrapper wrapper;

@BeforeEach
void setUp() throws IOException {
when(delegate.getBody()).thenReturn(inputStream);
wrapper = new BufferingClientHttpResponseWrapper(delegate);
}

@Test
void wrapBodyInBufferedInputStreamWhenMarkNotSupported() throws IOException {
when(inputStream.markSupported()).thenReturn(false);

assertTrue(new BufferingClientHttpResponseWrapper(delegate).getBody() instanceof BufferedInputStream);
}

@Test
void dontWrapBodyInBufferedInputStreamWhenMarkSupported() throws IOException {
when(inputStream.markSupported()).thenReturn(true);

assertEquals(inputStream, new BufferingClientHttpResponseWrapper(delegate).getBody());
}

@Test
void getStatusCode() throws IOException {
when(delegate.getStatusCode()).thenReturn(HttpStatus.OK);

assertEquals(HttpStatus.OK, wrapper.getStatusCode());
}

@Test
void getRawStatusCode() throws IOException {
when(delegate.getRawStatusCode()).thenReturn(200);

assertEquals(200, wrapper.getRawStatusCode());
}

@Test
void getStatusText() throws IOException {
when(delegate.getStatusText()).thenReturn("OK");

assertEquals("OK", wrapper.getStatusText());
}

@Test
void close() {
wrapper.close();
verify(delegate).close();
}

@Test
void close_throws() throws IOException {
doThrow(new IOException()).when(inputStream).close();

assertThrows(RuntimeException.class, () -> wrapper.close());
}

@Test
void getBody() {
assertTrue(wrapper.getBody().markSupported());
}

@Test
void getHeaders() {
final HttpHeaders httpHeaders = new HttpHeaders();
when(delegate.getHeaders()).thenReturn(httpHeaders);

assertEquals(httpHeaders, wrapper.getHeaders());
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package org.zalando.logbook.spring;

import com.github.tomakehurst.wiremock.WireMockServer;
import java.io.IOException;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
Expand All @@ -8,8 +10,7 @@
import org.mockito.Captor;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.http.HttpMethod;
import org.springframework.test.web.client.MockRestServiceServer;
import org.springframework.http.client.HttpComponentsClientHttpRequestFactory;
import org.springframework.web.client.HttpClientErrorException;
import org.springframework.web.client.RestTemplate;
import org.zalando.logbook.Correlation;
Expand All @@ -19,24 +20,25 @@
import org.zalando.logbook.Logbook;
import org.zalando.logbook.Precorrelation;
import org.zalando.logbook.TestStrategy;

import java.io.IOException;

import static com.github.tomakehurst.wiremock.client.WireMock.aResponse;
import static com.github.tomakehurst.wiremock.client.WireMock.equalTo;
import static com.github.tomakehurst.wiremock.client.WireMock.get;
import static com.github.tomakehurst.wiremock.client.WireMock.post;
import static com.github.tomakehurst.wiremock.client.WireMock.postRequestedFor;
import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo;
import static com.github.tomakehurst.wiremock.core.WireMockConfiguration.options;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.springframework.test.web.client.ExpectedCount.once;
import static org.springframework.test.web.client.match.MockRestRequestMatchers.method;
import static org.springframework.test.web.client.match.MockRestRequestMatchers.requestTo;
import static org.springframework.test.web.client.response.MockRestResponseCreators.withBadRequest;
import static org.springframework.test.web.client.response.MockRestResponseCreators.withSuccess;

@ExtendWith(MockitoExtension.class)
class LogbookClientHttpRequestInterceptorTest {

private final WireMockServer server = new WireMockServer(options().dynamicPort());

@Mock
private HttpLogWriter writer;

Expand All @@ -53,34 +55,30 @@ class LogbookClientHttpRequestInterceptorTest {
private ArgumentCaptor<Correlation> correlationCaptor;

private RestTemplate restTemplate;
private MockRestServiceServer serviceServer;
private Logbook logbook;

private LogbookClientHttpRequestInterceptor interceptor;

@BeforeEach
void setup() {
server.start();
when(writer.isActive()).thenReturn(true);
logbook = Logbook.builder()
Logbook logbook = Logbook.builder()
.strategy(new TestStrategy())
.sink(new DefaultSink(new DefaultHttpLogFormatter(), writer))
.build();
interceptor = new LogbookClientHttpRequestInterceptor(logbook);
restTemplate = new RestTemplate();
LogbookClientHttpRequestInterceptor interceptor = new LogbookClientHttpRequestInterceptor(logbook);
restTemplate = new RestTemplate(new HttpComponentsClientHttpRequestFactory());
restTemplate.getInterceptors().add(interceptor);
serviceServer = MockRestServiceServer.createServer(restTemplate);
}

@AfterEach
void mockHttpVerify() {
serviceServer.verify();
void tearDown() {
server.stop();
}

@Test
void get200() throws IOException {
serviceServer.expect(once(), requestTo("/test/get")).andExpect(method(HttpMethod.GET))
.andRespond(withSuccess().body("response"));
restTemplate.getForObject("/test/get", String.class);
server.stubFor(get("/test/get/withcontent").willReturn(aResponse().withStatus(200).withBody("response")));

restTemplate.getForObject(server.baseUrl() + "/test/get/withcontent", String.class);

verify(writer).write(precorrelationCaptor.capture(), requestCaptor.capture());
verify(writer).write(correlationCaptor.capture(), responseCaptor.capture());
Expand All @@ -98,34 +96,28 @@ void get200() throws IOException {

@Test
void get200WithEmptyResponseBody(){
serviceServer.expect(once(), requestTo("/test/get")).andExpect(method(HttpMethod.GET))
.andRespond(withSuccess());

restTemplate.getForObject("/test/get", Void.class);
server.stubFor(get("/test/get/withoutcontent").willReturn(aResponse().withStatus(200)));
restTemplate.getForObject(server.baseUrl() + "/test/get/withoutcontent", Void.class);
}

@Test
void get200WithNonEmptyResponseBody() {
String expectedResponseBody = "response";
serviceServer.expect(once(), requestTo("/test/get")).andExpect(method(HttpMethod.GET))
.andRespond(withSuccess().body(expectedResponseBody));

String actualResponseBody = restTemplate.getForObject("/test/get", String.class);
server.stubFor(get("/test/get/withcontent").willReturn(aResponse().withStatus(200).withBody("response")));
String actualResponseBody = restTemplate.getForObject(server.baseUrl() + "/test/get/withcontent", String.class);

assertNotNull(actualResponseBody);
assertEquals(expectedResponseBody, actualResponseBody);
assertEquals("response", actualResponseBody);
}

@Test
void post400() throws IOException {
serviceServer.expect(once(), requestTo("/test/post")).andExpect(method(HttpMethod.POST))
.andRespond(withBadRequest().body("response"));
assertThrows(HttpClientErrorException.class, () -> restTemplate.postForObject("/test/post", "request", Void.class));
server.stubFor(post("/test/post/withcontent").willReturn(aResponse().withStatus(400).withBody("response")));
assertThrows(HttpClientErrorException.class, () -> restTemplate.postForObject(server.baseUrl() + "/test/post/withcontent", "request", String.class));

verify(writer).write(precorrelationCaptor.capture(), requestCaptor.capture());
verify(writer).write(correlationCaptor.capture(), responseCaptor.capture());

assertTrue(requestCaptor.getValue().contains("/test/post"));
assertTrue(requestCaptor.getValue().contains("/test/post/withcontent"));
assertTrue(requestCaptor.getValue().contains("POST"));
assertTrue(requestCaptor.getValue().contains("Remote: localhost"));
assertTrue(requestCaptor.getValue().contains(precorrelationCaptor.getValue().getId()));
Expand All @@ -135,5 +127,7 @@ void post400() throws IOException {
assertTrue(responseCaptor.getValue().contains(precorrelationCaptor.getValue().getId()));
assertTrue(responseCaptor.getValue().contains("400 Bad Request"));
assertTrue(responseCaptor.getValue().contains("response"));

server.verify(postRequestedFor(urlEqualTo("/test/post/withcontent")).withRequestBody(equalTo("request")));
}
}

0 comments on commit 4f4f1a6

Please sign in to comment.