diff --git a/okhttp-tests/src/test/java/okhttp3/internal/SocketRecorder.java b/okhttp-tests/src/test/java/okhttp3/internal/SocketRecorder.java new file mode 100644 index 000000000000..ba3b84f91eb9 --- /dev/null +++ b/okhttp-tests/src/test/java/okhttp3/internal/SocketRecorder.java @@ -0,0 +1,206 @@ +/* + * Copyright (C) 2016 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okhttp3.internal; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.net.Socket; +import java.util.Deque; +import java.util.concurrent.LinkedBlockingDeque; +import javax.net.ssl.SSLSocket; +import javax.net.ssl.SSLSocketFactory; +import okhttp3.DelegatingSSLSocket; +import okhttp3.DelegatingSSLSocketFactory; +import okio.Buffer; +import okio.ByteString; + +/** Records all bytes written and read from a socket and makes them available for inspection. */ +public final class SocketRecorder { + private final Deque recordedSockets = new LinkedBlockingDeque<>(); + + /** Returns an SSLSocketFactory whose sockets will record all transmitted bytes. */ + public SSLSocketFactory sslSocketFactory(SSLSocketFactory delegate) { + return new DelegatingSSLSocketFactory(delegate) { + @Override protected SSLSocket configureSocket(SSLSocket sslSocket) throws IOException { + RecordedSocket recordedSocket = new RecordedSocket(); + recordedSockets.add(recordedSocket); + return new RecordingSSLSocket(sslSocket, recordedSocket); + } + }; + } + + public RecordedSocket takeSocket() { + return recordedSockets.remove(); + } + + /** A bidirectional transfer of unadulterated bytes over a socket. */ + public static final class RecordedSocket { + private final Buffer bytesWritten = new Buffer(); + private final Buffer bytesRead = new Buffer(); + + synchronized void byteWritten(int b) { + bytesWritten.writeByte(b); + } + + synchronized void byteRead(int b) { + bytesRead.writeByte(b); + } + + synchronized void bytesWritten(byte[] bytes, int offset, int length) { + bytesWritten.write(bytes, offset, length); + } + + synchronized void bytesRead(byte[] bytes, int offset, int length) { + bytesRead.write(bytes, offset, length); + } + + /** Returns all bytes that have been written to this socket. */ + public synchronized ByteString bytesWritten() { + return bytesWritten.readByteString(); + } + + /** Returns all bytes that have been read from this socket. */ + public synchronized ByteString bytesRead() { + return bytesRead.readByteString(); + } + } + + static final class RecordingInputStream extends InputStream { + private final Socket socket; + private final RecordedSocket recordedSocket; + + RecordingInputStream(Socket socket, RecordedSocket recordedSocket) { + this.socket = socket; + this.recordedSocket = recordedSocket; + } + + @Override public int read() throws IOException { + int b = socket.getInputStream().read(); + if (b == -1) return -1; + recordedSocket.byteRead(b); + return b; + } + + @Override public int read(byte[] b, int off, int len) throws IOException { + int read = socket.getInputStream().read(b, off, len); + if (read == -1) return -1; + recordedSocket.bytesRead(b, off, read); + return read; + } + + @Override public void close() throws IOException { + socket.getInputStream().close(); + } + } + + static final class RecordingOutputStream extends OutputStream { + private final Socket socket; + private final RecordedSocket recordedSocket; + + RecordingOutputStream(Socket socket, RecordedSocket recordedSocket) { + this.socket = socket; + this.recordedSocket = recordedSocket; + } + + @Override public void write(int b) throws IOException { + socket.getOutputStream().write(b); + recordedSocket.byteWritten(b); + } + + @Override public void write(byte[] b, int off, int len) throws IOException { + socket.getOutputStream().write(b, off, len); + recordedSocket.bytesWritten(b, off, len); + } + + @Override public void close() throws IOException { + socket.getOutputStream().close(); + } + + @Override public void flush() throws IOException { + socket.getOutputStream().flush(); + } + } + + static final class RecordingSSLSocket extends DelegatingSSLSocket { + private final InputStream inputStream; + private final OutputStream outputStream; + + RecordingSSLSocket(SSLSocket delegate, RecordedSocket recordedSocket) { + super(delegate); + inputStream = new RecordingInputStream(delegate, recordedSocket); + outputStream = new RecordingOutputStream(delegate, recordedSocket); + } + + @Override public void startHandshake() throws IOException { + // Intercept the handshake to properly configure TLS extensions with Jetty ALPN. Jetty ALPN + // expects the real SSLSocket to be placed in the global map. Because we are wrapping the real + // SSLSocket, it confuses Jetty ALPN. This patches that up so things work as expected. + Class alpn = null; + Class provider = null; + try { + alpn = Class.forName("org.eclipse.jetty.alpn.ALPN"); + provider = Class.forName("org.eclipse.jetty.alpn.ALPN$Provider"); + } catch (ClassNotFoundException ignored) { + } + + if (alpn == null || provider == null) { + // No Jetty, so nothing to worry about. + super.startHandshake(); + return; + } + + Object providerInstance = null; + Method putMethod = null; + try { + Method getMethod = alpn.getMethod("get", SSLSocket.class); + putMethod = alpn.getMethod("put", SSLSocket.class, provider); + providerInstance = getMethod.invoke(null, this); + if (providerInstance == null) { + // Jetty's on the classpath but TLS extensions weren't used. + super.startHandshake(); + return; + } + + // TLS extensions were used; replace with the real SSLSocket to make Jetty ALPN happy. + putMethod.invoke(null, delegate, providerInstance); + super.startHandshake(); + } catch (NoSuchMethodException | IllegalAccessException | InvocationTargetException e) { + throw new AssertionError(); + } finally { + // If we replaced the SSLSocket in the global map, we must put the original back for + // everything to work inside OkHttp. + if (providerInstance != null) { + try { + putMethod.invoke(null, this, providerInstance); + } catch (IllegalAccessException | InvocationTargetException e) { + throw new AssertionError(); + } + } + } + } + + @Override public InputStream getInputStream() throws IOException { + return inputStream; + } + + @Override public OutputStream getOutputStream() throws IOException { + return outputStream; + } + } +} diff --git a/okhttp-tests/src/test/java/okhttp3/internal/http2/HttpOverHttp2Test.java b/okhttp-tests/src/test/java/okhttp3/internal/http2/HttpOverHttp2Test.java index 4c2e05844b66..b6dcd10ec25f 100644 --- a/okhttp-tests/src/test/java/okhttp3/internal/http2/HttpOverHttp2Test.java +++ b/okhttp-tests/src/test/java/okhttp3/internal/http2/HttpOverHttp2Test.java @@ -20,7 +20,10 @@ import java.net.Authenticator; import java.net.HttpURLConnection; import java.net.SocketTimeoutException; +import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; +import java.util.List; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; @@ -42,6 +45,7 @@ import okhttp3.internal.DoubleInetAddressDns; import okhttp3.internal.RecordingOkAuthenticator; import okhttp3.internal.SingleInetAddressDns; +import okhttp3.internal.SocketRecorder; import okhttp3.internal.Util; import okhttp3.internal.connection.RealConnection; import okhttp3.internal.tls.SslClient; @@ -52,6 +56,7 @@ import okhttp3.mockwebserver.SocketPolicy; import okio.Buffer; import okio.BufferedSink; +import okio.BufferedSource; import okio.GzipSink; import okio.Okio; import org.junit.After; @@ -67,6 +72,7 @@ import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; /** Test how SPDY interacts with HTTP/2 features. */ @@ -760,6 +766,70 @@ private void noRecoveryFromErrorWithRetryDisabled(ErrorCode errorCode) throws Ex assertEquals("bar", pushedRequest.getHeader("foo")); } + @Test public void noDataFramesSentWithNullRequestBody() throws Exception { + server.enqueue(new MockResponse() + .setBody("ABC")); + + SocketRecorder socketRecorder = new SocketRecorder(); + client = client.newBuilder() + .sslSocketFactory(socketRecorder.sslSocketFactory(sslClient.socketFactory), + sslClient.trustManager) + .build(); + + Call call = client.newCall(new Request.Builder() + .url(server.url("/")) + .method("DELETE", null) + .build()); + Response response = call.execute(); + assertEquals("ABC", response.body().string()); + + // Replay the bytes written by the client to confirm no data frames were sent. + SocketRecorder.RecordedSocket recordedSocket = socketRecorder.takeSocket(); + Buffer buffer = new Buffer(); + buffer.write(recordedSocket.bytesWritten()); + + RecordingHandler handler = new RecordingHandler(); + Http2Reader reader = new Http2Reader(buffer, false); + reader.readConnectionPreface(null); + while (reader.nextFrame(false, handler)) { + } + + assertEquals(1, handler.headerFrameCount); + assertTrue(handler.dataFrames.isEmpty()); + } + + @Test public void emptyDataFrameSentWithEmptyBody() throws Exception { + server.enqueue(new MockResponse() + .setBody("ABC")); + + SocketRecorder socketRecorder = new SocketRecorder(); + client = client.newBuilder() + .sslSocketFactory(socketRecorder.sslSocketFactory(sslClient.socketFactory), + sslClient.trustManager) + .build(); + + Call call = client.newCall(new Request.Builder() + .url(server.url("/")) + .method("DELETE", Util.EMPTY_REQUEST) + .build()); + Response response = call.execute(); + assertEquals("ABC", response.body().string()); + + // Replay the bytes written by the client to confirm an empty data frame was sent. + SocketRecorder.RecordedSocket recordedSocket = socketRecorder.takeSocket(); + Buffer buffer = new Buffer(); + buffer.write(recordedSocket.bytesWritten()); + + RecordingHandler handler = new RecordingHandler(); + Http2Reader reader = new Http2Reader(buffer, false); + reader.readConnectionPreface(null); + while (reader.nextFrame(false, handler)) { + } + + assertEquals(1, handler.headerFrameCount); + assertEquals(Collections.singletonList(0), handler.dataFrames); + } + /** * Push a setting that permits up to 2 concurrent streams, then make 3 concurrent requests and * confirm that the third concurrent request prepared a new connection. @@ -904,4 +974,28 @@ public AsyncRequest(String path, CountDownLatch countDownLatch) { } } } + + static final class RecordingHandler extends BaseTestHandler { + int headerFrameCount; + final List dataFrames = new ArrayList<>(); + + @Override public void settings(boolean clearPrevious, Settings settings) { + } + + @Override public void ackSettings() { + } + + @Override public void windowUpdate(int streamId, long windowSizeIncrement) { + } + + @Override public void data(boolean inFinished, int streamId, BufferedSource source, int length) + throws IOException { + dataFrames.add(length); + } + + @Override public void headers(boolean inFinished, int streamId, int associatedStreamId, + List
headerBlock) { + headerFrameCount++; + } + } }