Skip to content

Commit

Permalink
Merge pull request square#1280 from square/jw/mws-ws
Browse files Browse the repository at this point in the history
Teach MockWebServer to speak WebSockets.
  • Loading branch information
JakeWharton committed Jan 4, 2015
2 parents 16e2b0d + 8b0bdf3 commit 7000edd
Show file tree
Hide file tree
Showing 9 changed files with 257 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
package com.squareup.okhttp.mockwebserver;

import com.squareup.okhttp.internal.ws.WebSocketListener;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.Iterator;
Expand Down Expand Up @@ -43,6 +44,7 @@ public final class MockResponse implements Cloneable {
private int bodyDelayTimeMs = 0;

private List<PushPromise> promises = new ArrayList<>();
private WebSocketListener webSocketListener;

/** Creates a new mock response with an empty body. */
public MockResponse() {
Expand All @@ -66,8 +68,7 @@ public String getStatus() {
}

public MockResponse setResponseCode(int code) {
this.status = "HTTP/1.1 " + code + " OK";
return this;
return setStatus("HTTP/1.1 " + code + " OK");
}

public MockResponse setStatus(String status) {
Expand Down Expand Up @@ -134,7 +135,11 @@ public Buffer getBody() {

/** Returns an input stream containing the raw HTTP payload. */
InputStream getBodyStream() {
return bodyStream != null ? bodyStream : getBody().inputStream();
if (bodyStream != null) {
return bodyStream;
}
Buffer body = getBody();
return body != null ? body.inputStream() : null;
}

public MockResponse setBody(byte[] body) {
Expand Down Expand Up @@ -251,6 +256,24 @@ public List<PushPromise> getPushPromises() {
return promises;
}

/**
* Attempts to perform a web socket upgrade on the connection. This will overwrite any previously
* set status or body.
*/
public MockResponse withWebSocketUpgrade(WebSocketListener listener) {
setStatus("HTTP/1.1 101 Switching Protocols");
setHeader("Connection", "Upgrade");
setHeader("Upgrade", "websocket");
body = null;
bodyStream = null;
webSocketListener = listener;
return this;
}

public WebSocketListener getWebSocketListener() {
return webSocketListener;
}

@Override public String toString() {
return status;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@

package com.squareup.okhttp.mockwebserver;

import com.squareup.okhttp.Headers;
import com.squareup.okhttp.Protocol;
import com.squareup.okhttp.Request;
import com.squareup.okhttp.Response;
import com.squareup.okhttp.internal.NamedRunnable;
import com.squareup.okhttp.internal.Platform;
import com.squareup.okhttp.internal.Util;
Expand All @@ -26,6 +29,9 @@
import com.squareup.okhttp.internal.spdy.IncomingStreamHandler;
import com.squareup.okhttp.internal.spdy.SpdyConnection;
import com.squareup.okhttp.internal.spdy.SpdyStream;
import com.squareup.okhttp.internal.ws.RealWebSocket;
import com.squareup.okhttp.internal.ws.WebSocketListener;
import com.squareup.okhttp.internal.ws.WebSocketProtocol;
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.ByteArrayOutputStream;
Expand All @@ -52,6 +58,7 @@
import java.util.Set;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingQueue;
Expand All @@ -67,6 +74,7 @@
import javax.net.ssl.X509TrustManager;
import okio.Buffer;
import okio.BufferedSink;
import okio.BufferedSource;
import okio.ByteString;
import okio.Okio;

Expand Down Expand Up @@ -433,8 +441,10 @@ private boolean processOneRequest(Socket socket, InputStream in, OutputStream ou
throws IOException, InterruptedException {
RecordedRequest request = readRequest(socket, in, out, sequenceNumber);
if (request == null) return false;

requestCount.incrementAndGet();
requestQueue.add(request);

MockResponse response = dispatcher.dispatch(request);
if (response.getSocketPolicy() == SocketPolicy.DISCONNECT_AFTER_REQUEST) {
socket.close();
Expand All @@ -445,7 +455,16 @@ private boolean processOneRequest(Socket socket, InputStream in, OutputStream ou
if (in.read() == -1) return false;
throw new ProtocolException("unexpected data");
}
writeResponse(socket, out, response);

boolean requestWantsWebSockets = "Upgrade".equalsIgnoreCase(request.getHeader("Connection"))
&& "websocket".equalsIgnoreCase(request.getHeader("Upgrade"));
boolean responseWantsWebSockets = response.getWebSocketListener() != null;
if (requestWantsWebSockets && responseWantsWebSockets) {
handleWebSocketUpgrade(socket, in, out, request, response);
} else {
writeHttpResponse(socket, out, response);
}

if (response.getSocketPolicy() == SocketPolicy.DISCONNECT_AT_END) {
in.close();
out.close();
Expand All @@ -458,6 +477,7 @@ private boolean processOneRequest(Socket socket, InputStream in, OutputStream ou
logger.info(MockWebServer.this + " received request: " + request
+ " and responded: " + response);
}

sequenceNumber++;
return true;
}
Expand Down Expand Up @@ -565,7 +585,79 @@ private RecordedRequest readRequest(Socket socket, InputStream in, OutputStream
new Buffer().write(requestBody.toByteArray()), sequenceNumber, socket);
}

private void writeResponse(Socket socket, OutputStream out, MockResponse response)
private void handleWebSocketUpgrade(Socket socket, InputStream in, OutputStream out,
RecordedRequest request, MockResponse response) throws IOException {
String key = request.getHeader("Sec-WebSocket-Key");
String acceptKey = Util.shaBase64(key + WebSocketProtocol.ACCEPT_MAGIC);
response.setHeader("Sec-WebSocket-Accept", acceptKey);

writeHttpResponse(socket, out, response);

BufferedSource source = Okio.buffer(Okio.source(in));
BufferedSink sink = Okio.buffer(Okio.sink(out));

final WebSocketListener listener = response.getWebSocketListener();
final CountDownLatch connectionClose = new CountDownLatch(1);
final RealWebSocket webSocket =
new RealWebSocket(false, source, sink, new SecureRandom(), listener,
request.getPath()) {
@Override protected void closeConnection() throws IOException {
connectionClose.countDown();
}
};

// Adapt the request and response into our Request and Response domain model.
Request.Builder fancyRequestBuilder = new Request.Builder()
.get().url(request.getPath());
List<String> requestHeaders = request.getHeaders();
Headers.Builder fancyRequestHeaders = new Headers.Builder();
for (int i = 0, size = requestHeaders.size(); i < size; i++) {
fancyRequestHeaders.add(requestHeaders.get(i));
}
fancyRequestBuilder.headers(fancyRequestHeaders.build());
final Request fancyRequest = fancyRequestBuilder.build();

Response.Builder fancyResponseBuilder = new Response.Builder()
.code(Integer.parseInt(response.getStatus().split(" ")[1]))
.message(response.getStatus().split(" ", 3)[2])
.request(fancyRequest)
.protocol(Protocol.HTTP_1_1);
List<String> responseHeaders = response.getHeaders();
Headers.Builder fancyResponseHeaders = new Headers.Builder();
for (int i = 0, size = responseHeaders.size(); i < size; i++) {
fancyRequestHeaders.add(responseHeaders.get(i));
}
fancyRequestBuilder.headers(fancyResponseHeaders.build());
final Response fancyResponse = fancyResponseBuilder.build();

// The callback might act synchronously. Give it its own thread.
new Thread(new Runnable() {
@Override public void run() {
try {
listener.onOpen(webSocket, fancyRequest, fancyResponse);
} catch (IOException e) {
// TODO try to write close frame?
connectionClose.countDown();
}
}
}, "MockWebServer WebSocket Writer " + request.getPath()).start();

// Use this thread to continuously read messages.
while (webSocket.readMessage()) {
}

// Even if messages are no longer being read we need to wait for the connection close signal.
try {
connectionClose.await();
} catch (InterruptedException e) {
throw new RuntimeException(e);
}

Util.closeQuietly(sink);
Util.closeQuietly(source);
}

private void writeHttpResponse(Socket socket, OutputStream out, MockResponse response)
throws IOException {
out.write((response.getStatus() + "\r\n").getBytes(Util.US_ASCII));
List<String> headers = response.getHeaders();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,14 @@
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import okio.Buffer;
import okio.BufferedSink;
import okio.BufferedSource;
import org.junit.After;
import org.junit.Rule;
import org.junit.Test;

import static com.squareup.okhttp.internal.ws.WebSocket.PayloadType.TEXT;

public final class WebSocketCallTest {
@Rule public final MockWebServerRule server = new MockWebServerRule();

Expand All @@ -44,6 +47,66 @@ public final class WebSocketCallTest {
listener.assertExhausted();
}

@Test public void clientPingPong() throws IOException {
WebSocketListener serverListener = new EmptyWebSocketListener();
server.enqueue(new MockResponse().withWebSocketUpgrade(serverListener));

WebSocket webSocket = awaitCall().webSocket;
webSocket.sendPing(new Buffer().writeUtf8("Hello, WebSockets!"));
listener.assertPong(new Buffer().writeUtf8("Hello, WebSockets!"));
}

@Test public void clientMessage() throws IOException {
WebSocketRecorder serverListener = new WebSocketRecorder();
server.enqueue(new MockResponse().withWebSocketUpgrade(serverListener));

WebSocket webSocket = awaitCall().webSocket;
webSocket.sendMessage(TEXT, new Buffer().writeUtf8("Hello, WebSockets!"));
serverListener.assertTextMessage("Hello, WebSockets!");
}

@Test public void serverMessage() throws IOException {
WebSocketListener serverListener = new EmptyWebSocketListener() {
@Override public void onOpen(WebSocket webSocket, Request request, Response response)
throws IOException {
webSocket.sendMessage(TEXT, new Buffer().writeUtf8("Hello, WebSockets!"));
}
};
server.enqueue(new MockResponse().withWebSocketUpgrade(serverListener));

awaitCall();
listener.assertTextMessage("Hello, WebSockets!");
}

@Test public void clientStreamingMessage() throws IOException {
WebSocketRecorder serverListener = new WebSocketRecorder();
server.enqueue(new MockResponse().withWebSocketUpgrade(serverListener));

WebSocket webSocket = awaitCall().webSocket;
BufferedSink sink = webSocket.newMessageSink(TEXT);
sink.writeUtf8("Hello, ").flush();
sink.writeUtf8("WebSockets!").flush();
sink.close();

serverListener.assertTextMessage("Hello, WebSockets!");
}

@Test public void serverStreamingMessage() throws IOException {
WebSocketListener serverListener = new EmptyWebSocketListener() {
@Override public void onOpen(WebSocket webSocket, Request request, Response response)
throws IOException {
BufferedSink sink = webSocket.newMessageSink(TEXT);
sink.writeUtf8("Hello, ").flush();
sink.writeUtf8("WebSockets!").flush();
sink.close();
}
};
server.enqueue(new MockResponse().withWebSocketUpgrade(serverListener));

awaitCall();
listener.assertTextMessage("Hello, WebSockets!");
}

@Test public void okButNotOk() {
server.enqueue(new MockResponse());
awaitCall();
Expand Down Expand Up @@ -167,4 +230,23 @@ private RecordedResponse awaitCall() {
return new RecordedResponse(request, responseRef.get(), webSocketRef.get(), null,
failureRef.get());
}

private static class EmptyWebSocketListener implements WebSocketListener {
@Override public void onOpen(WebSocket webSocket, Request request, Response response)
throws IOException {
}

@Override public void onMessage(BufferedSource payload, WebSocket.PayloadType type)
throws IOException {
}

@Override public void onPong(Buffer payload) {
}

@Override public void onClose(int code, String reason) {
}

@Override public void onFailure(IOException e) {
}
}
}
Loading

0 comments on commit 7000edd

Please sign in to comment.