Skip to content

Commit b1d1280

Browse files
authored
Refactor Socket Reconnect (davidstump#145)
* Handling abnormal close * Added specs for socket reconnect changes
1 parent d5c8765 commit b1d1280

File tree

3 files changed

+176
-46
lines changed

3 files changed

+176
-46
lines changed

Sources/client/Socket.swift

Lines changed: 46 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ public class Socket {
5151

5252

5353

54+
5455
//----------------------------------------------------------------------
5556
// MARK: - Public Attributes
5657
//----------------------------------------------------------------------
@@ -83,8 +84,11 @@ public class Socket {
8384
/// Interval between sending a heartbeat
8485
public var heartbeatInterval: TimeInterval = Defaults.heartbeatInterval
8586

86-
/// Internval between socket reconnect attempts
87-
public var reconnectAfter: (Int) -> TimeInterval = Defaults.steppedBackOff
87+
/// Interval between socket reconnect attempts, in seconds
88+
public var reconnectAfter: (Int) -> TimeInterval = Defaults.reconnectSteppedBackOff
89+
90+
/// Interval between channel rejoin attempts, in seconds
91+
public var rejoinAfter: (Int) -> TimeInterval = Defaults.rejoinSteppedBackOff
8892

8993
/// The optional function to receive logs
9094
public var logger: ((String) -> Void)?
@@ -133,6 +137,9 @@ public class Socket {
133137
/// Timer to use when attempting to reconnect
134138
var reconnectTimer: TimeoutTimer
135139

140+
/// True if the Socket closed cleaned. False if not (connection timeout, heartbeat, etc)
141+
var closeWasClean: Bool = false
142+
136143
/// Websocket connection to the server
137144
var connection: WebSocketClient?
138145

@@ -224,6 +231,9 @@ public class Socket {
224231
// Do not attempt to reconnect if the socket is currently connected
225232
guard !isConnected else { return }
226233

234+
// Reset the clean close flag when attempting to connect
235+
self.closeWasClean = false
236+
227237
self.connection = self.transport(endPointUrl)
228238
self.connection?.delegate = self
229239
self.connection?.disableSSLCertValidation = disableSSLCertValidation
@@ -243,6 +253,10 @@ public class Socket {
243253
/// - paramter callback: Optional. Called when disconnected
244254
public func disconnect(code: CloseCode = CloseCode.normal,
245255
callback: (() -> Void)? = nil) {
256+
// The socket was closed cleanly by the User
257+
self.closeWasClean = true
258+
259+
// Reset any reconnects and teardown the socket connection
246260
self.reconnectTimer.reset()
247261
self.teardown(code: code, callback: callback)
248262
}
@@ -520,6 +534,9 @@ public class Socket {
520534
internal func onConnectionOpen() {
521535
self.logItems("transport", "Connected to \(endPoint)")
522536

537+
// Reset the closeWasClean flag now that the socket has been connected
538+
self.closeWasClean = false
539+
523540
// Send any messages that were waiting for a connection
524541
self.flushSendBuffer()
525542

@@ -541,13 +558,12 @@ public class Socket {
541558
self.heartbeatTimer?.invalidate()
542559
self.heartbeatTimer = nil
543560

544-
self.stateChangeCallbacks.close.forEach({ $0.call() })
561+
// Only attempt to reconnect if the socket did not close normally
562+
if (!self.closeWasClean) {
563+
self.reconnectTimer.scheduleTimeout()
564+
}
545565

546-
// If there was a non-normal event when the connection closed, attempt
547-
// to schedule a reconnect attempt
548-
let closeCode = CloseCode.init(rawValue: UInt16(code ?? 0))
549-
guard closeCode != CloseCode.normal else { return }
550-
self.reconnectTimer.scheduleTimeout()
566+
self.stateChangeCallbacks.close.forEach({ $0.call() })
551567
}
552568

553569
internal func onConnectionError(_ error: Error) {
@@ -585,7 +601,12 @@ public class Socket {
585601

586602
/// Triggers an error event to all of the connected Channels
587603
internal func triggerChannelError() {
588-
self.channels.forEach( { $0.trigger(event: ChannelEvent.error) } )
604+
self.channels.forEach { (channel) in
605+
// Only trigger a channel error if it is in an "opened" state
606+
if !(channel.isErrored || channel.isLeaving || channel.isClosed) {
607+
channel.trigger(event: ChannelEvent.error)
608+
}
609+
}
589610
}
590611

591612
/// Send all messages that were buffered before the socket opened
@@ -637,6 +658,9 @@ public class Socket {
637658
self.logItems("transport",
638659
"heartbeat timeout. Attempting to re-establish connection")
639660

661+
// Close the socket, flagging the closure as abnormal
662+
self.abnormalClose("heartbeat timeout")
663+
640664
// Disconnect the socket manually. Do not use `teardown` or
641665
// `disconnect` as they will nil out the websocket delegate
642666
self.connection?.disconnect(forceTimeout: nil,
@@ -651,6 +675,19 @@ public class Socket {
651675
payload: [:],
652676
ref: self.pendingHeartbeatRef)
653677
}
678+
679+
internal func abnormalClose(_ reason: String) {
680+
self.closeWasClean = false
681+
682+
/*
683+
We use NORMAL here since the client is the one determining to close the
684+
connection. However, we keep a flag `closeWasClean` set to false so that
685+
the client knows that it should attempt to reconnect.
686+
*/
687+
self.connection?.disconnect(forceTimeout: nil,
688+
closeCode: CloseCode.normal.rawValue)
689+
690+
}
654691
}
655692

656693

Sources/client/Utilities/Defaults.swift

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,17 @@ public class Defaults {
2828
/// Default interval to send heartbeats on
2929
public static let heartbeatInterval: TimeInterval = 30.0
3030

31-
/// Default reconnect function
32-
public static let steppedBackOff: (Int) -> TimeInterval = { tries in
33-
return tries > 4 ? 10 : [1, 2, 5, 10][tries - 1]
31+
/// Default reconnect algorithm for the socket
32+
public static let reconnectSteppedBackOff: (Int) -> TimeInterval = { tries in
33+
return tries > 9 ? 5.0 : [0.01, 0.05, 0.1, 0.15, 0.2, 0.25, 0.5, 1.0, 2.0][tries - 1]
3434
}
3535

36+
/** Default rejoin algorithm for individual channels */
37+
public static let rejoinSteppedBackOff: (Int) -> TimeInterval = { tries in
38+
return tries > 3 ? 10 : [1, 2, 5][tries - 1]
39+
}
40+
41+
3642
/// Default encode function, utilizing JSONSerialization.data
3743
public static let encode: ([String: Any]) -> Data = { json in
3844
return try! JSONSerialization

Tests/client/SocketSpec.swift

Lines changed: 121 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,17 @@ class SocketSpec: QuickSpec {
3232
expect(socket.timeout).to(equal(Defaults.timeoutInterval))
3333
expect(socket.heartbeatInterval).to(equal(Defaults.heartbeatInterval))
3434
expect(socket.logger).to(beNil())
35-
expect(socket.reconnectAfter(1)).to(equal(1))
36-
expect(socket.reconnectAfter(2)).to(equal(2))
37-
expect(socket.reconnectAfter(3)).to(equal(5))
38-
expect(socket.reconnectAfter(4)).to(equal(10))
39-
expect(socket.reconnectAfter(5)).to(equal(10))
35+
expect(socket.reconnectAfter(1)).to(equal(0.010)) // 10ms
36+
expect(socket.reconnectAfter(2)).to(equal(0.050)) // 50ms
37+
expect(socket.reconnectAfter(3)).to(equal(0.100)) // 100ms
38+
expect(socket.reconnectAfter(4)).to(equal(0.150)) // 150ms
39+
expect(socket.reconnectAfter(5)).to(equal(0.200)) // 200ms
40+
expect(socket.reconnectAfter(6)).to(equal(0.250)) // 250ms
41+
expect(socket.reconnectAfter(7)).to(equal(0.500)) // 500ms
42+
expect(socket.reconnectAfter(8)).to(equal(1.000)) // 1_000ms (1s)
43+
expect(socket.reconnectAfter(9)).to(equal(2.000)) // 2_000ms (2s)
44+
expect(socket.reconnectAfter(10)).to(equal(5.00)) // 5_000ms (5s)
45+
expect(socket.reconnectAfter(11)).to(equal(5.00)) // 5_000ms (5s)
4046
})
4147

4248
it("overrides some defaults", closure: {
@@ -110,6 +116,7 @@ class SocketSpec: QuickSpec {
110116
})
111117
}
112118

119+
113120
describe("websocketProtocol") {
114121
it("returns wss when protocol is https", closure: {
115122
let socket = Socket("https://example.com/")
@@ -268,6 +275,13 @@ class SocketSpec: QuickSpec {
268275
.to(equal(CloseCode.normal.rawValue))
269276
})
270277

278+
it("flags the socket as closed cleanly", closure: {
279+
expect(socket.closeWasClean).to(beFalse())
280+
281+
socket.disconnect()
282+
expect(socket.closeWasClean).to(beTrue())
283+
})
284+
271285
it("calls callback", closure: {
272286
var callCount = 0
273287
socket.connect()
@@ -599,6 +613,13 @@ class SocketSpec: QuickSpec {
599613
expect(socket.pendingHeartbeatRef).to(beNil())
600614
})
601615

616+
it("does not schedule heartbeat if skipHeartbeat == true", closure: {
617+
socket.skipHeartbeat = true
618+
socket.resetHeartbeat()
619+
620+
expect(socket.heartbeatTimer).to(beNil())
621+
})
622+
602623
it("creates a timer and sends a heartbeat", closure: {
603624
mockWebSocket.isConnected = true
604625
socket.connect()
@@ -647,13 +668,19 @@ class SocketSpec: QuickSpec {
647668
mockWebSocket = WebSocketClientMock()
648669
mockTimeoutTimer = TimeoutTimerMock()
649670
socket = Socket(endPoint: "/socket", transport: mockWebSocketTransport)
650-
socket.reconnectAfter = { _ in return 10 }
671+
// socket.reconnectAfter = { _ in return 10 }
651672
socket.reconnectTimer = mockTimeoutTimer
652-
socket.skipHeartbeat = true
673+
// socket.skipHeartbeat = true
653674
}
654675

655-
it("does not schedule reconnectTimer timeout if normal close", closure: {
676+
it("schedules reconnectTimer timeout if normal close", closure: {
656677
socket.onConnectionClosed(code: Int(CloseCode.normal.rawValue))
678+
expect(mockTimeoutTimer.scheduleTimeoutCalled).to(beTrue())
679+
})
680+
681+
682+
it("does not schedule reconnectTimer timeout if normal close after explicit disconnect", closure: {
683+
socket.disconnect()
657684
expect(mockTimeoutTimer.scheduleTimeoutCalled).to(beFalse())
658685
})
659686

@@ -662,6 +689,58 @@ class SocketSpec: QuickSpec {
662689
expect(mockTimeoutTimer.scheduleTimeoutCalled).to(beTrue())
663690
})
664691

692+
it("schedules reconnectTimer timeout if connection cannot be made after a previous clean disconnect", closure: {
693+
socket.disconnect()
694+
socket.connect()
695+
696+
socket.onConnectionClosed(code: 1001)
697+
expect(mockTimeoutTimer.scheduleTimeoutCalled).to(beTrue())
698+
})
699+
700+
it("triggers channel error if joining", closure: {
701+
let channel = socket.channel("topic")
702+
var errorCalled = false
703+
channel.on(ChannelEvent.error, callback: { _ in
704+
errorCalled = true
705+
})
706+
707+
channel.join()
708+
expect(channel.state).to(equal(.joining))
709+
710+
socket.onConnectionClosed(code: 1001)
711+
expect(errorCalled).to(beTrue())
712+
})
713+
714+
it("triggers channel error if joined", closure: {
715+
let channel = socket.channel("topic")
716+
var errorCalled = false
717+
channel.on(ChannelEvent.error, callback: { _ in
718+
errorCalled = true
719+
})
720+
721+
channel.join().trigger("ok", payload: [:])
722+
expect(channel.state).to(equal(.joined))
723+
724+
socket.onConnectionClosed(code: 1001)
725+
expect(errorCalled).to(beTrue())
726+
})
727+
728+
it("does not trigger channel error after leave", closure: {
729+
let channel = socket.channel("topic")
730+
var errorCalled = false
731+
channel.on(ChannelEvent.error, callback: { _ in
732+
errorCalled = true
733+
})
734+
735+
channel.join().trigger("ok", payload: [:])
736+
channel.leave()
737+
expect(channel.state).to(equal(.closed))
738+
739+
socket.onConnectionClosed(code: 1001)
740+
expect(errorCalled).to(beFalse())
741+
})
742+
743+
665744
it("triggers onClose callbacks", closure: {
666745
var oneCalled = 0
667746
socket.onClose { oneCalled += 1 }
@@ -675,24 +754,6 @@ class SocketSpec: QuickSpec {
675754
expect(twoCalled).to(equal(1))
676755
expect(threeCalled).to(equal(0))
677756
})
678-
679-
it("triggers channel error", closure: {
680-
let channel = socket.channel("topic")
681-
682-
var errorCalled = false
683-
var errorEvent: String?
684-
var closeCalled = false
685-
channel.on(ChannelEvent.error, callback: { (msg) in
686-
errorEvent = msg.event
687-
errorCalled = true
688-
})
689-
channel.on(ChannelEvent.close, callback: { (_) in closeCalled = true })
690-
691-
socket.onConnectionClosed(code: 1000)
692-
expect(errorCalled).to(beTrue())
693-
expect(errorEvent).to(equal(ChannelEvent.error))
694-
expect(closeCalled).to(beFalse())
695-
})
696757
}
697758

698759
describe("onConnectionError") {
@@ -724,22 +785,48 @@ class SocketSpec: QuickSpec {
724785
expect(lastError).toNot(beNil())
725786
})
726787

727-
it("triggers channel error", closure: {
788+
789+
it("triggers channel error if joining", closure: {
728790
let channel = socket.channel("topic")
791+
var errorCalled = false
792+
channel.on(ChannelEvent.error, callback: { _ in
793+
errorCalled = true
794+
})
795+
796+
channel.join()
797+
expect(channel.state).to(equal(.joining))
729798

799+
socket.onConnectionError(TestError.stub)
800+
expect(errorCalled).to(beTrue())
801+
})
802+
803+
it("triggers channel error if joined", closure: {
804+
let channel = socket.channel("topic")
730805
var errorCalled = false
731-
var errorEvent: String?
732-
var closeCalled = false
733-
channel.on(ChannelEvent.error, callback: { (msg) in
734-
errorEvent = msg.event
806+
channel.on(ChannelEvent.error, callback: { _ in
735807
errorCalled = true
736808
})
737-
channel.on(ChannelEvent.close, callback: { (_) in closeCalled = true })
809+
810+
channel.join().trigger("ok", payload: [:])
811+
expect(channel.state).to(equal(.joined))
738812

739813
socket.onConnectionError(TestError.stub)
740814
expect(errorCalled).to(beTrue())
741-
expect(errorEvent).to(equal(ChannelEvent.error))
742-
expect(closeCalled).to(beFalse())
815+
})
816+
817+
it("does not trigger channel error after leave", closure: {
818+
let channel = socket.channel("topic")
819+
var errorCalled = false
820+
channel.on(ChannelEvent.error, callback: { _ in
821+
errorCalled = true
822+
})
823+
824+
channel.join().trigger("ok", payload: [:])
825+
channel.leave()
826+
expect(channel.state).to(equal(.closed))
827+
828+
socket.onConnectionError(TestError.stub)
829+
expect(errorCalled).to(beFalse())
743830
})
744831
}
745832

0 commit comments

Comments
 (0)