16
16
17
17
package org .springframework .web .socket .messaging ;
18
18
19
+ import java .io .IOException ;
19
20
import java .util .ArrayList ;
20
21
import java .util .Arrays ;
21
22
import java .util .HashSet ;
24
25
import java .util .Set ;
25
26
import java .util .TreeMap ;
26
27
import java .util .concurrent .ConcurrentHashMap ;
28
+ import java .util .concurrent .locks .ReentrantLock ;
27
29
28
30
import org .apache .commons .logging .Log ;
29
31
import org .apache .commons .logging .LogFactory ;
64
66
public class SubProtocolWebSocketHandler implements WebSocketHandler ,
65
67
SubProtocolCapable , MessageHandler , SmartLifecycle {
66
68
69
+ /**
70
+ * Sessions connected to this handler use a sub-protocol. Hence we expect to
71
+ * receive some client messages. If we don't receive any within a minute, the
72
+ * connection isn't doing well (proxy issue, slow network?) and can be closed.
73
+ * @see #checkSessions()
74
+ */
75
+ private final int TIME_TO_FIRST_MESSAGE = 60 * 1000 ;
76
+
77
+
67
78
private final Log logger = LogFactory .getLog (SubProtocolWebSocketHandler .class );
68
79
80
+
69
81
private final MessageChannel clientInboundChannel ;
70
82
71
83
private final SubscribableChannel clientOutboundChannel ;
@@ -75,12 +87,16 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler,
75
87
76
88
private SubProtocolHandler defaultProtocolHandler ;
77
89
78
- private final Map <String , WebSocketSession > sessions = new ConcurrentHashMap <String , WebSocketSession >();
90
+ private final Map <String , WebSocketSessionHolder > sessions = new ConcurrentHashMap <String , WebSocketSessionHolder >();
79
91
80
92
private int sendTimeLimit = 10 * 1000 ;
81
93
82
94
private int sendBufferSizeLimit = 512 * 1024 ;
83
95
96
+ private volatile long lastSessionCheckTime = System .currentTimeMillis ();
97
+
98
+ private final ReentrantLock sessionCheckLock = new ReentrantLock ();
99
+
84
100
private final Object lifecycleMonitor = new Object ();
85
101
86
102
private volatile boolean running = false ;
@@ -214,12 +230,12 @@ public final void stop() {
214
230
this .clientOutboundChannel .unsubscribe (this );
215
231
216
232
// Notify sessions to stop flushing messages
217
- for (WebSocketSession session : this .sessions .values ()) {
233
+ for (WebSocketSessionHolder holder : this .sessions .values ()) {
218
234
try {
219
- session .close (CloseStatus .GOING_AWAY );
235
+ holder . getSession () .close (CloseStatus .GOING_AWAY );
220
236
}
221
237
catch (Throwable t ) {
222
- logger .error ("Failed to close session id '" + session . getId () + "': " + t .getMessage ());
238
+ logger .error ("Failed to close '" + holder . getSession () + "': " + t .getMessage ());
223
239
}
224
240
}
225
241
}
@@ -235,15 +251,11 @@ public final void stop(Runnable callback) {
235
251
236
252
@ Override
237
253
public void afterConnectionEstablished (WebSocketSession session ) throws Exception {
238
-
239
254
session = new ConcurrentWebSocketSessionDecorator (session , getSendTimeLimit (), getSendBufferSizeLimit ());
240
-
241
- this .sessions .put (session .getId (), session );
255
+ this .sessions .put (session .getId (), new WebSocketSessionHolder (session ));
242
256
if (logger .isDebugEnabled ()) {
243
- logger .debug ("Started WebSocket session=" + session .getId () +
244
- ", number of sessions=" + this .sessions .size ());
257
+ logger .debug ("Started session " + session .getId () + ", number of sessions=" + this .sessions .size ());
245
258
}
246
-
247
259
findProtocolHandler (session ).afterSessionStarted (session , this .clientInboundChannel );
248
260
}
249
261
@@ -283,41 +295,49 @@ protected final SubProtocolHandler findProtocolHandler(WebSocketSession session)
283
295
284
296
@ Override
285
297
public void handleMessage (WebSocketSession session , WebSocketMessage <?> message ) throws Exception {
286
- findProtocolHandler (session ).handleMessageFromClient (session , message , this .clientInboundChannel );
298
+ SubProtocolHandler protocolHandler = findProtocolHandler (session );
299
+ protocolHandler .handleMessageFromClient (session , message , this .clientInboundChannel );
300
+ WebSocketSessionHolder holder = this .sessions .get (session .getId ());
301
+ if (holder != null ) {
302
+ holder .setHasHandledMessages ();
303
+ }
304
+ else {
305
+ // Should never happen
306
+ throw new IllegalStateException ("Session not found: " + session );
307
+ }
308
+ checkSessions ();
287
309
}
288
310
289
311
@ Override
290
312
public void handleMessage (Message <?> message ) throws MessagingException {
291
-
292
313
String sessionId = resolveSessionId (message );
293
314
if (sessionId == null ) {
294
315
logger .error ("sessionId not found in message " + message );
295
316
return ;
296
317
}
297
-
298
- WebSocketSession session = this .sessions .get (sessionId );
299
- if (session == null ) {
318
+ WebSocketSessionHolder holder = this .sessions .get (sessionId );
319
+ if (holder == null ) {
300
320
logger .error ("Session not found for session with id '" + sessionId + "', ignoring message " + message );
301
321
return ;
302
322
}
303
-
323
+ WebSocketSession session = holder . getSession ();
304
324
try {
305
325
findProtocolHandler (session ).handleMessageToClient (session , message );
306
326
}
307
327
catch (SessionLimitExceededException ex ) {
308
328
try {
309
- logger .error ("Terminating session id '" + sessionId + "'" , ex );
329
+ logger .error ("Terminating '" + session + "'" , ex );
310
330
311
331
// Session may be unresponsive so clear first
312
332
clearSession (session , ex .getStatus ());
313
333
session .close (ex .getStatus ());
314
334
}
315
335
catch (Exception secondException ) {
316
- logger .error ("Exception terminating session id '" + sessionId + "'" , secondException );
336
+ logger .error ("Exception terminating '" + sessionId + "'" , secondException );
317
337
}
318
338
}
319
339
catch (Exception e ) {
320
- logger .error ("Failed to send message to client " + message , e );
340
+ logger .error ("Failed to send message to client " + message + " in " + session , e );
321
341
}
322
342
}
323
343
@@ -337,6 +357,43 @@ private String resolveSessionId(Message<?> message) {
337
357
return null ;
338
358
}
339
359
360
+ /**
361
+ * Periodically check sessions to ensure they have received at least one
362
+ * message or otherwise close them.
363
+ */
364
+ private void checkSessions () throws IOException {
365
+ long currentTime = System .currentTimeMillis ();
366
+ if (!isRunning () && currentTime - this .lastSessionCheckTime < TIME_TO_FIRST_MESSAGE ) {
367
+ return ;
368
+ }
369
+ try {
370
+ if (this .sessionCheckLock .tryLock ()) {
371
+ for (WebSocketSessionHolder holder : this .sessions .values ()) {
372
+ if (holder .hasHandledMessages ()) {
373
+ continue ;
374
+ }
375
+ long timeSinceCreated = currentTime - holder .getCreateTime ();
376
+ if (holder .hasHandledMessages () || timeSinceCreated < TIME_TO_FIRST_MESSAGE ) {
377
+ continue ;
378
+ }
379
+ WebSocketSession session = holder .getSession ();
380
+ if (logger .isErrorEnabled ()) {
381
+ logger .error ("No messages received after " + timeSinceCreated + " ms. Closing " + holder );
382
+ }
383
+ try {
384
+ session .close (CloseStatus .PROTOCOL_ERROR );
385
+ }
386
+ catch (Throwable t ) {
387
+ logger .error ("Failed to close " + session , t );
388
+ }
389
+ }
390
+ }
391
+ }
392
+ finally {
393
+ this .sessionCheckLock .unlock ();
394
+ }
395
+ }
396
+
340
397
@ Override
341
398
public void handleTransportError (WebSocketSession session , Throwable exception ) throws Exception {
342
399
}
@@ -356,4 +413,45 @@ public boolean supportsPartialMessages() {
356
413
return false ;
357
414
}
358
415
416
+
417
+ private static class WebSocketSessionHolder {
418
+
419
+ private final WebSocketSession session ;
420
+
421
+ private final long createTime = System .currentTimeMillis ();
422
+
423
+ private volatile boolean handledMessages ;
424
+
425
+
426
+ private WebSocketSessionHolder (WebSocketSession session ) {
427
+ this .session = session ;
428
+ }
429
+
430
+ public WebSocketSession getSession () {
431
+ return this .session ;
432
+ }
433
+
434
+ public long getCreateTime () {
435
+ return this .createTime ;
436
+ }
437
+
438
+ public void setHasHandledMessages () {
439
+ this .handledMessages = true ;
440
+ }
441
+
442
+ public boolean hasHandledMessages () {
443
+ return this .handledMessages ;
444
+ }
445
+
446
+ @ Override
447
+ public String toString () {
448
+ if (this .session instanceof ConcurrentWebSocketSessionDecorator ) {
449
+ return ((ConcurrentWebSocketSessionDecorator ) this .session ).getLastSession ().toString ();
450
+ }
451
+ else {
452
+ return this .session .toString ();
453
+ }
454
+ }
455
+ }
456
+
359
457
}
0 commit comments