Skip to content

Commit e308511

Browse files
committed
zmq4: add option for automatic reconnect
1 parent 16d169c commit e308511

File tree

3 files changed

+83
-9
lines changed

3 files changed

+83
-9
lines changed

options.go

+8
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,14 @@ func WithLogger(msg *log.Logger) Option {
5151
}
5252
}
5353

54+
// WithAutomaticReconnect allows to configure a socket to automatically
55+
// reconnect on connection loss.
56+
func WithAutomaticReconnect(automaticReconnect bool) Option {
57+
return func(s *socket) {
58+
s.autoReconnect = automaticReconnect
59+
}
60+
}
61+
5462
/*
5563
// TODO(sbinet)
5664

socket.go

+18-9
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,13 @@ var (
3030

3131
// socket implements the ZeroMQ socket interface
3232
type socket struct {
33-
ep string // socket end-point
34-
typ SocketType
35-
id SocketIdentity
36-
retry time.Duration
37-
sec Security
38-
log *log.Logger
33+
ep string // socket end-point
34+
typ SocketType
35+
id SocketIdentity
36+
retry time.Duration
37+
sec Security
38+
log *log.Logger
39+
autoReconnect bool
3940

4041
mu sync.RWMutex
4142
ids map[string]*Conn // ZMTP connection IDs
@@ -50,8 +51,9 @@ type socket struct {
5051
listener net.Listener
5152
dialer net.Dialer
5253

53-
closedConns []*Conn
54-
reaperCond *sync.Cond
54+
closedConns []*Conn
55+
reaperCond *sync.Cond
56+
reaperStarted bool
5557
}
5658

5759
func newDefaultSocket(ctx context.Context, sockType SocketType) *socket {
@@ -266,7 +268,10 @@ connect:
266268
return fmt.Errorf("zmq4: got a nil ZMTP connection to %q", endpoint)
267269
}
268270

269-
go sck.connReaper()
271+
if !sck.reaperStarted {
272+
go sck.connReaper()
273+
sck.reaperStarted = true
274+
}
270275
sck.addConn(zconn)
271276
return nil
272277
}
@@ -319,6 +324,10 @@ func (sck *socket) scheduleRmConn(c *Conn) {
319324
sck.closedConns = append(sck.closedConns, c)
320325
sck.reaperCond.Signal()
321326
sck.reaperCond.L.Unlock()
327+
328+
if sck.autoReconnect {
329+
sck.Dial(sck.ep)
330+
}
322331
}
323332

324333
// Type returns the type of this Socket (PUB, SUB, ...)

socket_test.go

+57
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"fmt"
1010
"io"
1111
"net"
12+
"sync"
1213
"testing"
1314
"time"
1415

@@ -220,3 +221,59 @@ func TestConnReaperDeadlock(t *testing.T) {
220221
clients[i].Close()
221222
}
222223
}
224+
225+
func TestSocketSendSubscriptionOnConnect(t *testing.T) {
226+
endpoint := "tcp://*:1234"
227+
// endpoint := "inproc://test-resub"
228+
message := "test"
229+
230+
ctx, cancel := context.WithCancel(context.Background())
231+
232+
sub := zmq4.NewSub(context.Background())
233+
pub := zmq4.NewPub(context.Background(), zmq4.WithAutomaticReconnect(true))
234+
defer pub.Close()
235+
if err := sub.Listen(endpoint); err != nil {
236+
t.Fatalf("Sub Dial failed: %v", err)
237+
}
238+
if err := pub.Dial("tcp://localhost:1234"); err != nil {
239+
t.Fatalf("Pub Dial failed: %v", err)
240+
}
241+
sub.SetOption(zmq4.OptionSubscribe, message)
242+
243+
wg := new(sync.WaitGroup)
244+
defer wg.Wait()
245+
defer cancel()
246+
wg.Add(1)
247+
go func(t *testing.T) {
248+
defer wg.Done()
249+
for {
250+
pub.Send(zmq4.NewMsgFromString([]string{message}))
251+
if ctx.Err() != nil {
252+
return
253+
}
254+
time.Sleep(1 * time.Millisecond)
255+
}
256+
}(t)
257+
258+
checkConnectionWorking := func(socket zmq4.Socket) {
259+
msg, err := socket.Recv()
260+
if err != nil {
261+
t.Fatalf("Recv failed: %v", err)
262+
}
263+
if string(msg.Frames[0]) != message {
264+
t.Fatalf("invalid message received: got '%s', wanted '%s'", msg.Frames[0], message)
265+
}
266+
}
267+
268+
checkConnectionWorking(sub)
269+
sub.Close()
270+
271+
sub2 := zmq4.NewSub(context.Background())
272+
defer sub2.Close()
273+
if err := sub2.Listen(endpoint); err != nil {
274+
t.Fatalf("Sub Listen failed: %v", err)
275+
}
276+
time.Sleep(10 * time.Millisecond)
277+
sub2.SetOption(zmq4.OptionSubscribe, message)
278+
checkConnectionWorking(sub2)
279+
}

0 commit comments

Comments
 (0)