Skip to content

Commit

Permalink
api: add context to connection create
Browse files Browse the repository at this point in the history
`connection.Connect` and `pool.Connect` no longer return non-working
connection objects. Those functions now accept context as their first
arguments, which user may cancel in process.

`connection.Connect` will block until either the working connection
created (and returned), `opts.MaxReconnects` creation attempts
were made (returns error) or the context is canceled by user
(returns error too).

Closes #136
  • Loading branch information
DerekBum committed Oct 3, 2023
1 parent d8df65d commit 4c4ac1c
Show file tree
Hide file tree
Showing 29 changed files with 416 additions and 176 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ Versioning](http://semver.org/spec/v2.0.0.html) except to the first release.
decoded to a varbinary object (#313).
- Use objects of the Decimal type instead of pointers (#238)
- Use objects of the Datetime type instead of pointers (#238)
- `connection.Connect` and `pool.Connect` no longer return non-working
connection objects (#136). Those functions now accept context as their first
arguments, which user may cancel in process.

### Deprecated

Expand Down
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,15 @@ about what it does.
package tarantool

import (
"context"
"fmt"
"github.com/tarantool/go-tarantool/v2"
)

func main() {
opts := tarantool.Opts{User: "guest"}
conn, err := tarantool.Connect("127.0.0.1:3301", opts)
ctx := context.Background()
conn, err := tarantool.Connect(ctx, "127.0.0.1:3301", opts)
if err != nil {
fmt.Println("Connection refused:", err)
}
Expand Down
67 changes: 32 additions & 35 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -381,10 +381,11 @@ func (opts Opts) Clone() Opts {
// - If opts.Reconnect is zero (default), then connection either already connected
// or error is returned.
//
// - If opts.Reconnect is non-zero, then error will be returned only if authorization
// fails. But if Tarantool is not reachable, then it will make an attempt to reconnect later
// and will not finish to make attempts on authorization failures.
func Connect(addr string, opts Opts) (conn *Connection, err error) {
// - If opts.Reconnect is non-zero, then error will be returned if authorization
// fails, or user has canceled context. If Tarantool is not reachable, then it
// will make attempts to reconnect and will not finish to make attempts on
// authorization failures.
func Connect(ctx context.Context, addr string, opts Opts) (conn *Connection, err error) {
conn = &Connection{
addr: addr,
requestId: 0,
Expand Down Expand Up @@ -432,25 +433,8 @@ func Connect(addr string, opts Opts) (conn *Connection, err error) {

conn.cond = sync.NewCond(&conn.mutex)

if err = conn.createConnection(false); err != nil {
ter, ok := err.(Error)
if conn.opts.Reconnect <= 0 {
return nil, err
} else if ok && (ter.Code == iproto.ER_NO_SUCH_USER ||
ter.Code == iproto.ER_CREDS_MISMATCH) {
// Reported auth errors immediately.
return nil, err
} else {
// Without SkipSchema it is useless.
go func(conn *Connection) {
conn.mutex.Lock()
defer conn.mutex.Unlock()
if err := conn.createConnection(true); err != nil {
conn.closeConnection(err, true)
}
}(conn)
err = nil
}
if err = conn.createConnection(ctx, false); err != nil {
return nil, err
}

go conn.pinger()
Expand Down Expand Up @@ -534,18 +518,19 @@ func (conn *Connection) cancelFuture(fut *Future, err error) {
}
}

func (conn *Connection) dial() (err error) {
func (conn *Connection) dial(ctx context.Context) (err error) {
opts := conn.opts
dialTimeout := opts.Reconnect / 2
if dialTimeout == 0 {
dialTimeout = 500 * time.Millisecond
} else if dialTimeout > 5*time.Second {
dialTimeout = 5 * time.Second
}
nestedCtx, cancel := context.WithTimeout(ctx, dialTimeout)
defer cancel()

var c Conn
c, err = conn.opts.Dialer.Dial(conn.addr, DialOpts{
DialTimeout: dialTimeout,
c, err = conn.opts.Dialer.Dial(nestedCtx, conn.addr, DialOpts{
IoTimeout: opts.Timeout,
Transport: opts.Transport,
Ssl: opts.Ssl,
Expand Down Expand Up @@ -658,34 +643,46 @@ func pack(h *smallWBuf, enc *msgpack.Encoder, reqid uint32,
return
}

func (conn *Connection) createConnection(reconnect bool) (err error) {
func (conn *Connection) createConnection(ctx context.Context,
reconnect bool) error {
var reconnects uint
for conn.c == nil && conn.state == connDisconnected {
now := time.Now()
err = conn.dial()
err := conn.dial(ctx)
if err == nil || !reconnect {
if err == nil {
conn.notify(Connected)
}
return
return err
}
if conn.opts.MaxReconnects > 0 && reconnects > conn.opts.MaxReconnects {
conn.opts.Logger.Report(LogLastReconnectFailed, conn, err)
err = ClientError{ErrConnectionClosed, "last reconnect failed"}
// mark connection as closed to avoid reopening by another goroutine
return
return ClientError{ErrConnectionClosed, "last reconnect failed"}
}
conn.opts.Logger.Report(LogReconnectFailed, conn, reconnects, err)
conn.notify(ReconnectFailed)
reconnects++
conn.mutex.Unlock()
time.Sleep(time.Until(now.Add(conn.opts.Reconnect)))

timer := time.NewTimer(time.Until(now.Add(conn.opts.Reconnect)))
waitLoop:
for {
select {
case <-ctx.Done():
conn.mutex.Lock()
return ClientError{ErrConnectionClosed, "context is canceled"}
case <-timer.C:
break waitLoop
}
}

conn.mutex.Lock()
}
if conn.state == connClosed {
err = ClientError{ErrConnectionClosed, "using closed connection"}
return ClientError{ErrConnectionClosed, "using closed connection"}
}
return
return nil
}

func (conn *Connection) closeConnection(neterr error, forever bool) (err error) {
Expand Down Expand Up @@ -731,7 +728,7 @@ func (conn *Connection) reconnectImpl(neterr error, c Conn) {
if conn.opts.Reconnect > 0 {
if c == conn.c {
conn.closeConnection(neterr, false)
if err := conn.createConnection(true); err != nil {
if err := conn.createConnection(context.Background(), true); err != nil {
conn.closeConnection(err, true)
}
}
Expand Down
3 changes: 2 additions & 1 deletion crud/example_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package crud_test

import (
"context"
"fmt"
"reflect"
"time"
Expand All @@ -21,7 +22,7 @@ var exampleOpts = tarantool.Opts{
}

func exampleConnect() *tarantool.Connection {
conn, err := tarantool.Connect(exampleServer, exampleOpts)
conn, err := tarantool.Connect(context.Background(), exampleServer, exampleOpts)
if err != nil {
panic("Connection is not established: " + err.Error())
}
Expand Down
3 changes: 2 additions & 1 deletion crud/tarantool_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package crud_test

import (
"context"
"fmt"
"log"
"os"
Expand Down Expand Up @@ -108,7 +109,7 @@ var object = crud.MapObject{

func connect(t testing.TB) *tarantool.Connection {
for i := 0; i < 10; i++ {
conn, err := tarantool.Connect(server, opts)
conn, err := tarantool.Connect(context.Background(), server, opts)
if err != nil {
t.Fatalf("Failed to connect: %s", err)
}
Expand Down
3 changes: 2 additions & 1 deletion datetime/example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
package datetime_test

import (
"context"
"fmt"
"time"

Expand All @@ -23,7 +24,7 @@ func Example() {
User: "test",
Pass: "test",
}
conn, err := tarantool.Connect("127.0.0.1:3013", opts)
conn, err := tarantool.Connect(context.Background(), "127.0.0.1:3013", opts)
if err != nil {
fmt.Printf("Error in connect is %v", err)
return
Expand Down
3 changes: 2 additions & 1 deletion decimal/example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
package decimal_test

import (
"context"
"log"
"time"

Expand All @@ -28,7 +29,7 @@ func Example() {
User: "test",
Pass: "test",
}
client, err := tarantool.Connect(server, opts)
client, err := tarantool.Connect(context.Background(), server, opts)
if err != nil {
log.Fatalf("Failed to connect: %s", err.Error())
}
Expand Down
16 changes: 8 additions & 8 deletions dial.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package tarantool
import (
"bufio"
"bytes"
"context"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -56,8 +57,6 @@ type Conn interface {

// DialOpts is a way to configure a Dial method to create a new Conn.
type DialOpts struct {
// DialTimeout is a timeout for an initial network dial.
DialTimeout time.Duration
// IoTimeout is a timeout per a network read/write.
IoTimeout time.Duration
// Transport is a connect transport type.
Expand Down Expand Up @@ -86,7 +85,7 @@ type DialOpts struct {
type Dialer interface {
// Dial connects to a Tarantool instance to the address with specified
// options.
Dial(address string, opts DialOpts) (Conn, error)
Dial(ctx context.Context, address string, opts DialOpts) (Conn, error)
}

type tntConn struct {
Expand All @@ -104,11 +103,11 @@ type TtDialer struct {

// Dial connects to a Tarantool instance to the address with specified
// options.
func (t TtDialer) Dial(address string, opts DialOpts) (Conn, error) {
func (t TtDialer) Dial(ctx context.Context, address string, opts DialOpts) (Conn, error) {
var err error
conn := new(tntConn)

if conn.net, err = dial(address, opts); err != nil {
if conn.net, err = dial(ctx, address, opts); err != nil {
return nil, fmt.Errorf("failed to dial: %w", err)
}

Expand Down Expand Up @@ -199,13 +198,14 @@ func (c *tntConn) ProtocolInfo() ProtocolInfo {
}

// dial connects to a Tarantool instance.
func dial(address string, opts DialOpts) (net.Conn, error) {
func dial(ctx context.Context, address string, opts DialOpts) (net.Conn, error) {
network, address := parseAddress(address)
switch opts.Transport {
case dialTransportNone:
return net.DialTimeout(network, address, opts.DialTimeout)
dialer := net.Dialer{}
return dialer.DialContext(ctx, network, address)
case dialTransportSsl:
return sslDialTimeout(network, address, opts.DialTimeout, opts.Ssl)
return sslDialContext(ctx, network, address, opts.Ssl)
default:
return nil, fmt.Errorf("unsupported transport type: %s", opts.Transport)
}
Expand Down
34 changes: 19 additions & 15 deletions dial_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package tarantool_test

import (
"bytes"
"context"
"errors"
"net"
"sync"
Expand All @@ -18,7 +19,7 @@ type mockErrorDialer struct {
err error
}

func (m mockErrorDialer) Dial(address string,
func (m mockErrorDialer) Dial(ctx context.Context, address string,
opts tarantool.DialOpts) (tarantool.Conn, error) {
return nil, m.err
}
Expand All @@ -29,9 +30,10 @@ func TestDialer_Dial_error(t *testing.T) {
err: errors.New(errMsg),
}

conn, err := tarantool.Connect("any", tarantool.Opts{
Dialer: dialer,
})
conn, err := tarantool.Connect(context.Background(), "any",
tarantool.Opts{
Dialer: dialer,
})
assert.Nil(t, conn)
assert.ErrorContains(t, err, errMsg)
}
Expand All @@ -41,7 +43,7 @@ type mockPassedDialer struct {
opts tarantool.DialOpts
}

func (m *mockPassedDialer) Dial(address string,
func (m *mockPassedDialer) Dial(ctx context.Context, address string,
opts tarantool.DialOpts) (tarantool.Conn, error) {
m.address = address
m.opts = opts
Expand All @@ -51,9 +53,8 @@ func (m *mockPassedDialer) Dial(address string,
func TestDialer_Dial_passedOpts(t *testing.T) {
const addr = "127.0.0.1:8080"
opts := tarantool.DialOpts{
DialTimeout: 500 * time.Millisecond,
IoTimeout: 2,
Transport: "any",
IoTimeout: 2,
Transport: "any",
Ssl: tarantool.SslOpts{
KeyFile: "a",
CertFile: "b",
Expand All @@ -73,7 +74,9 @@ func TestDialer_Dial_passedOpts(t *testing.T) {
}

dialer := &mockPassedDialer{}
conn, err := tarantool.Connect(addr, tarantool.Opts{
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
conn, err := tarantool.Connect(ctx, addr, tarantool.Opts{
Dialer: dialer,
Timeout: opts.IoTimeout,
Transport: opts.Transport,
Expand Down Expand Up @@ -187,7 +190,7 @@ func newMockIoConn() *mockIoConn {
return conn
}

func (m *mockIoDialer) Dial(address string,
func (m *mockIoDialer) Dial(ctx context.Context, address string,
opts tarantool.DialOpts) (tarantool.Conn, error) {
m.conn = newMockIoConn()
if m.init != nil {
Expand All @@ -203,11 +206,12 @@ func dialIo(t *testing.T,
dialer := mockIoDialer{
init: init,
}
conn, err := tarantool.Connect("any", tarantool.Opts{
Dialer: &dialer,
Timeout: 1000 * time.Second, // Avoid pings.
SkipSchema: true,
})
conn, err := tarantool.Connect(context.Background(), "any",
tarantool.Opts{
Dialer: &dialer,
Timeout: 1000 * time.Second, // Avoid pings.
SkipSchema: true,
})
require.Nil(t, err)
require.NotNil(t, conn)

Expand Down
3 changes: 2 additions & 1 deletion example_custom_unpacking_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package tarantool_test

import (
"context"
"fmt"
"log"
"time"
Expand Down Expand Up @@ -84,7 +85,7 @@ func Example_customUnpacking() {
User: "test",
Pass: "test",
}
conn, err := tarantool.Connect(server, opts)
conn, err := tarantool.Connect(context.Background(), server, opts)
if err != nil {
log.Fatalf("Failed to connect: %s", err.Error())
}
Expand Down
Loading

0 comments on commit 4c4ac1c

Please sign in to comment.