Skip to content

Commit

Permalink
Merge pull request square#3025 from square/dr.1207.http2-visibility
Browse files Browse the repository at this point in the history
Introduce SocketRecorder: a means to inspect all bytes written/read to a socket.
  • Loading branch information
swankjesse authored Dec 26, 2016
2 parents cddb441 + e3d5951 commit 525f922
Show file tree
Hide file tree
Showing 2 changed files with 300 additions and 0 deletions.
206 changes: 206 additions & 0 deletions okhttp-tests/src/test/java/okhttp3/internal/SocketRecorder.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
/*
* Copyright (C) 2016 Square, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package okhttp3.internal;

import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.net.Socket;
import java.util.Deque;
import java.util.concurrent.LinkedBlockingDeque;
import javax.net.ssl.SSLSocket;
import javax.net.ssl.SSLSocketFactory;
import okhttp3.DelegatingSSLSocket;
import okhttp3.DelegatingSSLSocketFactory;
import okio.Buffer;
import okio.ByteString;

/** Records all bytes written and read from a socket and makes them available for inspection. */
public final class SocketRecorder {
private final Deque<RecordedSocket> recordedSockets = new LinkedBlockingDeque<>();

/** Returns an SSLSocketFactory whose sockets will record all transmitted bytes. */
public SSLSocketFactory sslSocketFactory(SSLSocketFactory delegate) {
return new DelegatingSSLSocketFactory(delegate) {
@Override protected SSLSocket configureSocket(SSLSocket sslSocket) throws IOException {
RecordedSocket recordedSocket = new RecordedSocket();
recordedSockets.add(recordedSocket);
return new RecordingSSLSocket(sslSocket, recordedSocket);
}
};
}

public RecordedSocket takeSocket() {
return recordedSockets.remove();
}

/** A bidirectional transfer of unadulterated bytes over a socket. */
public static final class RecordedSocket {
private final Buffer bytesWritten = new Buffer();
private final Buffer bytesRead = new Buffer();

synchronized void byteWritten(int b) {
bytesWritten.writeByte(b);
}

synchronized void byteRead(int b) {
bytesRead.writeByte(b);
}

synchronized void bytesWritten(byte[] bytes, int offset, int length) {
bytesWritten.write(bytes, offset, length);
}

synchronized void bytesRead(byte[] bytes, int offset, int length) {
bytesRead.write(bytes, offset, length);
}

/** Returns all bytes that have been written to this socket. */
public synchronized ByteString bytesWritten() {
return bytesWritten.readByteString();
}

/** Returns all bytes that have been read from this socket. */
public synchronized ByteString bytesRead() {
return bytesRead.readByteString();
}
}

static final class RecordingInputStream extends InputStream {
private final Socket socket;
private final RecordedSocket recordedSocket;

RecordingInputStream(Socket socket, RecordedSocket recordedSocket) {
this.socket = socket;
this.recordedSocket = recordedSocket;
}

@Override public int read() throws IOException {
int b = socket.getInputStream().read();
if (b == -1) return -1;
recordedSocket.byteRead(b);
return b;
}

@Override public int read(byte[] b, int off, int len) throws IOException {
int read = socket.getInputStream().read(b, off, len);
if (read == -1) return -1;
recordedSocket.bytesRead(b, off, read);
return read;
}

@Override public void close() throws IOException {
socket.getInputStream().close();
}
}

static final class RecordingOutputStream extends OutputStream {
private final Socket socket;
private final RecordedSocket recordedSocket;

RecordingOutputStream(Socket socket, RecordedSocket recordedSocket) {
this.socket = socket;
this.recordedSocket = recordedSocket;
}

@Override public void write(int b) throws IOException {
socket.getOutputStream().write(b);
recordedSocket.byteWritten(b);
}

@Override public void write(byte[] b, int off, int len) throws IOException {
socket.getOutputStream().write(b, off, len);
recordedSocket.bytesWritten(b, off, len);
}

@Override public void close() throws IOException {
socket.getOutputStream().close();
}

@Override public void flush() throws IOException {
socket.getOutputStream().flush();
}
}

static final class RecordingSSLSocket extends DelegatingSSLSocket {
private final InputStream inputStream;
private final OutputStream outputStream;

RecordingSSLSocket(SSLSocket delegate, RecordedSocket recordedSocket) {
super(delegate);
inputStream = new RecordingInputStream(delegate, recordedSocket);
outputStream = new RecordingOutputStream(delegate, recordedSocket);
}

@Override public void startHandshake() throws IOException {
// Intercept the handshake to properly configure TLS extensions with Jetty ALPN. Jetty ALPN
// expects the real SSLSocket to be placed in the global map. Because we are wrapping the real
// SSLSocket, it confuses Jetty ALPN. This patches that up so things work as expected.
Class<?> alpn = null;
Class<?> provider = null;
try {
alpn = Class.forName("org.eclipse.jetty.alpn.ALPN");
provider = Class.forName("org.eclipse.jetty.alpn.ALPN$Provider");
} catch (ClassNotFoundException ignored) {
}

if (alpn == null || provider == null) {
// No Jetty, so nothing to worry about.
super.startHandshake();
return;
}

Object providerInstance = null;
Method putMethod = null;
try {
Method getMethod = alpn.getMethod("get", SSLSocket.class);
putMethod = alpn.getMethod("put", SSLSocket.class, provider);
providerInstance = getMethod.invoke(null, this);
if (providerInstance == null) {
// Jetty's on the classpath but TLS extensions weren't used.
super.startHandshake();
return;
}

// TLS extensions were used; replace with the real SSLSocket to make Jetty ALPN happy.
putMethod.invoke(null, delegate, providerInstance);
super.startHandshake();
} catch (NoSuchMethodException | IllegalAccessException | InvocationTargetException e) {
throw new AssertionError();
} finally {
// If we replaced the SSLSocket in the global map, we must put the original back for
// everything to work inside OkHttp.
if (providerInstance != null) {
try {
putMethod.invoke(null, this, providerInstance);
} catch (IllegalAccessException | InvocationTargetException e) {
throw new AssertionError();
}
}
}
}

@Override public InputStream getInputStream() throws IOException {
return inputStream;
}

@Override public OutputStream getOutputStream() throws IOException {
return outputStream;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@
import java.net.Authenticator;
import java.net.HttpURLConnection;
import java.net.SocketTimeoutException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
Expand All @@ -42,6 +45,7 @@
import okhttp3.internal.DoubleInetAddressDns;
import okhttp3.internal.RecordingOkAuthenticator;
import okhttp3.internal.SingleInetAddressDns;
import okhttp3.internal.SocketRecorder;
import okhttp3.internal.Util;
import okhttp3.internal.connection.RealConnection;
import okhttp3.internal.tls.SslClient;
Expand All @@ -52,6 +56,7 @@
import okhttp3.mockwebserver.SocketPolicy;
import okio.Buffer;
import okio.BufferedSink;
import okio.BufferedSource;
import okio.GzipSink;
import okio.Okio;
import org.junit.After;
Expand All @@ -67,6 +72,7 @@
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;

/** Test how SPDY interacts with HTTP/2 features. */
Expand Down Expand Up @@ -760,6 +766,70 @@ private void noRecoveryFromErrorWithRetryDisabled(ErrorCode errorCode) throws Ex
assertEquals("bar", pushedRequest.getHeader("foo"));
}

@Test public void noDataFramesSentWithNullRequestBody() throws Exception {
server.enqueue(new MockResponse()
.setBody("ABC"));

SocketRecorder socketRecorder = new SocketRecorder();
client = client.newBuilder()
.sslSocketFactory(socketRecorder.sslSocketFactory(sslClient.socketFactory),
sslClient.trustManager)
.build();

Call call = client.newCall(new Request.Builder()
.url(server.url("/"))
.method("DELETE", null)
.build());
Response response = call.execute();
assertEquals("ABC", response.body().string());

// Replay the bytes written by the client to confirm no data frames were sent.
SocketRecorder.RecordedSocket recordedSocket = socketRecorder.takeSocket();
Buffer buffer = new Buffer();
buffer.write(recordedSocket.bytesWritten());

RecordingHandler handler = new RecordingHandler();
Http2Reader reader = new Http2Reader(buffer, false);
reader.readConnectionPreface(null);
while (reader.nextFrame(false, handler)) {
}

assertEquals(1, handler.headerFrameCount);
assertTrue(handler.dataFrames.isEmpty());
}

@Test public void emptyDataFrameSentWithEmptyBody() throws Exception {
server.enqueue(new MockResponse()
.setBody("ABC"));

SocketRecorder socketRecorder = new SocketRecorder();
client = client.newBuilder()
.sslSocketFactory(socketRecorder.sslSocketFactory(sslClient.socketFactory),
sslClient.trustManager)
.build();

Call call = client.newCall(new Request.Builder()
.url(server.url("/"))
.method("DELETE", Util.EMPTY_REQUEST)
.build());
Response response = call.execute();
assertEquals("ABC", response.body().string());

// Replay the bytes written by the client to confirm an empty data frame was sent.
SocketRecorder.RecordedSocket recordedSocket = socketRecorder.takeSocket();
Buffer buffer = new Buffer();
buffer.write(recordedSocket.bytesWritten());

RecordingHandler handler = new RecordingHandler();
Http2Reader reader = new Http2Reader(buffer, false);
reader.readConnectionPreface(null);
while (reader.nextFrame(false, handler)) {
}

assertEquals(1, handler.headerFrameCount);
assertEquals(Collections.singletonList(0), handler.dataFrames);
}

/**
* Push a setting that permits up to 2 concurrent streams, then make 3 concurrent requests and
* confirm that the third concurrent request prepared a new connection.
Expand Down Expand Up @@ -904,4 +974,28 @@ public AsyncRequest(String path, CountDownLatch countDownLatch) {
}
}
}

static final class RecordingHandler extends BaseTestHandler {
int headerFrameCount;
final List<Integer> dataFrames = new ArrayList<>();

@Override public void settings(boolean clearPrevious, Settings settings) {
}

@Override public void ackSettings() {
}

@Override public void windowUpdate(int streamId, long windowSizeIncrement) {
}

@Override public void data(boolean inFinished, int streamId, BufferedSource source, int length)
throws IOException {
dataFrames.add(length);
}

@Override public void headers(boolean inFinished, int streamId, int associatedStreamId,
List<Header> headerBlock) {
headerFrameCount++;
}
}
}

0 comments on commit 525f922

Please sign in to comment.