Skip to content

Commit a3fa9c9

Browse files
committed
Add check for unused WebSocket sessions
Sessions connected to a STOMP endpoint are expected to receive some client messages. Having received none after successfully connecting could be an indication of proxy or network issue. This change adds periodic checks to see if we have not received any messages on a session which is an indication the session isn't going anywhere most likely due to a proxy issue (or unreliable network) and close those sessions. Issue: SPR-11884
1 parent 98d6f7b commit a3fa9c9

File tree

2 files changed

+153
-21
lines changed

2 files changed

+153
-21
lines changed

spring-websocket/src/main/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandler.java

Lines changed: 117 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
package org.springframework.web.socket.messaging;
1818

19+
import java.io.IOException;
1920
import java.util.ArrayList;
2021
import java.util.Arrays;
2122
import java.util.HashSet;
@@ -24,6 +25,7 @@
2425
import java.util.Set;
2526
import java.util.TreeMap;
2627
import java.util.concurrent.ConcurrentHashMap;
28+
import java.util.concurrent.locks.ReentrantLock;
2729

2830
import org.apache.commons.logging.Log;
2931
import org.apache.commons.logging.LogFactory;
@@ -64,8 +66,18 @@
6466
public class SubProtocolWebSocketHandler implements WebSocketHandler,
6567
SubProtocolCapable, MessageHandler, SmartLifecycle {
6668

69+
/**
70+
* Sessions connected to this handler use a sub-protocol. Hence we expect to
71+
* receive some client messages. If we don't receive any within a minute, the
72+
* connection isn't doing well (proxy issue, slow network?) and can be closed.
73+
* @see #checkSessions()
74+
*/
75+
private final int TIME_TO_FIRST_MESSAGE = 60 * 1000;
76+
77+
6778
private final Log logger = LogFactory.getLog(SubProtocolWebSocketHandler.class);
6879

80+
6981
private final MessageChannel clientInboundChannel;
7082

7183
private final SubscribableChannel clientOutboundChannel;
@@ -75,12 +87,16 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler,
7587

7688
private SubProtocolHandler defaultProtocolHandler;
7789

78-
private final Map<String, WebSocketSession> sessions = new ConcurrentHashMap<String, WebSocketSession>();
90+
private final Map<String, WebSocketSessionHolder> sessions = new ConcurrentHashMap<String, WebSocketSessionHolder>();
7991

8092
private int sendTimeLimit = 10 * 1000;
8193

8294
private int sendBufferSizeLimit = 512 * 1024;
8395

96+
private volatile long lastSessionCheckTime = System.currentTimeMillis();
97+
98+
private final ReentrantLock sessionCheckLock = new ReentrantLock();
99+
84100
private final Object lifecycleMonitor = new Object();
85101

86102
private volatile boolean running = false;
@@ -214,12 +230,12 @@ public final void stop() {
214230
this.clientOutboundChannel.unsubscribe(this);
215231

216232
// Notify sessions to stop flushing messages
217-
for (WebSocketSession session : this.sessions.values()) {
233+
for (WebSocketSessionHolder holder : this.sessions.values()) {
218234
try {
219-
session.close(CloseStatus.GOING_AWAY);
235+
holder.getSession().close(CloseStatus.GOING_AWAY);
220236
}
221237
catch (Throwable t) {
222-
logger.error("Failed to close session id '" + session.getId() + "': " + t.getMessage());
238+
logger.error("Failed to close '" + holder.getSession() + "': " + t.getMessage());
223239
}
224240
}
225241
}
@@ -235,15 +251,11 @@ public final void stop(Runnable callback) {
235251

236252
@Override
237253
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
238-
239254
session = new ConcurrentWebSocketSessionDecorator(session, getSendTimeLimit(), getSendBufferSizeLimit());
240-
241-
this.sessions.put(session.getId(), session);
255+
this.sessions.put(session.getId(), new WebSocketSessionHolder(session));
242256
if (logger.isDebugEnabled()) {
243-
logger.debug("Started WebSocket session=" + session.getId() +
244-
", number of sessions=" + this.sessions.size());
257+
logger.debug("Started session " + session.getId() + ", number of sessions=" + this.sessions.size());
245258
}
246-
247259
findProtocolHandler(session).afterSessionStarted(session, this.clientInboundChannel);
248260
}
249261

@@ -283,41 +295,49 @@ protected final SubProtocolHandler findProtocolHandler(WebSocketSession session)
283295

284296
@Override
285297
public void handleMessage(WebSocketSession session, WebSocketMessage<?> message) throws Exception {
286-
findProtocolHandler(session).handleMessageFromClient(session, message, this.clientInboundChannel);
298+
SubProtocolHandler protocolHandler = findProtocolHandler(session);
299+
protocolHandler.handleMessageFromClient(session, message, this.clientInboundChannel);
300+
WebSocketSessionHolder holder = this.sessions.get(session.getId());
301+
if (holder != null) {
302+
holder.setHasHandledMessages();
303+
}
304+
else {
305+
// Should never happen
306+
throw new IllegalStateException("Session not found: " + session);
307+
}
308+
checkSessions();
287309
}
288310

289311
@Override
290312
public void handleMessage(Message<?> message) throws MessagingException {
291-
292313
String sessionId = resolveSessionId(message);
293314
if (sessionId == null) {
294315
logger.error("sessionId not found in message " + message);
295316
return;
296317
}
297-
298-
WebSocketSession session = this.sessions.get(sessionId);
299-
if (session == null) {
318+
WebSocketSessionHolder holder = this.sessions.get(sessionId);
319+
if (holder == null) {
300320
logger.error("Session not found for session with id '" + sessionId + "', ignoring message " + message);
301321
return;
302322
}
303-
323+
WebSocketSession session = holder.getSession();
304324
try {
305325
findProtocolHandler(session).handleMessageToClient(session, message);
306326
}
307327
catch (SessionLimitExceededException ex) {
308328
try {
309-
logger.error("Terminating session id '" + sessionId + "'", ex);
329+
logger.error("Terminating '" + session + "'", ex);
310330

311331
// Session may be unresponsive so clear first
312332
clearSession(session, ex.getStatus());
313333
session.close(ex.getStatus());
314334
}
315335
catch (Exception secondException) {
316-
logger.error("Exception terminating session id '" + sessionId + "'", secondException);
336+
logger.error("Exception terminating '" + sessionId + "'", secondException);
317337
}
318338
}
319339
catch (Exception e) {
320-
logger.error("Failed to send message to client " + message, e);
340+
logger.error("Failed to send message to client " + message + " in " + session, e);
321341
}
322342
}
323343

@@ -337,6 +357,43 @@ private String resolveSessionId(Message<?> message) {
337357
return null;
338358
}
339359

360+
/**
361+
* Periodically check sessions to ensure they have received at least one
362+
* message or otherwise close them.
363+
*/
364+
private void checkSessions() throws IOException {
365+
long currentTime = System.currentTimeMillis();
366+
if (!isRunning() && currentTime - this.lastSessionCheckTime < TIME_TO_FIRST_MESSAGE) {
367+
return;
368+
}
369+
try {
370+
if (this.sessionCheckLock.tryLock()) {
371+
for (WebSocketSessionHolder holder : this.sessions.values()) {
372+
if (holder.hasHandledMessages()) {
373+
continue;
374+
}
375+
long timeSinceCreated = currentTime - holder.getCreateTime();
376+
if (holder.hasHandledMessages() || timeSinceCreated < TIME_TO_FIRST_MESSAGE) {
377+
continue;
378+
}
379+
WebSocketSession session = holder.getSession();
380+
if (logger.isErrorEnabled()) {
381+
logger.error("No messages received after " + timeSinceCreated + " ms. Closing " + holder);
382+
}
383+
try {
384+
session.close(CloseStatus.PROTOCOL_ERROR);
385+
}
386+
catch (Throwable t) {
387+
logger.error("Failed to close " + session, t);
388+
}
389+
}
390+
}
391+
}
392+
finally {
393+
this.sessionCheckLock.unlock();
394+
}
395+
}
396+
340397
@Override
341398
public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {
342399
}
@@ -356,4 +413,45 @@ public boolean supportsPartialMessages() {
356413
return false;
357414
}
358415

416+
417+
private static class WebSocketSessionHolder {
418+
419+
private final WebSocketSession session;
420+
421+
private final long createTime = System.currentTimeMillis();
422+
423+
private volatile boolean handledMessages;
424+
425+
426+
private WebSocketSessionHolder(WebSocketSession session) {
427+
this.session = session;
428+
}
429+
430+
public WebSocketSession getSession() {
431+
return this.session;
432+
}
433+
434+
public long getCreateTime() {
435+
return this.createTime;
436+
}
437+
438+
public void setHasHandledMessages() {
439+
this.handledMessages = true;
440+
}
441+
442+
public boolean hasHandledMessages() {
443+
return this.handledMessages;
444+
}
445+
446+
@Override
447+
public String toString() {
448+
if (this.session instanceof ConcurrentWebSocketSessionDecorator) {
449+
return ((ConcurrentWebSocketSessionDecorator) this.session).getLastSession().toString();
450+
}
451+
else {
452+
return this.session.toString();
453+
}
454+
}
455+
}
456+
359457
}

spring-websocket/src/test/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandlerTests.java

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,24 @@
1717
package org.springframework.web.socket.messaging;
1818

1919
import java.util.Arrays;
20+
import java.util.Map;
2021

2122
import org.junit.Before;
2223
import org.junit.Test;
2324
import org.mockito.Mock;
2425
import org.mockito.MockitoAnnotations;
26+
import org.springframework.beans.DirectFieldAccessor;
2527
import org.springframework.messaging.MessageChannel;
2628
import org.springframework.messaging.SubscribableChannel;
29+
import org.springframework.web.socket.CloseStatus;
30+
import org.springframework.web.socket.TextMessage;
2731
import org.springframework.web.socket.handler.ConcurrentWebSocketSessionDecorator;
2832
import org.springframework.web.socket.handler.TestWebSocketSession;
2933

34+
import static org.junit.Assert.assertEquals;
35+
import static org.junit.Assert.assertFalse;
36+
import static org.junit.Assert.assertNull;
37+
import static org.junit.Assert.assertTrue;
3038
import static org.mockito.Mockito.*;
3139

3240
/**
@@ -56,11 +64,9 @@ public class SubProtocolWebSocketHandlerTests {
5664
@Before
5765
public void setup() {
5866
MockitoAnnotations.initMocks(this);
59-
6067
this.webSocketHandler = new SubProtocolWebSocketHandler(this.inClientChannel, this.outClientChannel);
6168
when(stompHandler.getSupportedProtocols()).thenReturn(Arrays.asList("v10.stomp", "v11.stomp", "v12.stomp"));
6269
when(mqttHandler.getSupportedProtocols()).thenReturn(Arrays.asList("MQTT"));
63-
6470
this.session = new TestWebSocketSession();
6571
this.session.setId("1");
6672
}
@@ -140,4 +146,32 @@ public void noSubProtocolNoDefaultHandler() throws Exception {
140146
this.webSocketHandler.afterConnectionEstablished(session);
141147
}
142148

149+
@Test
150+
public void checkSession() throws Exception {
151+
TestWebSocketSession session1 = new TestWebSocketSession("id1");
152+
TestWebSocketSession session2 = new TestWebSocketSession("id2");
153+
session1.setAcceptedProtocol("v12.stomp");
154+
session2.setAcceptedProtocol("v12.stomp");
155+
156+
this.webSocketHandler.setProtocolHandlers(Arrays.asList(this.stompHandler));
157+
this.webSocketHandler.afterConnectionEstablished(session1);
158+
this.webSocketHandler.afterConnectionEstablished(session2);
159+
session1.setOpen(true);
160+
session2.setOpen(true);
161+
162+
long sixtyOneSecondsAgo = System.currentTimeMillis() - 61 * 1000;
163+
new DirectFieldAccessor(this.webSocketHandler).setPropertyValue("lastSessionCheckTime", sixtyOneSecondsAgo);
164+
Map<String, ?> sessions = (Map<String, ?>) new DirectFieldAccessor(this.webSocketHandler).getPropertyValue("sessions");
165+
new DirectFieldAccessor(sessions.get("id1")).setPropertyValue("createTime", sixtyOneSecondsAgo);
166+
new DirectFieldAccessor(sessions.get("id2")).setPropertyValue("createTime", sixtyOneSecondsAgo);
167+
168+
this.webSocketHandler.handleMessage(session1, new TextMessage("foo"));
169+
170+
assertTrue(session1.isOpen());
171+
assertFalse(session2.isOpen());
172+
assertNull(session1.getCloseStatus());
173+
assertEquals(CloseStatus.PROTOCOL_ERROR, session2.getCloseStatus());
174+
}
175+
176+
143177
}

0 commit comments

Comments
 (0)