Skip to content

api: block Connect() on failure if Reconnect > 0 #437

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,19 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
and this project adheres to [Semantic
Versioning](http://semver.org/spec/v2.0.0.html) except to the first release.

## [Unreleased]

### Added

### Changed

- Connect() now retry the connection if a failure occurs and opts.Reconnect > 0.
The number of attempts is equal to opts.MaxReconnects or unlimited if
opts.MaxReconnects == 0. Connect() blocks until a connection is established,
the context is cancelled, or the number of attempts is exhausted (#436).

### Fixed

## [v2.3.0] - 2025-03-11

The release extends box.info responses and ConnectionPool.GetInfo return data.
Expand Down
68 changes: 54 additions & 14 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,24 @@ func (d defaultLogger) Report(event ConnLogKind, conn *Connection, v ...interfac
case LogReconnectFailed:
reconnects := v[0].(uint)
err := v[1].(error)
log.Printf("tarantool: reconnect (%d/%d) to %s failed: %s",
reconnects, conn.opts.MaxReconnects, conn.Addr(), err)
addr := conn.Addr()
if addr == nil {
log.Printf("tarantool: connect (%d/%d) failed: %s",
reconnects, conn.opts.MaxReconnects, err)
} else {
log.Printf("tarantool: reconnect (%d/%d) to %s failed: %s",
reconnects, conn.opts.MaxReconnects, addr, err)
}
case LogLastReconnectFailed:
err := v[0].(error)
log.Printf("tarantool: last reconnect to %s failed: %s, giving it up",
conn.Addr(), err)
addr := conn.Addr()
if addr == nil {
log.Printf("tarantool: last connect failed: %s, giving it up",
err)
} else {
log.Printf("tarantool: last reconnect to %s failed: %s, giving it up",
addr, err)
}
case LogUnexpectedResultId:
header := v[0].(Header)
log.Printf("tarantool: connection %s got unexpected request ID (%d) in response "+
Expand Down Expand Up @@ -362,8 +374,20 @@ func Connect(ctx context.Context, dialer Dialer, opts Opts) (conn *Connection, e

conn.cond = sync.NewCond(&conn.mutex)

if err = conn.createConnection(ctx); err != nil {
return nil, err
if conn.opts.Reconnect > 0 {
// We don't need these mutex.Lock()/mutex.Unlock() here, but
// runReconnects() expects mutex.Lock() to be set, so it's
// easier to add them instead of reworking runReconnects().
conn.mutex.Lock()
err = conn.runReconnects(ctx)
conn.mutex.Unlock()
if err != nil {
return nil, err
}
} else {
if err = conn.connect(ctx); err != nil {
return nil, err
}
}

go conn.pinger()
Expand Down Expand Up @@ -553,7 +577,7 @@ func pack(h *smallWBuf, enc *msgpack.Encoder, reqid uint32,
return
}

func (conn *Connection) createConnection(ctx context.Context) error {
func (conn *Connection) connect(ctx context.Context) error {
var err error
if conn.c == nil && conn.state == connDisconnected {
if err = conn.dial(ctx); err == nil {
Expand Down Expand Up @@ -616,19 +640,30 @@ func (conn *Connection) getDialTimeout() time.Duration {
return dialTimeout
}

func (conn *Connection) runReconnects() error {
func (conn *Connection) runReconnects(ctx context.Context) error {
dialTimeout := conn.getDialTimeout()
var reconnects uint
var err error

t := time.NewTicker(conn.opts.Reconnect)
defer t.Stop()
for conn.opts.MaxReconnects == 0 || reconnects <= conn.opts.MaxReconnects {
now := time.Now()

ctx, cancel := context.WithTimeout(context.Background(), dialTimeout)
err = conn.createConnection(ctx)
localCtx, cancel := context.WithTimeout(ctx, dialTimeout)
err = conn.connect(localCtx)
cancel()

if err != nil {
// The error will most likely be the one that Dialer
// returns to us due to the context being cancelled.
// Although this is not guaranteed. For example,
// if the dialer may throw another error before checking
// the context, and the context has already been
// canceled. Or the context was not canceled after
// the error was thrown, but before the context was
// checked here.
if ctx.Err() != nil {
return err
}
if clientErr, ok := err.(ClientError); ok &&
clientErr.Code == ErrConnectionClosed {
return err
Expand All @@ -642,7 +677,12 @@ func (conn *Connection) runReconnects() error {
reconnects++
conn.mutex.Unlock()

time.Sleep(time.Until(now.Add(conn.opts.Reconnect)))
select {
case <-ctx.Done():
// Since the context is cancelled, we don't need to do anything.
// Conn.connect() will return the correct error.
case <-t.C:
}

conn.mutex.Lock()
}
Expand All @@ -656,7 +696,7 @@ func (conn *Connection) reconnectImpl(neterr error, c Conn) {
if conn.opts.Reconnect > 0 {
if c == conn.c {
conn.closeConnection(neterr, false)
if err := conn.runReconnects(); err != nil {
if err := conn.runReconnects(context.Background()); err != nil {
conn.closeConnection(err, true)
}
}
Expand Down
80 changes: 80 additions & 0 deletions tarantool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3972,6 +3972,86 @@ func TestConnect_context_cancel(t *testing.T) {
}
}

// A dialer that rejects the first few connection requests.
type mockSlowDialer struct {
counter *int
original NetDialer
}

func (m mockSlowDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) {
*m.counter++
if *m.counter < 10 {
return nil, fmt.Errorf("Too early: %v", *m.counter)
}
return m.original.Dial(ctx, opts)
}

func TestConnectIsBlocked(t *testing.T) {
const server = "127.0.0.1:3015"
testDialer := dialer
testDialer.Address = server

inst, err := test_helpers.StartTarantool(test_helpers.StartOpts{
Dialer: testDialer,
InitScript: "config.lua",
Listen: server,
WaitStart: 100 * time.Millisecond,
ConnectRetry: 10,
RetryTimeout: 500 * time.Millisecond,
})
defer test_helpers.StopTarantoolWithCleanup(inst)
if err != nil {
t.Fatalf("Unable to start Tarantool: %s", err)
}

var counter int
mockDialer := mockSlowDialer{original: testDialer, counter: &counter}
ctx, cancel := context.WithTimeout(context.TODO(), 10*time.Second)
defer cancel()

reconnectOpts := opts
reconnectOpts.Reconnect = 100 * time.Millisecond
reconnectOpts.MaxReconnects = 100
conn, err := Connect(ctx, mockDialer, reconnectOpts)
assert.Nil(t, err)
conn.Close()
assert.GreaterOrEqual(t, counter, 10)
}

func TestConnectIsBlockedUntilContextExpires(t *testing.T) {
const server = "127.0.0.1:3015"

testDialer := dialer
testDialer.Address = server

ctx, cancel := test_helpers.GetConnectContext()
defer cancel()

reconnectOpts := opts
reconnectOpts.Reconnect = 100 * time.Millisecond
reconnectOpts.MaxReconnects = 100
_, err := Connect(ctx, testDialer, reconnectOpts)
assert.NotNil(t, err)
assert.ErrorContains(t, err, "failed to dial: dial tcp 127.0.0.1:3015: i/o timeout")
}

func TestConnectIsUnblockedAfterMaxAttempts(t *testing.T) {
const server = "127.0.0.1:3015"

testDialer := dialer
testDialer.Address = server

ctx, cancel := test_helpers.GetConnectContext()
defer cancel()

reconnectOpts := opts
reconnectOpts.Reconnect = 100 * time.Millisecond
reconnectOpts.MaxReconnects = 1
_, err := Connect(ctx, testDialer, reconnectOpts)
assert.NotNil(t, err)
assert.ErrorContains(t, err, "last reconnect failed")
}

func buildSidecar(dir string) error {
goPath, err := exec.LookPath("go")
if err != nil {
Expand Down
Loading