Skip to content

Commit 8903a75

Browse files
committed
zmq4: add option for automatic reconnect
1 parent 04c84de commit 8903a75

File tree

3 files changed

+93
-10
lines changed

3 files changed

+93
-10
lines changed

options.go

+8
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,14 @@ func WithLogger(msg *log.Logger) Option {
5151
}
5252
}
5353

54+
// WithAutomaticReconnect allows to configure a socket to automatically
55+
// reconnect on connection loss.
56+
func WithAutomaticReconnect(automaticReconnect bool) Option {
57+
return func(s *socket) {
58+
s.autoReconnect = automaticReconnect
59+
}
60+
}
61+
5462
/*
5563
// TODO(sbinet)
5664

socket.go

+19-10
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,14 @@ var (
3030

3131
// socket implements the ZeroMQ socket interface
3232
type socket struct {
33-
ep string // socket end-point
34-
typ SocketType
35-
id SocketIdentity
36-
retry time.Duration
37-
sec Security
38-
log *log.Logger
39-
subTopics func() []string
33+
ep string // socket end-point
34+
typ SocketType
35+
id SocketIdentity
36+
retry time.Duration
37+
sec Security
38+
log *log.Logger
39+
subTopics func() []string
40+
autoReconnect bool
4041

4142
mu sync.RWMutex
4243
ids map[string]*Conn // ZMTP connection IDs
@@ -51,8 +52,9 @@ type socket struct {
5152
listener net.Listener
5253
dialer net.Dialer
5354

54-
closedConns []*Conn
55-
reaperCond *sync.Cond
55+
closedConns []*Conn
56+
reaperCond *sync.Cond
57+
reaperStarted bool
5658
}
5759

5860
func newDefaultSocket(ctx context.Context, sockType SocketType) *socket {
@@ -267,7 +269,10 @@ connect:
267269
return fmt.Errorf("zmq4: got a nil ZMTP connection to %q", endpoint)
268270
}
269271

270-
go sck.connReaper()
272+
if !sck.reaperStarted {
273+
go sck.connReaper()
274+
sck.reaperStarted = true
275+
}
271276
sck.addConn(zconn)
272277
return nil
273278
}
@@ -326,6 +331,10 @@ func (sck *socket) scheduleRmConn(c *Conn) {
326331
sck.closedConns = append(sck.closedConns, c)
327332
sck.reaperCond.Signal()
328333
sck.reaperCond.L.Unlock()
334+
335+
if sck.autoReconnect {
336+
sck.Dial(sck.ep)
337+
}
329338
}
330339

331340
// Type returns the type of this Socket (PUB, SUB, ...)

socket_test.go

+66
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ package zmq4_test
66

77
import (
88
"context"
9+
"errors"
910
"fmt"
1011
"io"
1112
"net"
@@ -260,3 +261,68 @@ func TestSocketSendSubscriptionOnConnect(t *testing.T) {
260261
t.Fatalf("invalid message received: got '%s', wanted '%s'", msg.Frames[0], message)
261262
}
262263
}
264+
265+
func TestSocketAutomaticReconnect(t *testing.T) {
266+
listenEndpoint := "tcp://*:1234"
267+
dialEndpoint := "tcp://localhost:1234"
268+
message := "test"
269+
270+
ctx, cancel := context.WithCancel(context.Background())
271+
272+
wg := new(sync.WaitGroup)
273+
defer wg.Wait()
274+
defer cancel()
275+
sendMessages := func(socket zmq4.Socket) {
276+
wg.Add(1)
277+
go func(t *testing.T) {
278+
defer wg.Done()
279+
for {
280+
socket.Send(zmq4.NewMsgFromString([]string{message}))
281+
if ctx.Err() != nil {
282+
return
283+
}
284+
time.Sleep(1 * time.Millisecond)
285+
}
286+
}(t)
287+
}
288+
289+
sub := zmq4.NewSub(context.Background(), zmq4.WithAutomaticReconnect(true))
290+
defer sub.Close()
291+
sub.SetOption(zmq4.OptionSubscribe, message)
292+
pub := zmq4.NewPub(context.Background())
293+
if err := pub.Listen(dialEndpoint); err != nil {
294+
t.Fatalf("Pub Dial failed: %v", err)
295+
}
296+
if err := sub.Dial(listenEndpoint); err != nil {
297+
t.Fatalf("Sub Dial failed: %v", err)
298+
}
299+
300+
sendMessages(pub)
301+
302+
checkConnectionWorking := func(socket zmq4.Socket) {
303+
for {
304+
msg, err := socket.Recv()
305+
if errors.Is(err, io.EOF) {
306+
continue
307+
}
308+
if err != nil {
309+
t.Fatalf("Recv failed: %v", err)
310+
}
311+
if string(msg.Frames[0]) != message {
312+
t.Fatalf("invalid message received: got '%s', wanted '%s'", msg.Frames[0], message)
313+
}
314+
return
315+
}
316+
}
317+
318+
checkConnectionWorking(sub)
319+
pub.Close()
320+
321+
pub2 := zmq4.NewPub(context.Background())
322+
defer pub2.Close()
323+
if err := pub2.Listen(listenEndpoint); err != nil {
324+
t.Fatalf("Sub Listen failed: %v", err)
325+
}
326+
sendMessages(pub2)
327+
checkConnectionWorking(sub)
328+
}

0 commit comments

Comments
 (0)