From 8331b2d75a51631bee622daf2175854fe49ce29a Mon Sep 17 00:00:00 2001 From: jwilson Date: Mon, 2 Jan 2017 12:36:30 -0500 Subject: [PATCH] Fix RealConnection to guard allocationLimit by connectionPool. I'm working towards making OkHttp limit itself to a single HTTP/2 connection to a single host. In this work I found we're not sufficiently safe on allocationLimit - connections are added to the pool when this is 0, and the value is updated without any synchronization. This change also reduces the visibility of some connection fields in RealConnection and organizes the fields into two sets: those that are immutable after connect and those that are guarded by connectionPool. https://github.com/square/okhttp/issues/373 --- .../test/java/okhttp3/ConnectionPoolTest.java | 42 +++--- .../internal/connection/RealConnection.java | 130 ++++++++++++------ .../internal/connection/StreamAllocation.java | 18 +-- .../okhttp3/internal/ws/RealWebSocket.java | 15 +- 4 files changed, 109 insertions(+), 96 deletions(-) diff --git a/okhttp-tests/src/test/java/okhttp3/ConnectionPoolTest.java b/okhttp-tests/src/test/java/okhttp3/ConnectionPoolTest.java index 80346ba443e2..0d07c26b2de8 100644 --- a/okhttp-tests/src/test/java/okhttp3/ConnectionPoolTest.java +++ b/okhttp-tests/src/test/java/okhttp3/ConnectionPoolTest.java @@ -54,27 +54,27 @@ public final class ConnectionPoolTest { // Running at time 50, the pool returns that nothing can be evicted until time 150. assertEquals(100L, pool.cleanup(50L)); assertEquals(1, pool.connectionCount()); - assertFalse(c1.socket.isClosed()); + assertFalse(c1.socket().isClosed()); // Running at time 60, the pool returns that nothing can be evicted until time 150. assertEquals(90L, pool.cleanup(60L)); assertEquals(1, pool.connectionCount()); - assertFalse(c1.socket.isClosed()); + assertFalse(c1.socket().isClosed()); // Running at time 149, the pool returns that nothing can be evicted until time 150. assertEquals(1L, pool.cleanup(149L)); assertEquals(1, pool.connectionCount()); - assertFalse(c1.socket.isClosed()); + assertFalse(c1.socket().isClosed()); // Running at time 150, the pool evicts. assertEquals(0, pool.cleanup(150L)); assertEquals(0, pool.connectionCount()); - assertTrue(c1.socket.isClosed()); + assertTrue(c1.socket().isClosed()); // Running again, the pool reports that no further runs are necessary. assertEquals(-1, pool.cleanup(150L)); assertEquals(0, pool.connectionCount()); - assertTrue(c1.socket.isClosed()); + assertTrue(c1.socket().isClosed()); } @Test public void inUseConnectionsNotEvicted() throws Exception { @@ -90,17 +90,17 @@ public final class ConnectionPoolTest { // Running at time 50, the pool returns that nothing can be evicted until time 150. assertEquals(100L, pool.cleanup(50L)); assertEquals(1, pool.connectionCount()); - assertFalse(c1.socket.isClosed()); + assertFalse(c1.socket().isClosed()); // Running at time 60, the pool returns that nothing can be evicted until time 160. assertEquals(100L, pool.cleanup(60L)); assertEquals(1, pool.connectionCount()); - assertFalse(c1.socket.isClosed()); + assertFalse(c1.socket().isClosed()); // Running at time 160, the pool returns that nothing can be evicted until time 260. assertEquals(100L, pool.cleanup(160L)); assertEquals(1, pool.connectionCount()); - assertFalse(c1.socket.isClosed()); + assertFalse(c1.socket().isClosed()); } @Test public void cleanupPrioritizesEarliestEviction() throws Exception { @@ -121,8 +121,8 @@ public final class ConnectionPoolTest { // Running at time 150, the pool evicts c2. assertEquals(0L, pool.cleanup(150L)); assertEquals(1, pool.connectionCount()); - assertFalse(c1.socket.isClosed()); - assertTrue(c2.socket.isClosed()); + assertFalse(c1.socket().isClosed()); + assertTrue(c2.socket().isClosed()); // Running at time 150, the pool returns that nothing can be evicted until time 175. assertEquals(25L, pool.cleanup(150L)); @@ -131,8 +131,8 @@ public final class ConnectionPoolTest { // Running at time 175, the pool evicts c1. assertEquals(0L, pool.cleanup(175L)); assertEquals(0, pool.connectionCount()); - assertTrue(c1.socket.isClosed()); - assertTrue(c2.socket.isClosed()); + assertTrue(c1.socket().isClosed()); + assertTrue(c2.socket().isClosed()); } @Test public void oldestConnectionsEvictedIfIdleLimitExceeded() throws Exception { @@ -145,8 +145,8 @@ public final class ConnectionPoolTest { // With 2 connections, there's no need to evict until the connections time out. assertEquals(50L, pool.cleanup(100L)); assertEquals(2, pool.connectionCount()); - assertFalse(c1.socket.isClosed()); - assertFalse(c2.socket.isClosed()); + assertFalse(c1.socket().isClosed()); + assertFalse(c2.socket().isClosed()); // Add a third connection RealConnection c3 = newConnection(pool, routeC1, 75L); @@ -154,9 +154,9 @@ public final class ConnectionPoolTest { // The third connection bounces the first. assertEquals(0L, pool.cleanup(100L)); assertEquals(2, pool.connectionCount()); - assertTrue(c1.socket.isClosed()); - assertFalse(c2.socket.isClosed()); - assertFalse(c3.socket.isClosed()); + assertTrue(c1.socket().isClosed()); + assertFalse(c2.socket().isClosed()); + assertFalse(c3.socket().isClosed()); } @Test public void leakedAllocation() throws Exception { @@ -182,13 +182,11 @@ private void allocateAndLeakAllocation(ConnectionPool pool, RealConnection conne } private RealConnection newConnection(ConnectionPool pool, Route route, long idleAtNanos) { - RealConnection connection = new RealConnection(route); - connection.idleAtNanos = idleAtNanos; - connection.socket = new Socket(); + RealConnection result = RealConnection.testConnection(pool, route, new Socket(), idleAtNanos); synchronized (pool) { - pool.put(connection); + pool.put(result); } - return connection; + return result; } private Address newAddress(String name) { diff --git a/okhttp/src/main/java/okhttp3/internal/connection/RealConnection.java b/okhttp/src/main/java/okhttp3/internal/connection/RealConnection.java index 0583b6c71b5e..898b9af5ec6e 100644 --- a/okhttp/src/main/java/okhttp3/internal/connection/RealConnection.java +++ b/okhttp/src/main/java/okhttp3/internal/connection/RealConnection.java @@ -22,6 +22,7 @@ import java.net.ProtocolException; import java.net.Proxy; import java.net.Socket; +import java.net.SocketException; import java.net.SocketTimeoutException; import java.net.UnknownServiceException; import java.security.cert.X509Certificate; @@ -34,22 +35,27 @@ import okhttp3.Address; import okhttp3.CertificatePinner; import okhttp3.Connection; +import okhttp3.ConnectionPool; import okhttp3.ConnectionSpec; import okhttp3.Handshake; import okhttp3.HttpUrl; +import okhttp3.OkHttpClient; import okhttp3.Protocol; import okhttp3.Request; import okhttp3.Response; import okhttp3.Route; import okhttp3.internal.Util; import okhttp3.internal.Version; +import okhttp3.internal.http.HttpCodec; import okhttp3.internal.http.HttpHeaders; import okhttp3.internal.http1.Http1Codec; import okhttp3.internal.http2.ErrorCode; +import okhttp3.internal.http2.Http2Codec; import okhttp3.internal.http2.Http2Connection; import okhttp3.internal.http2.Http2Stream; import okhttp3.internal.platform.Platform; import okhttp3.internal.tls.OkHostnameVerifier; +import okhttp3.internal.ws.RealWebSocket; import okio.BufferedSink; import okio.BufferedSource; import okio.Okio; @@ -61,8 +67,11 @@ import static okhttp3.internal.Util.closeQuietly; public final class RealConnection extends Http2Connection.Listener implements Connection { + private final ConnectionPool connectionPool; private final Route route; + // The fields below are initialized by connect() and never reassigned. + /** The low-level TCP socket. */ private Socket rawSocket; @@ -70,22 +79,45 @@ public final class RealConnection extends Http2Connection.Listener implements Co * The application layer socket. Either an {@link SSLSocket} layered over {@link #rawSocket}, or * {@link #rawSocket} itself if this connection does not use SSL. */ - public Socket socket; + private Socket socket; private Handshake handshake; private Protocol protocol; - public volatile Http2Connection http2Connection; + private Http2Connection http2Connection; + private BufferedSource source; + private BufferedSink sink; + + // The fields below track connection state and are guarded by connectionPool. + + /** If true, no new streams can be created on this connection. Once true this is always true. */ + public boolean noNewStreams; + public int successCount; - public BufferedSource source; - public BufferedSink sink; - public int allocationLimit; + + /** + * The maximum number of concurrent streams that can be carried by this connection. If {@code + * allocations.size() < allocationLimit} then new streams can be created on this connection. + */ + public int allocationLimit = 1; + + /** Current streams carried by this connection. */ public final List> allocations = new ArrayList<>(); - public boolean noNewStreams; + + /** Nanotime timestamp when {@code allocations.size()} reached zero. */ public long idleAtNanos = Long.MAX_VALUE; - public RealConnection(Route route) { + public RealConnection(ConnectionPool connectionPool, Route route) { + this.connectionPool = connectionPool; this.route = route; } + public static RealConnection testConnection( + ConnectionPool connectionPool, Route route, Socket socket, long idleAtNanos) { + RealConnection result = new RealConnection(connectionPool, route); + result.socket = socket; + result.idleAtNanos = idleAtNanos; + return result; + } + public void connect(int connectTimeout, int readTimeout, int writeTimeout, List connectionSpecs, boolean connectionRetryEnabled) { if (protocol != null) throw new IllegalStateException("already connected"); @@ -105,14 +137,15 @@ public void connect(int connectTimeout, int readTimeout, int writeTimeout, } } - while (protocol == null) { + while (true) { try { if (route.requiresTunnel()) { - buildTunneledConnection(connectTimeout, readTimeout, writeTimeout, - connectionSpecSelector); + connectTunnel(connectTimeout, readTimeout, writeTimeout); } else { - buildConnection(connectTimeout, readTimeout, writeTimeout, connectionSpecSelector); + connectSocket(connectTimeout, readTimeout); } + establishProtocol(connectionSpecSelector); + break; } catch (IOException e) { closeQuietly(socket); closeQuietly(rawSocket); @@ -122,6 +155,7 @@ public void connect(int connectTimeout, int readTimeout, int writeTimeout, sink = null; handshake = null; protocol = null; + http2Connection = null; if (routeException == null) { routeException = new RouteException(e); @@ -134,14 +168,20 @@ public void connect(int connectTimeout, int readTimeout, int writeTimeout, } } } + + if (http2Connection != null) { + synchronized (connectionPool) { + allocationLimit = http2Connection.maxConcurrentStreams(); + } + } } /** * Does all the work to build an HTTPS connection over a proxy tunnel. The catch here is that a * proxy server can issue an auth challenge and then close the connection. */ - private void buildTunneledConnection(int connectTimeout, int readTimeout, int writeTimeout, - ConnectionSpecSelector connectionSpecSelector) throws IOException { + private void connectTunnel(int connectTimeout, int readTimeout, int writeTimeout) + throws IOException { Request tunnelRequest = createTunnelRequest(); HttpUrl url = tunnelRequest.url(); int attemptedConnections = 0; @@ -163,17 +203,9 @@ private void buildTunneledConnection(int connectTimeout, int readTimeout, int wr sink = null; source = null; } - - establishProtocol(readTimeout, writeTimeout, connectionSpecSelector); } /** Does all the work necessary to build a full HTTP or HTTPS connection on a raw socket. */ - private void buildConnection(int connectTimeout, int readTimeout, int writeTimeout, - ConnectionSpecSelector connectionSpecSelector) throws IOException { - connectSocket(connectTimeout, readTimeout); - establishProtocol(readTimeout, writeTimeout, connectionSpecSelector); - } - private void connectSocket(int connectTimeout, int readTimeout) throws IOException { Proxy proxy = route.proxy(); Address address = route.address(); @@ -194,34 +226,26 @@ private void connectSocket(int connectTimeout, int readTimeout) throws IOExcepti sink = Okio.buffer(Okio.sink(rawSocket)); } - private void establishProtocol(int readTimeout, int writeTimeout, - ConnectionSpecSelector connectionSpecSelector) throws IOException { - if (route.address().sslSocketFactory() != null) { - connectTls(readTimeout, writeTimeout, connectionSpecSelector); - } else { + private void establishProtocol(ConnectionSpecSelector connectionSpecSelector) throws IOException { + if (route.address().sslSocketFactory() == null) { protocol = Protocol.HTTP_1_1; socket = rawSocket; + return; } + connectTls(connectionSpecSelector); + if (protocol == Protocol.HTTP_2) { socket.setSoTimeout(0); // Framed connection timeouts are set per-stream. - - Http2Connection http2Connection = new Http2Connection.Builder(true) + http2Connection = new Http2Connection.Builder(true) .socket(socket, route.address().url().host(), source, sink) .listener(this) .build(); http2Connection.start(); - - // Only assign the framed connection once the preface has been sent successfully. - this.allocationLimit = http2Connection.maxConcurrentStreams(); - this.http2Connection = http2Connection; - } else { - this.allocationLimit = 1; } } - private void connectTls(int readTimeout, int writeTimeout, - ConnectionSpecSelector connectionSpecSelector) throws IOException { + private void connectTls(ConnectionSpecSelector connectionSpecSelector) throws IOException { Address address = route.address(); SSLSocketFactory sslSocketFactory = address.sslSocketFactory(); boolean success = false; @@ -343,11 +367,31 @@ private Request createTunnelRequest() { return new Request.Builder() .url(route.address().url()) .header("Host", Util.hostHeader(route.address().url(), true)) - .header("Proxy-Connection", "Keep-Alive") - .header("User-Agent", Version.userAgent()) // For HTTP/1.0 proxies like Squid. + .header("Proxy-Connection", "Keep-Alive") // For HTTP/1.0 proxies like Squid. + .header("User-Agent", Version.userAgent()) .build(); } + public HttpCodec newCodec( + OkHttpClient client, StreamAllocation streamAllocation) throws SocketException { + if (http2Connection != null) { + return new Http2Codec(client, streamAllocation, http2Connection); + } else { + socket.setSoTimeout(client.readTimeoutMillis()); + source.timeout().timeout(client.readTimeoutMillis(), MILLISECONDS); + sink.timeout().timeout(client.writeTimeoutMillis(), MILLISECONDS); + return new Http1Codec(client, streamAllocation, source, sink); + } + } + + public RealWebSocket.Streams newWebSocketStreams(final StreamAllocation streamAllocation) { + return new RealWebSocket.Streams(true, source, sink) { + @Override public void close() throws IOException { + streamAllocation.streamFinished(true, streamAllocation.codec()); + } + }; + } + @Override public Route route() { return route; } @@ -400,7 +444,9 @@ public boolean isHealthy(boolean doExtensiveChecks) { /** When settings are received, adjust the allocation limit. */ @Override public void onSettings(Http2Connection connection) { - allocationLimit = connection.maxConcurrentStreams(); + synchronized (connectionPool) { + allocationLimit = connection.maxConcurrentStreams(); + } } @Override public Handshake handshake() { @@ -416,11 +462,7 @@ public boolean isMultiplexed() { } @Override public Protocol protocol() { - if (http2Connection == null) { - return protocol != null ? protocol : Protocol.HTTP_1_1; - } else { - return Protocol.HTTP_2; - } + return protocol; } @Override public String toString() { diff --git a/okhttp/src/main/java/okhttp3/internal/connection/StreamAllocation.java b/okhttp/src/main/java/okhttp3/internal/connection/StreamAllocation.java index a2213dfc9b5a..d95185ae1803 100644 --- a/okhttp/src/main/java/okhttp3/internal/connection/StreamAllocation.java +++ b/okhttp/src/main/java/okhttp3/internal/connection/StreamAllocation.java @@ -25,14 +25,10 @@ import okhttp3.internal.Internal; import okhttp3.internal.Util; import okhttp3.internal.http.HttpCodec; -import okhttp3.internal.http1.Http1Codec; import okhttp3.internal.http2.ConnectionShutdownException; import okhttp3.internal.http2.ErrorCode; -import okhttp3.internal.http2.Http2Codec; import okhttp3.internal.http2.StreamResetException; -import static java.util.concurrent.TimeUnit.MILLISECONDS; - /** * This class coordinates the relationship between three entities: * @@ -100,17 +96,7 @@ public HttpCodec newStream(OkHttpClient client, boolean doExtensiveHealthChecks) try { RealConnection resultConnection = findHealthyConnection(connectTimeout, readTimeout, writeTimeout, connectionRetryEnabled, doExtensiveHealthChecks); - - HttpCodec resultCodec; - if (resultConnection.http2Connection != null) { - resultCodec = new Http2Codec(client, this, resultConnection.http2Connection); - } else { - resultConnection.socket().setSoTimeout(readTimeout); - resultConnection.source.timeout().timeout(readTimeout, MILLISECONDS); - resultConnection.sink.timeout().timeout(writeTimeout, MILLISECONDS); - resultCodec = new Http1Codec( - client, this, resultConnection.source, resultConnection.sink); - } + HttpCodec resultCodec = resultConnection.newCodec(client, this); synchronized (connectionPool) { codec = resultCodec; @@ -184,7 +170,7 @@ private RealConnection findConnection(int connectTimeout, int readTimeout, int w refusedStreamCount = 0; } } - RealConnection newConnection = new RealConnection(selectedRoute); + RealConnection newConnection = new RealConnection(connectionPool, selectedRoute); synchronized (connectionPool) { acquire(newConnection); diff --git a/okhttp/src/main/java/okhttp3/internal/ws/RealWebSocket.java b/okhttp/src/main/java/okhttp3/internal/ws/RealWebSocket.java index b05d9b77e540..4be038120268 100644 --- a/okhttp/src/main/java/okhttp3/internal/ws/RealWebSocket.java +++ b/okhttp/src/main/java/okhttp3/internal/ws/RealWebSocket.java @@ -190,7 +190,7 @@ public void connect(OkHttpClient client) { // Promote the HTTP streams into web socket streams. StreamAllocation streamAllocation = Internal.instance.streamAllocation(call); streamAllocation.noNewStreams(); // Prevent connection pooling! - Streams streams = new ClientStreams(streamAllocation); + Streams streams = streamAllocation.connection().newWebSocketStreams(streamAllocation); // Process all web socket messages. try { @@ -569,19 +569,6 @@ public Streams(boolean client, BufferedSource source, BufferedSink sink) { } } - static final class ClientStreams extends Streams { - private final StreamAllocation streamAllocation; - - ClientStreams(StreamAllocation streamAllocation) { - super(true, streamAllocation.connection().source, streamAllocation.connection().sink); - this.streamAllocation = streamAllocation; - } - - @Override public void close() { - streamAllocation.streamFinished(true, streamAllocation.codec()); - } - } - final class CancelRunnable implements Runnable { @Override public void run() { cancel();