Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue-1160 Added java.lang.Error handling in WebSocketImpl and WebSocketServer #1223

Merged
merged 8 commits into from
Mar 25, 2022
9 changes: 9 additions & 0 deletions src/main/java/org/java_websocket/WebSocketImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,15 @@ private void decodeFrames(ByteBuffer socketBuffer) {
log.error("Closing due to invalid data in frame", e);
wsl.onWebsocketError(this, e);
close(e);
} catch (VirtualMachineError | ThreadDeath | LinkageError e) {
log.error("Got fatal error during frame processing");
throw e;
} catch (Error e) {
log.error("Closing web socket due to an error during frame processing");
Exception exception = new Exception(e);
wsl.onWebsocketError(this, exception);
String errorMessage = "Got error " + e.getClass().getName();
close(CloseFrame.UNEXPECTED_CONDITION, errorMessage);
}
}

Expand Down
16 changes: 14 additions & 2 deletions src/main/java/org/java_websocket/server/WebSocketServer.java
Original file line number Diff line number Diff line change
Expand Up @@ -1079,8 +1079,20 @@ public void run() {
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
} catch (RuntimeException e) {
handleFatal(ws, e);
} catch (VirtualMachineError | ThreadDeath | LinkageError e) {
if (ws != null) {
ws.close();
}
log.error("Got fatal error in worker thread {}", getName());
Exception exception = new Exception(e);
handleFatal(ws, exception);
} catch (Throwable e) {
log.error("Uncaught exception in thread {}: {}", getName(), e);
if (ws != null) {
Exception exception = new Exception(e);
onWebsocketError(ws, exception);
ws.close();
}
}
}

Expand Down
159 changes: 159 additions & 0 deletions src/test/java/org/java_websocket/issues/Issue1160Test.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
package org.java_websocket.issues;

import org.java_websocket.WebSocket;
import org.java_websocket.client.WebSocketClient;
import org.java_websocket.handshake.ClientHandshake;
import org.java_websocket.handshake.ServerHandshake;
import org.java_websocket.server.WebSocketServer;
import org.java_websocket.util.SocketUtil;
import org.junit.Assert;
import org.junit.Test;

import java.net.InetSocketAddress;
import java.net.URI;
import java.nio.ByteBuffer;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicInteger;

public class Issue1160Test {
private final CountDownLatch countServerStart = new CountDownLatch(1);

static class TestClient extends WebSocketClient {
private final CountDownLatch onCloseLatch;

public TestClient(URI uri, CountDownLatch latch) {
super(uri);
onCloseLatch = latch;
}

@Override
public void onOpen(ServerHandshake handshakedata) {
}

@Override
public void onMessage(String message) {
}

@Override
public void onClose(int code, String reason, boolean remote) {
onCloseLatch.countDown();
}

@Override
public void onError(Exception ex) {
}
}


@Test(timeout = 5000)
public void nonFatalErrorShallBeHandledByServer() throws Exception {
final AtomicInteger isServerOnErrorCalledCounter = new AtomicInteger(0);

int port = SocketUtil.getAvailablePort();
WebSocketServer server = new WebSocketServer(new InetSocketAddress(port)) {
@Override
public void onOpen(WebSocket conn, ClientHandshake handshake) {
}

@Override
public void onClose(WebSocket conn, int code, String reason, boolean remote) {
}

@Override
public void onMessage(WebSocket conn, ByteBuffer message) {
throw new Error("Some error");
}

@Override
public void onMessage(WebSocket conn, String message) {
throw new Error("Some error");
}

@Override
public void onError(WebSocket conn, Exception ex) {
isServerOnErrorCalledCounter.incrementAndGet();
}

@Override
public void onStart() {
countServerStart.countDown();
}
};


server.setConnectionLostTimeout(10);
server.start();
countServerStart.await();

URI uri = new URI("ws://localhost:" + port);

int CONNECTION_COUNT = 3;
for (int i = 0; i < CONNECTION_COUNT; i++) {
CountDownLatch countClientDownLatch = new CountDownLatch(1);
WebSocketClient client = new TestClient(uri, countClientDownLatch);
client.setConnectionLostTimeout(10);

client.connectBlocking();
client.send(new byte[100]);
countClientDownLatch.await();
client.closeBlocking();
}

Assert.assertEquals(CONNECTION_COUNT, isServerOnErrorCalledCounter.get());

server.stop();
}

@Test(timeout = 5000)
public void fatalErrorShallNotBeHandledByServer() throws Exception {
int port = SocketUtil.getAvailablePort();

final CountDownLatch countServerDownLatch = new CountDownLatch(1);
WebSocketServer server = new WebSocketServer(new InetSocketAddress(port)) {
@Override
public void onOpen(WebSocket conn, ClientHandshake handshake) {
}

@Override
public void onClose(WebSocket conn, int code, String reason, boolean remote) {
countServerDownLatch.countDown();
}

@Override
public void onMessage(WebSocket conn, ByteBuffer message) {
throw new OutOfMemoryError("Some error");
}

@Override
public void onMessage(WebSocket conn, String message) {
throw new OutOfMemoryError("Some error");
}

@Override
public void onError(WebSocket conn, Exception ex) {
}

@Override
public void onStart() {
countServerStart.countDown();
}
};


server.setConnectionLostTimeout(10);
server.start();
countServerStart.await();

URI uri = new URI("ws://localhost:" + port);

CountDownLatch countClientDownLatch = new CountDownLatch(1);
WebSocketClient client = new TestClient(uri, countClientDownLatch);
client.setConnectionLostTimeout(10);

client.connectBlocking();
client.send(new byte[100]);
countClientDownLatch.await();
countServerDownLatch.await();
Assert.assertTrue(countClientDownLatch.getCount() == 0 && countServerDownLatch.getCount() == 0);
}
}