From 49527dae14a9f112e8f0b6e2f37ef49f23452ec5 Mon Sep 17 00:00:00 2001 From: jwilson Date: Sat, 5 Nov 2016 14:01:57 -0400 Subject: [PATCH] New non-blocking API for websockets. Currently this uses a placholder name, NewWebSocket. I'd like to get this reviewed, expand it to cover MockWebServer's needs, and then I'll delete the current blocking API. Still undecided is which APIs to add - if any - to expose the state of the web socket. https://github.com/square/okhttp/issues/2902 --- .../test/java/okhttp3/NewWebSocketTest.java | 347 +++++++++++++ .../internal/ws/NewWebSocketRecorder.java | 400 ++++++++++++++ .../src/main/java/okhttp3/NewWebSocket.java | 151 ++++++ .../src/main/java/okhttp3/OkHttpClient.java | 22 +- .../main/java/okhttp3/internal/Internal.java | 6 + .../okhttp3/internal/ws/RealNewWebSocket.java | 488 ++++++++++++++++++ 6 files changed, 1413 insertions(+), 1 deletion(-) create mode 100644 okhttp-tests/src/test/java/okhttp3/NewWebSocketTest.java create mode 100644 okhttp-tests/src/test/java/okhttp3/internal/ws/NewWebSocketRecorder.java create mode 100644 okhttp/src/main/java/okhttp3/NewWebSocket.java create mode 100644 okhttp/src/main/java/okhttp3/internal/ws/RealNewWebSocket.java diff --git a/okhttp-tests/src/test/java/okhttp3/NewWebSocketTest.java b/okhttp-tests/src/test/java/okhttp3/NewWebSocketTest.java new file mode 100644 index 000000000000..fc4f1bc74c7d --- /dev/null +++ b/okhttp-tests/src/test/java/okhttp3/NewWebSocketTest.java @@ -0,0 +1,347 @@ +/* + * Copyright (C) 2014 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okhttp3; + +import java.io.IOException; +import java.net.ProtocolException; +import java.util.Random; +import java.util.logging.Logger; +import okhttp3.internal.tls.SslClient; +import okhttp3.internal.ws.NewWebSocketRecorder; +import okhttp3.internal.ws.RealNewWebSocket; +import okhttp3.internal.ws.WebSocketRecorder; +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import okio.ByteString; +import org.junit.After; +import org.junit.Ignore; +import org.junit.Rule; +import org.junit.Test; + +import static okhttp3.TestUtil.defaultClient; +import static okhttp3.WebSocket.TEXT; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.fail; + +public final class NewWebSocketTest { + @Rule public final MockWebServer webServer = new MockWebServer(); + + private final SslClient sslClient = SslClient.localhost(); + private final NewWebSocketRecorder clientListener = new NewWebSocketRecorder("client"); + private final WebSocketRecorder serverListener = new WebSocketRecorder("server"); + private final Random random = new Random(0); + private OkHttpClient client = defaultClient().newBuilder() + .addInterceptor(new Interceptor() { + @Override public Response intercept(Chain chain) throws IOException { + Response response = chain.proceed(chain.request()); + assertNotNull(response.body()); // Ensure application interceptors never see a null body. + return response; + } + }) + .build(); + + @After public void tearDown() { + clientListener.assertExhausted(); + } + + @Test public void textMessage() throws IOException { + webServer.enqueue(new MockResponse().withWebSocketUpgrade(serverListener)); + NewWebSocket client = enqueueClientWebSocket(); + + clientListener.assertOpen(); + serverListener.assertOpen(); + + client.send("Hello, WebSockets!"); + serverListener.assertTextMessage("Hello, WebSockets!"); + } + + @Test public void binaryMessage() throws IOException { + webServer.enqueue(new MockResponse().withWebSocketUpgrade(serverListener)); + RealNewWebSocket client = enqueueClientWebSocket(); + + clientListener.assertOpen(); + serverListener.assertOpen(); + + client.send(ByteString.encodeUtf8("Hello!")); + serverListener.assertBinaryMessage(new byte[] {'H', 'e', 'l', 'l', 'o', '!'}); + } + + @Test public void nullStringThrows() throws IOException { + webServer.enqueue(new MockResponse().withWebSocketUpgrade(serverListener)); + RealNewWebSocket client = enqueueClientWebSocket(); + + clientListener.assertOpen(); + try { + client.send((String) null); + fail(); + } catch (NullPointerException e) { + assertEquals("text == null", e.getMessage()); + } + } + + @Test public void nullByteStringThrows() throws IOException { + webServer.enqueue(new MockResponse().withWebSocketUpgrade(serverListener)); + RealNewWebSocket client = enqueueClientWebSocket(); + + clientListener.assertOpen(); + try { + client.send((ByteString) null); + fail(); + } catch (NullPointerException e) { + assertEquals("bytes == null", e.getMessage()); + } + } + + @Test public void serverMessage() throws IOException { + webServer.enqueue(new MockResponse().withWebSocketUpgrade(serverListener)); + enqueueClientWebSocket(); + + clientListener.assertOpen(); + WebSocket server = serverListener.assertOpen(); + + server.message(RequestBody.create(TEXT, "Hello, WebSockets!")); + clientListener.assertTextMessage("Hello, WebSockets!"); + } + + @Test public void throwingOnOpenFailsImmediately() { + webServer.enqueue(new MockResponse().withWebSocketUpgrade(serverListener)); + + final RuntimeException e = new RuntimeException(); + clientListener.setNextEventDelegate(new NewWebSocket.Listener() { + @Override public void onOpen(NewWebSocket webSocket, Response response) { + throw e; + } + }); + enqueueClientWebSocket(); + + serverListener.assertOpen(); + serverListener.assertExhausted(); + clientListener.assertFailure(e); + } + + @Ignore("AsyncCall currently lets runtime exceptions propagate.") + @Test public void throwingOnFailLogs() throws InterruptedException { + TestLogHandler logs = new TestLogHandler(); + Logger logger = Logger.getLogger(OkHttpClient.class.getName()); + logger.addHandler(logs); + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody("Body")); + + final RuntimeException e = new RuntimeException(); + clientListener.setNextEventDelegate(new NewWebSocket.Listener() { + @Override public void onFailure(NewWebSocket webSocket, Throwable t, Response response) { + throw e; + } + }); + + enqueueClientWebSocket(); + + assertEquals("", logs.take()); + logger.removeHandler(logs); + } + + @Test public void throwingOnMessageClosesImmediatelyAndFails() throws IOException { + webServer.enqueue(new MockResponse().withWebSocketUpgrade(serverListener)); + enqueueClientWebSocket(); + + clientListener.assertOpen(); + WebSocket server = serverListener.assertOpen(); + + final RuntimeException e = new RuntimeException(); + clientListener.setNextEventDelegate(new NewWebSocket.Listener() { + @Override public void onMessage(NewWebSocket webSocket, String text) { + throw e; + } + }); + + server.message(RequestBody.create(TEXT, "Hello, WebSockets!")); + clientListener.assertFailure(e); + serverListener.assertExhausted(); + } + + @Test public void throwingOnClosingClosesImmediatelyAndFails() throws IOException { + webServer.enqueue(new MockResponse().withWebSocketUpgrade(serverListener)); + enqueueClientWebSocket(); + + clientListener.assertOpen(); + WebSocket server = serverListener.assertOpen(); + + final RuntimeException e = new RuntimeException(); + clientListener.setNextEventDelegate(new NewWebSocket.Listener() { + @Override public void onClosing(NewWebSocket webSocket, int code, String reason) { + throw e; + } + }); + + server.close(1000, "bye"); + clientListener.assertFailure(e); + serverListener.assertExhausted(); + } + + @Test public void non101RetainsBody() throws IOException { + webServer.enqueue(new MockResponse().setResponseCode(200).setBody("Body")); + enqueueClientWebSocket(); + + clientListener.assertFailure(200, "Body", ProtocolException.class, + "Expected HTTP 101 response but was '200 OK'"); + } + + @Test public void notFound() throws IOException { + webServer.enqueue(new MockResponse().setStatus("HTTP/1.1 404 Not Found")); + enqueueClientWebSocket(); + + clientListener.assertFailure(404, null, ProtocolException.class, + "Expected HTTP 101 response but was '404 Not Found'"); + } + + @Test public void clientTimeoutClosesBody() throws IOException { + webServer.enqueue(new MockResponse().setResponseCode(408)); + webServer.enqueue(new MockResponse().withWebSocketUpgrade(serverListener)); + RealNewWebSocket client = enqueueClientWebSocket(); + + clientListener.assertOpen(); + WebSocket server = serverListener.assertOpen(); + + client.send("abc"); + serverListener.assertTextMessage("abc"); + + server.message(RequestBody.create(TEXT, "def")); + clientListener.assertTextMessage("def"); + } + + @Test public void missingConnectionHeader() throws IOException { + webServer.enqueue(new MockResponse() + .setResponseCode(101) + .setHeader("Upgrade", "websocket") + .setHeader("Sec-WebSocket-Accept", "ujmZX4KXZqjwy6vi1aQFH5p4Ygk=")); + enqueueClientWebSocket(); + + clientListener.assertFailure(101, null, ProtocolException.class, + "Expected 'Connection' header value 'Upgrade' but was 'null'"); + } + + @Test public void wrongConnectionHeader() throws IOException { + webServer.enqueue(new MockResponse() + .setResponseCode(101) + .setHeader("Upgrade", "websocket") + .setHeader("Connection", "Downgrade") + .setHeader("Sec-WebSocket-Accept", "ujmZX4KXZqjwy6vi1aQFH5p4Ygk=")); + enqueueClientWebSocket(); + + clientListener.assertFailure(101, null, ProtocolException.class, + "Expected 'Connection' header value 'Upgrade' but was 'Downgrade'"); + } + + @Test public void missingUpgradeHeader() throws IOException { + webServer.enqueue(new MockResponse() + .setResponseCode(101) + .setHeader("Connection", "Upgrade") + .setHeader("Sec-WebSocket-Accept", "ujmZX4KXZqjwy6vi1aQFH5p4Ygk=")); + enqueueClientWebSocket(); + + clientListener.assertFailure(101, null, ProtocolException.class, + "Expected 'Upgrade' header value 'websocket' but was 'null'"); + } + + @Test public void wrongUpgradeHeader() throws IOException { + webServer.enqueue(new MockResponse() + .setResponseCode(101) + .setHeader("Connection", "Upgrade") + .setHeader("Upgrade", "Pepsi") + .setHeader("Sec-WebSocket-Accept", "ujmZX4KXZqjwy6vi1aQFH5p4Ygk=")); + enqueueClientWebSocket(); + + clientListener.assertFailure(101, null, ProtocolException.class, + "Expected 'Upgrade' header value 'websocket' but was 'Pepsi'"); + } + + @Test public void missingMagicHeader() throws IOException { + webServer.enqueue(new MockResponse() + .setResponseCode(101) + .setHeader("Connection", "Upgrade") + .setHeader("Upgrade", "websocket")); + enqueueClientWebSocket(); + + clientListener.assertFailure(101, null, ProtocolException.class, + "Expected 'Sec-WebSocket-Accept' header value 'ujmZX4KXZqjwy6vi1aQFH5p4Ygk=' but was 'null'"); + } + + @Test public void wrongMagicHeader() throws IOException { + webServer.enqueue(new MockResponse() + .setResponseCode(101) + .setHeader("Connection", "Upgrade") + .setHeader("Upgrade", "websocket") + .setHeader("Sec-WebSocket-Accept", "magic")); + enqueueClientWebSocket(); + + clientListener.assertFailure(101, null, ProtocolException.class, + "Expected 'Sec-WebSocket-Accept' header value 'ujmZX4KXZqjwy6vi1aQFH5p4Ygk=' but was 'magic'"); + } + + @Test public void wsScheme() throws IOException { + websocketScheme("ws"); + } + + @Test public void wsUppercaseScheme() throws IOException { + websocketScheme("WS"); + } + + @Test public void wssScheme() throws IOException { + webServer.useHttps(sslClient.socketFactory, false); + client = client.newBuilder() + .sslSocketFactory(sslClient.socketFactory, sslClient.trustManager) + .hostnameVerifier(new RecordingHostnameVerifier()) + .build(); + + websocketScheme("wss"); + } + + @Test public void httpsScheme() throws IOException { + webServer.useHttps(sslClient.socketFactory, false); + client = client.newBuilder() + .sslSocketFactory(sslClient.socketFactory, sslClient.trustManager) + .hostnameVerifier(new RecordingHostnameVerifier()) + .build(); + + websocketScheme("https"); + } + + private void websocketScheme(String scheme) throws IOException { + webServer.enqueue(new MockResponse().withWebSocketUpgrade(serverListener)); + + Request request = new Request.Builder() + .url(scheme + "://" + webServer.getHostName() + ":" + webServer.getPort() + "/") + .build(); + + RealNewWebSocket webSocket = enqueueClientWebSocket(request); + clientListener.assertOpen(); + serverListener.assertOpen(); + + webSocket.send("abc"); + serverListener.assertTextMessage("abc"); + } + + private RealNewWebSocket enqueueClientWebSocket() { + return enqueueClientWebSocket(new Request.Builder().get().url(webServer.url("/")).build()); + } + + private RealNewWebSocket enqueueClientWebSocket(Request request) { + RealNewWebSocket webSocket = new RealNewWebSocket(client, request, clientListener, random); + webSocket.connnect(); + return webSocket; + } +} diff --git a/okhttp-tests/src/test/java/okhttp3/internal/ws/NewWebSocketRecorder.java b/okhttp-tests/src/test/java/okhttp3/internal/ws/NewWebSocketRecorder.java new file mode 100644 index 000000000000..7f4ad50f68eb --- /dev/null +++ b/okhttp-tests/src/test/java/okhttp3/internal/ws/NewWebSocketRecorder.java @@ -0,0 +1,400 @@ +/* + * Copyright (C) 2016 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okhttp3.internal.ws; + +import java.io.IOException; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; +import okhttp3.NewWebSocket; +import okhttp3.Response; +import okhttp3.ResponseBody; +import okhttp3.WebSocket; +import okhttp3.internal.Util; +import okhttp3.internal.platform.Platform; +import okio.ByteString; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertTrue; + +public final class NewWebSocketRecorder extends NewWebSocket.Listener { + private final String name; + private final BlockingQueue events = new LinkedBlockingQueue<>(); + private NewWebSocket.Listener delegate; + + public NewWebSocketRecorder(String name) { + this.name = name; + } + + /** Sets a delegate for handling the next callback to this listener. Cleared after invoked. */ + public void setNextEventDelegate(NewWebSocket.Listener delegate) { + this.delegate = delegate; + } + + @Override public void onOpen(NewWebSocket webSocket, Response response) { + Platform.get().log(Platform.INFO, "[WS " + name + "] onOpen", null); + + NewWebSocket.Listener delegate = this.delegate; + if (delegate != null) { + this.delegate = null; + delegate.onOpen(webSocket, response); + } else { + events.add(new Open(webSocket, response)); + } + } + + @Override public void onMessage(NewWebSocket webSocket, ByteString bytes) { + Platform.get().log(Platform.INFO, "[WS " + name + "] onMessage", null); + + NewWebSocket.Listener delegate = this.delegate; + if (delegate != null) { + this.delegate = null; + delegate.onMessage(webSocket, bytes); + } else { + Message event = new Message(bytes); + events.add(event); + } + } + + @Override public void onMessage(NewWebSocket webSocket, String text) { + Platform.get().log(Platform.INFO, "[WS " + name + "] onMessage", null); + + NewWebSocket.Listener delegate = this.delegate; + if (delegate != null) { + this.delegate = null; + delegate.onMessage(webSocket, text); + } else { + Message event = new Message(text); + events.add(event); + } + } + + @Override public void onClosing(NewWebSocket webSocket, int code, String reason) { + Platform.get().log(Platform.INFO, "[WS " + name + "] onClose " + code, null); + + NewWebSocket.Listener delegate = this.delegate; + if (delegate != null) { + this.delegate = null; + delegate.onClosing(webSocket, code, reason); + } else { + events.add(new Closing(code, reason)); + } + } + + @Override public void onClosed(NewWebSocket webSocket, int code, String reason) { + Platform.get().log(Platform.INFO, "[WS " + name + "] onClose " + code, null); + + NewWebSocket.Listener delegate = this.delegate; + if (delegate != null) { + this.delegate = null; + delegate.onClosed(webSocket, code, reason); + } else { + events.add(new Closed(code, reason)); + } + } + + @Override public void onFailure(NewWebSocket webSocket, Throwable t, Response response) { + Platform.get().log(Platform.INFO, "[WS " + name + "] onFailure", t); + + NewWebSocket.Listener delegate = this.delegate; + if (delegate != null) { + this.delegate = null; + delegate.onFailure(webSocket, t, response); + } else { + events.add(new Failure(t, response)); + } + } + + private Object nextEvent() { + try { + Object event = events.poll(10, TimeUnit.SECONDS); + if (event == null) { + throw new AssertionError("Timed out waiting for event."); + } + return event; + } catch (InterruptedException e) { + throw new AssertionError(e); + } + } + + public void assertTextMessage(String payload) { + Object actual = nextEvent(); + assertEquals(new Message(payload), actual); + } + + public void assertBinaryMessage(byte[] payload) { + Object actual = nextEvent(); + assertEquals(new Message(ByteString.of(payload)), actual); + } + + public void assertPong(ByteString payload) { + Object actual = nextEvent(); + assertEquals(new Pong(payload), actual); + } + + public void assertClose(int code, String reason) { + Object actual = nextEvent(); + assertEquals(new Closing(code, reason), actual); + } + + public void assertExhausted() { + assertTrue("Remaining events: " + events, events.isEmpty()); + } + + public NewWebSocket assertOpen() { + Object event = nextEvent(); + if (!(event instanceof Open)) { + throw new AssertionError("Expected Open but was " + event); + } + return ((Open) event).webSocket; + } + + public void assertFailure(Throwable t) { + Object event = nextEvent(); + if (!(event instanceof Failure)) { + throw new AssertionError("Expected Failure but was " + event); + } + Failure failure = (Failure) event; + assertNull(failure.response); + assertSame(t, failure.t); + } + + public void assertFailure(Class cls, String message) { + Object event = nextEvent(); + if (!(event instanceof Failure)) { + throw new AssertionError("Expected Failure but was " + event); + } + Failure failure = (Failure) event; + assertNull(failure.response); + assertEquals(cls, failure.t.getClass()); + assertEquals(message, failure.t.getMessage()); + } + + public void assertFailure(int code, String body, Class cls, String message) + throws IOException { + Object event = nextEvent(); + if (!(event instanceof Failure)) { + throw new AssertionError("Expected Failure but was " + event); + } + Failure failure = (Failure) event; + assertEquals(code, failure.response.code()); + if (body != null) { + assertEquals(body, failure.responseBody); + } + assertEquals(cls, failure.t.getClass()); + assertEquals(message, failure.t.getMessage()); + } + + static final class Open { + final NewWebSocket webSocket; + final Response response; + + Open(NewWebSocket webSocket, Response response) { + this.webSocket = webSocket; + this.response = response; + } + + @Override public String toString() { + return "Open[" + response + "]"; + } + } + + static final class Failure { + final Throwable t; + final Response response; + final String responseBody; + + Failure(Throwable t, Response response) { + this.t = t; + this.response = response; + String responseBody = null; + if (response != null) { + try { + responseBody = response.body().string(); + } catch (IOException ignored) { + } + } + this.responseBody = responseBody; + } + + @Override public String toString() { + if (response == null) { + return "Failure[" + t + "]"; + } + return "Failure[" + response + "]"; + } + } + + static final class Message { + public final ByteString bytes; + public final String string; + + public Message(ByteString bytes) { + this.bytes = bytes; + this.string = null; + } + + public Message(String string) { + this.bytes = null; + this.string = string; + } + + @Override public String toString() { + return "Message[" + (bytes != null ? bytes : string) + "]"; + } + + @Override public int hashCode() { + return (bytes != null ? bytes : string).hashCode(); + } + + @Override public boolean equals(Object other) { + return other instanceof Message + && Util.equal(((Message) other).bytes, bytes) + && Util.equal(((Message) other).string, string); + } + } + + static final class Pong { + public final ByteString payload; + + Pong(ByteString payload) { + this.payload = payload; + } + + @Override public String toString() { + return "Pong[" + payload + "]"; + } + + @Override public int hashCode() { + return payload.hashCode(); + } + + @Override public boolean equals(Object obj) { + if (obj instanceof Pong) { + Pong other = (Pong) obj; + return payload == null ? other.payload == null : payload.equals(other.payload); + } + return false; + } + } + + static final class Closing { + public final int code; + public final String reason; + + Closing(int code, String reason) { + this.code = code; + this.reason = reason; + } + + @Override public String toString() { + return "Closing[" + code + " " + reason + "]"; + } + + @Override public int hashCode() { + return code * 37 + reason.hashCode(); + } + + @Override public boolean equals(Object other) { + return other instanceof Closing + && ((Closing) other).code == code + && ((Closing) other).reason.equals(reason); + } + } + + static final class Closed { + public final int code; + public final String reason; + + Closed(int code, String reason) { + this.code = code; + this.reason = reason; + } + + @Override public String toString() { + return "Closed[" + code + " " + reason + "]"; + } + + @Override public int hashCode() { + return code * 37 + reason.hashCode(); + } + + @Override public boolean equals(Object other) { + return other instanceof Closed + && ((Closed) other).code == code + && ((Closed) other).reason.equals(reason); + } + } + + /** Expose this recorder as a frame callback and shim in "ping" events. */ + WebSocketReader.FrameCallback asFrameCallback() { + return new WebSocketReader.FrameCallback() { + @Override public void onReadMessage(ResponseBody body) throws IOException { + if (body.contentType().equals(WebSocket.TEXT)) { + String text = body.source().readUtf8(); + onMessage(null, text); + } else if (body.contentType().equals(WebSocket.BINARY)) { + ByteString bytes = body.source().readByteString(); + onMessage(null, bytes); + } else { + throw new IllegalArgumentException(); + } + } + + @Override public void onReadPing(ByteString payload) { + events.add(new Ping(payload)); + } + + @Override public void onReadPong(ByteString padload) { + } + + @Override public void onReadClose(int code, String reason) { + onClosing(null, code, reason); + } + }; + } + + void assertPing(ByteString payload) { + Object actual = nextEvent(); + assertEquals(new Ping(payload), actual); + } + + static final class Ping { + public final ByteString buffer; + + Ping(ByteString buffer) { + this.buffer = buffer; + } + + @Override public String toString() { + return "Ping[" + buffer + "]"; + } + + @Override public int hashCode() { + return buffer.hashCode(); + } + + @Override public boolean equals(Object obj) { + if (obj instanceof Ping) { + Ping other = (Ping) obj; + return buffer == null ? other.buffer == null : buffer.equals(other.buffer); + } + return false; + } + } +} diff --git a/okhttp/src/main/java/okhttp3/NewWebSocket.java b/okhttp/src/main/java/okhttp3/NewWebSocket.java new file mode 100644 index 000000000000..4baff4ff76dc --- /dev/null +++ b/okhttp/src/main/java/okhttp3/NewWebSocket.java @@ -0,0 +1,151 @@ +/* + * Copyright (C) 2016 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okhttp3; + +import okio.ByteString; + +/** + * A non-blocking interface to a web socket. Use the {@linkplain NewWebSocket.Factory factory} to + * create instances; usually this is {@link OkHttpClient}. + * + *

Web Socket Lifecycle

+ * + * Upon normal operation each web socket progresses through a sequence of states: + * + * + * + * Web sockets may fail due to HTTP upgrade problems, connectivity problems, or if either peer + * chooses to short-circuit the graceful shutdown process: + * + * + * + * Note that the state progression is independent for each peer. Arriving at a gracefully-closed + * state indicates that a peer has sent all of its outgoing messages and received all of its + * incoming messages. But it does not guarantee that the other peer will successfully receive all of + * its incoming messages. + */ +public interface NewWebSocket { + /** Returns the original request that initiated this web socket. */ + Request request(); + + /** + * Returns the size in bytes of all messages enqueued to be transmitted to the server. This + * doesn't include framing overhead. It also doesn't include any bytes buffered by the operating + * system or network intermediaries. This method returns 0 if no messages are waiting + * in the queue. If may return a nonzero value after the web socket has been canceled; this + * indicates that enqueued messages were not transmitted. + */ + long queueSize(); + + /** + * Attempts to enqueue {@code text} to be UTF-8 encoded and sent as a the data of a text (type + * {@code 0x1}) message. + * + *

This method returns true if the message was enqueued. Messages that would overflow the + * outgoing message buffer will be rejected and trigger a {@linkplain #close graceful shutdown} of + * this web socket. This method returns false in that case, and in any other case where this + * web socket is closing, closed, or canceled. + * + *

This method returns immediately. + */ + boolean send(String text); + + /** + * Attempts to enqueue {@code bytes} to be sent as a the data of a binary (type {@code 0x2}) + * message. + * + *

This method returns true if the message was enqueued. Messages that would overflow the + * outgoing message buffer will be rejected and trigger a {@linkplain #close graceful shutdown} of + * this web socket. This method returns false in that case, and in any other case where this + * web socket is closing, closed, or canceled. + * + *

This method returns immediately. + */ + boolean send(ByteString bytes); + + /** + * Attempts to initiate a graceful shutdown of this web socket. Any already-enqueued messages will + * be transmitted before the close message is sent but subsequent calls to {@link #send} will + * return false and their messages will not be enqueued. + * + *

This returns true if a graceful shutdown was initiated by this call. It returns false and if + * a graceful shutdown was already underway or if the web socket is already closed or canceled. + * + * @param code Status code as defined by Section 7.4 of RFC 6455 or {@code 0}. + * @param reason Reason for shutting down or {@code null}. + */ + boolean close(int code, String reason); + + /** + * Immediately and violently release resources held by this web socket, discarding any enqueued + * messages. This does nothing if the web socket has already been closed or canceled. + */ + void cancel(); + + interface Factory { + NewWebSocket newWebSocket(Request request, Listener listener); + } + + abstract class Listener { + /** + * Invoked when a web socket has been accepted by the remote peer and may begin transmitting + * messages. + */ + public void onOpen(NewWebSocket webSocket, Response response) { + } + + /** Invoked when a text (type {@code 0x1}) message has been received. */ + public void onMessage(NewWebSocket webSocket, String text) { + } + + /** Invoked when a binary (type {@code 0x2}) message has been received. */ + public void onMessage(NewWebSocket webSocket, ByteString bytes) { + } + + /** Invoked when the peer has indicated that no more incoming messages will be transmitted. */ + public void onClosing(NewWebSocket webSocket, int code, String reason) { + } + + /** + * Invoked when both peers have indicated that no more messages will be transmitted and the + * connection has been successfully released. No further calls to this listener will be made. + */ + public void onClosed(NewWebSocket webSocket, int code, String reason) { + } + + /** + * Invoked when a web socket has been closed due to an error reading from or writing to the + * network. Both outgoing and incoming messages may have been lost. No further calls to this + * listener will be made. + */ + public void onFailure(NewWebSocket webSocket, Throwable t, Response response) { + } + } +} diff --git a/okhttp/src/main/java/okhttp3/OkHttpClient.java b/okhttp/src/main/java/okhttp3/OkHttpClient.java index 225896d11dfa..91160a56f192 100644 --- a/okhttp/src/main/java/okhttp3/OkHttpClient.java +++ b/okhttp/src/main/java/okhttp3/OkHttpClient.java @@ -21,6 +21,7 @@ import java.net.UnknownHostException; import java.security.GeneralSecurityException; import java.security.KeyStore; +import java.security.SecureRandom; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -44,6 +45,7 @@ import okhttp3.internal.platform.Platform; import okhttp3.internal.tls.CertificateChainCleaner; import okhttp3.internal.tls.OkHostnameVerifier; +import okhttp3.internal.ws.RealNewWebSocket; /** * Factory for {@linkplain Call calls}, which can be used to send HTTP requests and read their @@ -115,7 +117,8 @@ *

OkHttp also uses daemon threads for HTTP/2 connections. These will exit automatically if they * remain idle. */ -public class OkHttpClient implements Cloneable, Call.Factory, WebSocketCall.Factory { +public class OkHttpClient + implements Cloneable, Call.Factory, WebSocketCall.Factory, NewWebSocket.Factory { private static final List DEFAULT_PROTOCOLS = Util.immutableList( Protocol.HTTP_2, Protocol.HTTP_1_1); @@ -163,6 +166,14 @@ public void apply(ConnectionSpec tlsConfiguration, SSLSocket sslSocket, boolean throws MalformedURLException, UnknownHostException { return HttpUrl.getChecked(url); } + + @Override public StreamAllocation streamAllocation(Call call) { + return ((RealCall) call).streamAllocation(); + } + + @Override public Call newWebSocketCall(OkHttpClient client, Request originalRequest) { + return new RealCall(client, originalRequest, true); + } }; } @@ -387,6 +398,15 @@ public List networkInterceptors() { return new RealWebSocketCall(this, request); } + /** + * Uses {@code request} to connect a new web socket. + */ + @Override public NewWebSocket newWebSocket(Request request, NewWebSocket.Listener listener) { + RealNewWebSocket webSocket = new RealNewWebSocket(this, request, listener, new SecureRandom()); + webSocket.connnect(); + return webSocket; + } + public Builder newBuilder() { return new Builder(this); } diff --git a/okhttp/src/main/java/okhttp3/internal/Internal.java b/okhttp/src/main/java/okhttp3/internal/Internal.java index 6a25d3cdf87f..2610a4451635 100644 --- a/okhttp/src/main/java/okhttp3/internal/Internal.java +++ b/okhttp/src/main/java/okhttp3/internal/Internal.java @@ -19,11 +19,13 @@ import java.net.UnknownHostException; import javax.net.ssl.SSLSocket; import okhttp3.Address; +import okhttp3.Call; import okhttp3.ConnectionPool; import okhttp3.ConnectionSpec; import okhttp3.Headers; import okhttp3.HttpUrl; import okhttp3.OkHttpClient; +import okhttp3.Request; import okhttp3.internal.cache.InternalCache; import okhttp3.internal.connection.RealConnection; import okhttp3.internal.connection.RouteDatabase; @@ -62,4 +64,8 @@ public abstract void apply(ConnectionSpec tlsConfiguration, SSLSocket sslSocket, public abstract HttpUrl getHttpUrlChecked(String url) throws MalformedURLException, UnknownHostException; + + public abstract StreamAllocation streamAllocation(Call call); + + public abstract Call newWebSocketCall(OkHttpClient client, Request request); } diff --git a/okhttp/src/main/java/okhttp3/internal/ws/RealNewWebSocket.java b/okhttp/src/main/java/okhttp3/internal/ws/RealNewWebSocket.java new file mode 100644 index 000000000000..380c9939a2b7 --- /dev/null +++ b/okhttp/src/main/java/okhttp3/internal/ws/RealNewWebSocket.java @@ -0,0 +1,488 @@ +/* + * Copyright (C) 2016 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okhttp3.internal.ws; + +import java.io.Closeable; +import java.io.IOException; +import java.net.ProtocolException; +import java.util.ArrayDeque; +import java.util.Collections; +import java.util.List; +import java.util.Random; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.SynchronousQueue; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import okhttp3.Call; +import okhttp3.Callback; +import okhttp3.NewWebSocket; +import okhttp3.OkHttpClient; +import okhttp3.Protocol; +import okhttp3.Request; +import okhttp3.Response; +import okhttp3.ResponseBody; +import okhttp3.WebSocket; +import okhttp3.internal.Internal; +import okhttp3.internal.NamedRunnable; +import okhttp3.internal.Util; +import okhttp3.internal.connection.StreamAllocation; +import okio.BufferedSink; +import okio.BufferedSource; +import okio.ByteString; +import okio.Okio; + +import static java.util.concurrent.TimeUnit.SECONDS; +import static okhttp3.internal.Util.closeQuietly; +import static okhttp3.internal.ws.WebSocketProtocol.CLOSE_CLIENT_GOING_AWAY; +import static okhttp3.internal.ws.WebSocketProtocol.OPCODE_BINARY; +import static okhttp3.internal.ws.WebSocketProtocol.OPCODE_TEXT; + +public final class RealNewWebSocket implements NewWebSocket, WebSocketReader.FrameCallback { + private static final List ONLY_HTTP1 = Collections.singletonList(Protocol.HTTP_1_1); + + /** + * The maximum number of bytes to enqueue. Rather than enqueueing beyond this limit we tear down + * the web socket! It's possible that we're writing faster than the peer can read. + */ + private static final long MAX_QUEUE_SIZE = 1024 * 1024; // 1 MiB. + + /** A shared executor for all web sockets. */ + private static final ExecutorService executor = new ThreadPoolExecutor(0, + Integer.MAX_VALUE, 60, TimeUnit.SECONDS, new SynchronousQueue(), + Util.threadFactory("OkHttp WebSocket", true)); + + /** The application's original request unadulterated by web socket headers. */ + private final Request originalRequest; + + private final Listener listener; + private final Random random; + private final String key; + private final Call call; + + /** This runnable processes the outgoing queues. Call {@link #runWriter()} to after enqueueing. */ + private final NamedRunnable writerRunnable; + + // All mutable web socket state is guarded by this. + + /** + * True if {@link #writerRunnable} is active. Because writing is single-threaded we only enqueue + * it if it isn't already enqueued. + */ + private boolean writerRunning; + + /** Null until this web Socket is connected. Note that messages may be enqueued before that. */ + private WebSocketWriter writer; + + /** + * The streams held by this web socket. This is non-null until all incoming messages have been + * read and all outgoing messages have been written. It is closed when both reader and writer are + * exhausted, or if there is any failure. + */ + private Streams streams; + + /** Outgoing pongs in the order they should be written. */ + private final ArrayDeque pongQueue = new ArrayDeque<>(); + + /** Outgoing messages and close frames in the order they should be written. */ + private final ArrayDeque messageAndCloseQueue = new ArrayDeque<>(); + + /** The total size in bytes of enqueued but not yet transmitted messages. */ + private long queueSize; + + /** True if we've enqueued a close frame. No further message frames will be enqueued. */ + private boolean enqueuedClose; + + /** The close code from the peer, or -1 if this web socket has not yet read a close frame. */ + private int receivedCloseCode = -1; + + /** The close reason from the peer, or null if this web socket has not yet read a close frame. */ + private String receivedCloseReason; + + /** True if this web socket failed and the listener has been notified. */ + private boolean failed; + + public RealNewWebSocket(OkHttpClient client, Request request, Listener listener, Random random) { + if (!"GET".equals(request.method())) { + throw new IllegalArgumentException("Request must be GET: " + request.method()); + } + this.originalRequest = request; + this.listener = listener; + this.random = random; + + byte[] nonce = new byte[16]; + random.nextBytes(nonce); + this.key = ByteString.of(nonce).base64(); + + this.writerRunnable = new NamedRunnable("OkHttp WebSocket %s", request.url().redact()) { + @Override protected void execute() { + try { + while (writeOneFrame()) { + } + } catch (IOException e) { + failWebSocket(e, null); + } + } + }; + + client = client.newBuilder() + .readTimeout(0, SECONDS) // i.e., no timeout because this is a long-lived connection. + .writeTimeout(0, SECONDS) // i.e., no timeout because this is a long-lived connection. + .protocols(ONLY_HTTP1) + .build(); + request = request.newBuilder() + .header("Upgrade", "websocket") + .header("Connection", "Upgrade") + .header("Sec-WebSocket-Key", key) + .header("Sec-WebSocket-Version", "13") + .build(); + this.call = Internal.instance.newWebSocketCall(client, request); + } + + @Override public Request request() { + return originalRequest; + } + + @Override public synchronized long queueSize() { + return queueSize; + } + + @Override public void cancel() { + call.cancel(); + } + + public void connnect() { + call.enqueue(new Callback() { + @Override public void onResponse(Call call, Response response) { + try { + checkResponse(response); + } catch (ProtocolException e) { + failWebSocket(e, response); + closeQuietly(response); + return; + } + + // Promote the HTTP streams into web socket streams. + StreamAllocation streamAllocation = Internal.instance.streamAllocation(call); + streamAllocation.noNewStreams(); // Prevent connection pooling! + Streams streams = new ClientStreams(streamAllocation); + + try { + readWebsocket(streams, response); + } catch (Exception e) { + failWebSocket(e, null); + } + } + + @Override public void onFailure(Call call, IOException e) { + failWebSocket(e, null); + } + }); + } + + private void checkResponse(Response response) throws ProtocolException { + if (response.code() != 101) { + throw new ProtocolException("Expected HTTP 101 response but was '" + + response.code() + " " + response.message() + "'"); + } + + String headerConnection = response.header("Connection"); + if (!"Upgrade".equalsIgnoreCase(headerConnection)) { + throw new ProtocolException("Expected 'Connection' header value 'Upgrade' but was '" + + headerConnection + "'"); + } + + String headerUpgrade = response.header("Upgrade"); + if (!"websocket".equalsIgnoreCase(headerUpgrade)) { + throw new ProtocolException( + "Expected 'Upgrade' header value 'websocket' but was '" + headerUpgrade + "'"); + } + + String headerAccept = response.header("Sec-WebSocket-Accept"); + String acceptExpected = Util.shaBase64(key + WebSocketProtocol.ACCEPT_MAGIC); + if (!acceptExpected.equals(headerAccept)) { + throw new ProtocolException("Expected 'Sec-WebSocket-Accept' header value '" + + acceptExpected + "' but was '" + headerAccept + "'"); + } + } + + void readWebsocket(Streams streams, Response response) throws IOException { + synchronized (this) { + this.streams = streams; + this.writer = new WebSocketWriter(streams.client, streams.sink, random); + if (!messageAndCloseQueue.isEmpty()) { + runWriter(); // Send messages that were enqueued before we were connected. + } + } + + // Receive frames until there are no more. + WebSocketReader reader = new WebSocketReader(streams.client, streams.source, this); + listener.onOpen(this, response); + while (receivedCloseCode == -1) { + // This method call results in one or more onRead* methods being called on this thread. + reader.processNextFrame(); + } + } + + @Override public void onReadMessage(ResponseBody body) throws IOException { + try { + if (body.contentType().equals(WebSocket.TEXT)) { + String text = body.source().readUtf8(); + listener.onMessage(this, text); + } else if (body.contentType().equals(WebSocket.BINARY)) { + ByteString bytes = body.source().readByteString(); + listener.onMessage(this, bytes); + } else { + throw new IllegalArgumentException(); + } + } finally { + Util.closeQuietly(body); + } + } + + @Override public synchronized void onReadPing(final ByteString payload) { + // Don't respond to pings after we've failed or sent the close frame. + if (failed || (enqueuedClose && messageAndCloseQueue.isEmpty())) return; + + pongQueue.add(payload); + runWriter(); + } + + @Override public void onReadPong(ByteString buffer) { + // This API doesn't expose pings. + } + + @Override public void onReadClose(int code, String reason) { + if (code == -1) throw new IllegalArgumentException(); + + Streams toClose = null; + synchronized (this) { + if (receivedCloseCode != -1) throw new IllegalStateException("already closed"); + receivedCloseCode = code; + receivedCloseReason = reason; + if (enqueuedClose && messageAndCloseQueue.isEmpty()) { + toClose = this.streams; + this.streams = null; + } + } + + try { + listener.onClosing(this, code, reason); + + if (toClose != null) { + listener.onClosed(this, code, reason); + } + } finally { + closeQuietly(toClose); + } + } + + // Writer methods to enqueue frames. They'll be sent asynchronously by the writer thread. + + @Override public boolean send(String text) { + if (text == null) throw new NullPointerException("text == null"); + return send(ByteString.encodeUtf8(text), OPCODE_TEXT); + } + + @Override public boolean send(ByteString bytes) { + if (bytes == null) throw new NullPointerException("bytes == null"); + return send(bytes, OPCODE_BINARY); + } + + private synchronized boolean send(final ByteString data, final int formatOpcode) { + // Don't send new frames after we've failed or enqueued a close frame. + if (failed || enqueuedClose) return false; + + // If this frame overflows the buffer, reject it and close the web socket. + if (queueSize + data.size() > MAX_QUEUE_SIZE) { + close(CLOSE_CLIENT_GOING_AWAY, null); + return false; + } + + // Enqueue the message frame. + queueSize += data.size(); + messageAndCloseQueue.add(new Message(formatOpcode, data)); + runWriter(); + return true; + } + + @Override public synchronized boolean close(final int code, final String reason) { + // TODO(jwilson): confirm reason is well-formed. (<=123 bytes, etc.) + + if (failed || enqueuedClose) return false; + + // Immediately prevent further frames from being enqueued. + enqueuedClose = true; + + // Enqueue the close frame. + messageAndCloseQueue.add(new Close(code, reason)); + runWriter(); + return true; + } + + private void runWriter() { + assert (Thread.holdsLock(this)); + + if (!writerRunning) { + writerRunning = true; + executor.execute(writerRunnable); + } + } + + /** + * Attempts to remove a single frame from a queue and send it. This prefers to write urgent pongs + * before less urgent messages and close frames. For example it's possible that a caller will + * enqueue messages followed by pongs, but this sends pongs followed by messages. Pongs are always + * written in the order they were enqueued. + * + *

If a frame cannot be sent - because there are none enqueued or because the web socket is not + * connected - this does nothing and returns false. Otherwise this returns true and the caller + * should immediately invoke this method again until it returns false. + * + *

This method may only be invoked by the writer thread. There may be only thread invoking this + * method at a time. + */ + private boolean writeOneFrame() throws IOException { + WebSocketWriter writer; + ByteString pong; + Object messageOrClose = null; + int receivedCloseCode = -1; + String receivedCloseReason = null; + Streams streamsToClose = null; + + synchronized (RealNewWebSocket.this) { + if (failed) { + writerRunning = false; + return false; // Failed web socket. + } + + writer = this.writer; + if (writer == null) { + writerRunning = false; + return false; // Not yet connected. + } + + pong = pongQueue.poll(); + if (pong == null) { + messageOrClose = messageAndCloseQueue.poll(); + if (messageOrClose instanceof Close) { + receivedCloseCode = this.receivedCloseCode; + receivedCloseReason = this.receivedCloseReason; + if (receivedCloseCode != -1) { + streamsToClose = this.streams; + this.streams = null; + } + + } else if (messageOrClose == null) { + writerRunning = false; + return false; // The queue is exhausted. + } + } + } + + try { + if (pong != null) { + writer.writePong(pong); + + } else if (messageOrClose instanceof Message) { + ByteString data = ((Message) messageOrClose).data; + BufferedSink sink = Okio.buffer(writer.newMessageSink( + ((Message) messageOrClose).formatOpcode, data.size())); + sink.write(data); + sink.close(); + synchronized (this) { + queueSize -= data.size(); + } + + } else if (messageOrClose instanceof Close) { + Close close = (Close) messageOrClose; + writer.writeClose(close.code, close.reason); + + // We closed the writer: now both reader and writer are closed. + if (streamsToClose != null) { + listener.onClosed(this, receivedCloseCode, receivedCloseReason); + } + + } else { + throw new AssertionError(); + } + + return true; + } finally { + closeQuietly(streamsToClose); + } + } + + private void failWebSocket(Exception e, Response response) { + Streams streamsToClose; + synchronized (this) { + if (failed) return; // Already failed. + failed = true; + streamsToClose = this.streams; + this.streams = null; + } + + try { + listener.onFailure(this, e, response); + } finally { + closeQuietly(streamsToClose); + } + } + + static final class Message { + final int formatOpcode; + final ByteString data; + + public Message(int formatOpcode, ByteString data) { + this.formatOpcode = formatOpcode; + this.data = data; + } + } + + static final class Close { + final int code; + final String reason; + + public Close(int code, String reason) { + this.code = code; + this.reason = reason; + } + } + + abstract static class Streams implements Closeable { + final boolean client; + final BufferedSource source; + final BufferedSink sink; + + public Streams(boolean client, BufferedSource source, BufferedSink sink) { + this.client = client; + this.source = source; + this.sink = sink; + } + } + + static final class ClientStreams extends Streams { + private final StreamAllocation streamAllocation; + + public ClientStreams(StreamAllocation streamAllocation) { + super(true, streamAllocation.connection().source, streamAllocation.connection().sink); + this.streamAllocation = streamAllocation; + } + + @Override public void close() { + streamAllocation.streamFinished(true, streamAllocation.codec()); + } + } +}