diff --git a/mockwebserver/src/main/java/com/squareup/okhttp/mockwebserver/MockWebServer.java b/mockwebserver/src/main/java/com/squareup/okhttp/mockwebserver/MockWebServer.java index e79831a9afb2..86c314f19743 100644 --- a/mockwebserver/src/main/java/com/squareup/okhttp/mockwebserver/MockWebServer.java +++ b/mockwebserver/src/main/java/com/squareup/okhttp/mockwebserver/MockWebServer.java @@ -499,11 +499,13 @@ private boolean processOneRequest(Socket socket, BufferedSource source, Buffered throw new ProtocolException("unexpected data"); } + boolean reuseSocket = true; boolean requestWantsWebSockets = "Upgrade".equalsIgnoreCase(request.getHeader("Connection")) && "websocket".equalsIgnoreCase(request.getHeader("Upgrade")); boolean responseWantsWebSockets = response.getWebSocketListener() != null; if (requestWantsWebSockets && responseWantsWebSockets) { handleWebSocketUpgrade(socket, source, sink, request, response); + reuseSocket = false; } else { writeHttpResponse(socket, sink, response); } @@ -523,7 +525,7 @@ private boolean processOneRequest(Socket socket, BufferedSource source, Buffered } sequenceNumber++; - return true; + return reuseSocket; } }); } diff --git a/okhttp-tests/src/test/java/com/squareup/okhttp/internal/RecordingHostnameVerifier.java b/okhttp-testing-support/src/main/java/com/squareup/okhttp/testing/RecordingHostnameVerifier.java similarity index 96% rename from okhttp-tests/src/test/java/com/squareup/okhttp/internal/RecordingHostnameVerifier.java rename to okhttp-testing-support/src/main/java/com/squareup/okhttp/testing/RecordingHostnameVerifier.java index c9d914f5f2c4..d4d343a5f32f 100644 --- a/okhttp-tests/src/test/java/com/squareup/okhttp/internal/RecordingHostnameVerifier.java +++ b/okhttp-testing-support/src/main/java/com/squareup/okhttp/testing/RecordingHostnameVerifier.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.squareup.okhttp.internal; +package com.squareup.okhttp.testing; import java.util.ArrayList; import java.util.List; diff --git a/okhttp-tests/src/test/java/com/squareup/okhttp/CallTest.java b/okhttp-tests/src/test/java/com/squareup/okhttp/CallTest.java index 63adddda1ccc..8600cdbeb6f6 100644 --- a/okhttp-tests/src/test/java/com/squareup/okhttp/CallTest.java +++ b/okhttp-tests/src/test/java/com/squareup/okhttp/CallTest.java @@ -17,7 +17,7 @@ import com.squareup.okhttp.internal.DoubleInetAddressNetwork; import com.squareup.okhttp.internal.Internal; -import com.squareup.okhttp.internal.RecordingHostnameVerifier; +import com.squareup.okhttp.testing.RecordingHostnameVerifier; import com.squareup.okhttp.internal.RecordingOkAuthenticator; import com.squareup.okhttp.internal.SingleInetAddressNetwork; import com.squareup.okhttp.internal.SslContextBuilder; diff --git a/okhttp-tests/src/test/java/com/squareup/okhttp/ConnectionPoolTest.java b/okhttp-tests/src/test/java/com/squareup/okhttp/ConnectionPoolTest.java index 4e8ec7a043e0..64164b136c44 100644 --- a/okhttp-tests/src/test/java/com/squareup/okhttp/ConnectionPoolTest.java +++ b/okhttp-tests/src/test/java/com/squareup/okhttp/ConnectionPoolTest.java @@ -16,7 +16,7 @@ package com.squareup.okhttp; import com.squareup.okhttp.internal.Internal; -import com.squareup.okhttp.internal.RecordingHostnameVerifier; +import com.squareup.okhttp.testing.RecordingHostnameVerifier; import com.squareup.okhttp.internal.SslContextBuilder; import com.squareup.okhttp.internal.Util; import com.squareup.okhttp.internal.http.AuthenticatorAdapter; diff --git a/okhttp-tests/src/test/java/com/squareup/okhttp/internal/http/URLConnectionTest.java b/okhttp-tests/src/test/java/com/squareup/okhttp/internal/http/URLConnectionTest.java index 77724a55a89c..bf22e1501003 100644 --- a/okhttp-tests/src/test/java/com/squareup/okhttp/internal/http/URLConnectionTest.java +++ b/okhttp-tests/src/test/java/com/squareup/okhttp/internal/http/URLConnectionTest.java @@ -33,7 +33,7 @@ import com.squareup.okhttp.TlsVersion; import com.squareup.okhttp.internal.Internal; import com.squareup.okhttp.internal.RecordingAuthenticator; -import com.squareup.okhttp.internal.RecordingHostnameVerifier; +import com.squareup.okhttp.testing.RecordingHostnameVerifier; import com.squareup.okhttp.internal.RecordingOkAuthenticator; import com.squareup.okhttp.internal.SingleInetAddressNetwork; import com.squareup.okhttp.internal.SslContextBuilder; diff --git a/okhttp-ws-tests/src/test/java/com/squareup/okhttp/ws/WebSocketCallTest.java b/okhttp-ws-tests/src/test/java/com/squareup/okhttp/ws/WebSocketCallTest.java index 63d21cb73f93..561d20943966 100644 --- a/okhttp-ws-tests/src/test/java/com/squareup/okhttp/ws/WebSocketCallTest.java +++ b/okhttp-ws-tests/src/test/java/com/squareup/okhttp/ws/WebSocketCallTest.java @@ -18,14 +18,17 @@ import com.squareup.okhttp.OkHttpClient; import com.squareup.okhttp.Request; import com.squareup.okhttp.Response; +import com.squareup.okhttp.internal.SslContextBuilder; import com.squareup.okhttp.mockwebserver.MockResponse; import com.squareup.okhttp.mockwebserver.rule.MockWebServerRule; +import com.squareup.okhttp.testing.RecordingHostnameVerifier; import java.io.IOException; import java.net.ProtocolException; import java.util.Random; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; +import javax.net.ssl.SSLContext; import okio.Buffer; import okio.BufferedSink; import okio.BufferedSource; @@ -36,6 +39,7 @@ import static com.squareup.okhttp.ws.WebSocket.PayloadType.TEXT; public final class WebSocketCallTest { + private static final SSLContext sslContext = SslContextBuilder.localhost(); @Rule public final MockWebServerRule server = new MockWebServerRule(); private final WebSocketRecorder listener = new WebSocketRecorder(); @@ -181,8 +185,48 @@ public final class WebSocketCallTest { "Expected 'Sec-WebSocket-Accept' header value 'ujmZX4KXZqjwy6vi1aQFH5p4Ygk=' but was 'magic'"); } + @Test public void wsScheme() throws IOException { + websocketScheme("ws"); + } + + @Test public void wsUppercaseScheme() throws IOException { + websocketScheme("WS"); + } + + @Test public void wssScheme() throws IOException { + server.get().useHttps(sslContext.getSocketFactory(), false); + client.setSslSocketFactory(sslContext.getSocketFactory()); + client.setHostnameVerifier(new RecordingHostnameVerifier()); + + websocketScheme("wss"); + } + + @Test public void httpsScheme() throws IOException { + server.get().useHttps(sslContext.getSocketFactory(), false); + client.setSslSocketFactory(sslContext.getSocketFactory()); + client.setHostnameVerifier(new RecordingHostnameVerifier()); + + websocketScheme("https"); + } + + private void websocketScheme(String scheme) throws IOException { + WebSocketRecorder serverListener = new WebSocketRecorder(); + server.enqueue(new MockResponse().withWebSocketUpgrade(serverListener)); + + Request request1 = new Request.Builder() + .url(scheme + "://" + server.getHostName() + ":" + server.getPort() + "/") + .build(); + + WebSocket webSocket = awaitWebSocket(request1); + webSocket.sendMessage(TEXT, new Buffer().writeUtf8("abc")); + serverListener.assertTextMessage("abc"); + } + private WebSocket awaitWebSocket() { - Request request = new Request.Builder().get().url(server.getUrl("/")).build(); + return awaitWebSocket(new Request.Builder().get().url(server.getUrl("/")).build()); + } + + private WebSocket awaitWebSocket(Request request) { WebSocketCall call = new WebSocketCall(client, request, random); final AtomicReference responseRef = new AtomicReference<>(); diff --git a/okhttp-ws/src/main/java/com/squareup/okhttp/ws/WebSocketCall.java b/okhttp-ws/src/main/java/com/squareup/okhttp/ws/WebSocketCall.java index b499485afac2..80b3f0558f07 100644 --- a/okhttp-ws/src/main/java/com/squareup/okhttp/ws/WebSocketCall.java +++ b/okhttp-ws/src/main/java/com/squareup/okhttp/ws/WebSocketCall.java @@ -61,19 +61,6 @@ public static WebSocketCall create(OkHttpClient client, Request request) { if (!"GET".equals(request.method())) { throw new IllegalArgumentException("Request must be GET: " + request.method()); } - String url = request.urlString(); - String httpUrl; - if (url.startsWith("ws://")) { - httpUrl = "http://" + url.substring(5); - } else if (url.startsWith("wss://")) { - httpUrl = "https://" + url.substring(6); - } else if (url.startsWith("http://") || url.startsWith("https://")) { - httpUrl = url; - } else { - throw new IllegalArgumentException( - "Request url must use 'ws', 'wss', 'http', or 'https' scheme: " + url); - } - this.random = random; byte[] nonce = new byte[16]; @@ -87,7 +74,6 @@ public static WebSocketCall create(OkHttpClient client, Request request) { client.setProtocols(Collections.singletonList(com.squareup.okhttp.Protocol.HTTP_1_1)); request = request.newBuilder() - .url(httpUrl) .header("Upgrade", "websocket") .header("Connection", "Upgrade") .header("Sec-WebSocket-Key", key) diff --git a/okhttp/src/main/java/com/squareup/okhttp/Request.java b/okhttp/src/main/java/com/squareup/okhttp/Request.java index 8b98a5adf6bf..9b0169dbb1fa 100644 --- a/okhttp/src/main/java/com/squareup/okhttp/Request.java +++ b/okhttp/src/main/java/com/squareup/okhttp/Request.java @@ -145,6 +145,14 @@ public Builder url(HttpUrl url) { public Builder url(String url) { if (url == null) throw new IllegalArgumentException("url == null"); + + // Silently replace websocket URLs with HTTP URLs. + if (url.regionMatches(true, 0, "ws:", 0, 3)) { + url = "http:" + url.substring(3); + } else if (url.regionMatches(true, 0, "wss:", 0, 4)) { + url = "https:" + url.substring(4); + } + HttpUrl parsed = HttpUrl.parse(url); if (parsed == null) throw new IllegalArgumentException("unexpected url: " + url); return url(parsed);