diff --git a/CHANGELOG.md b/CHANGELOG.md index 211ea25a2..58e6878c1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,9 @@ Versioning](http://semver.org/spec/v2.0.0.html) except to the first release. decoded to a varbinary object (#313). - Use objects of the Decimal type instead of pointers (#238) - Use objects of the Datetime type instead of pointers (#238) +- `connection.Connect` and `pool.Connect` no longer return non-working + connection objects (#136). Those functions now accept context as their first + arguments, which user may cancel in process. ### Deprecated diff --git a/README.md b/README.md index aa4c6deac..0b4520206 100644 --- a/README.md +++ b/README.md @@ -105,13 +105,15 @@ about what it does. package tarantool import ( + "context" "fmt" "github.com/tarantool/go-tarantool/v2" ) func main() { opts := tarantool.Opts{User: "guest"} - conn, err := tarantool.Connect("127.0.0.1:3301", opts) + ctx := context.Background() + conn, err := tarantool.Connect(ctx, "127.0.0.1:3301", opts) if err != nil { fmt.Println("Connection refused:", err) } diff --git a/connection.go b/connection.go index 9bb42626a..97cb9f4c5 100644 --- a/connection.go +++ b/connection.go @@ -381,10 +381,11 @@ func (opts Opts) Clone() Opts { // - If opts.Reconnect is zero (default), then connection either already connected // or error is returned. // -// - If opts.Reconnect is non-zero, then error will be returned only if authorization -// fails. But if Tarantool is not reachable, then it will make an attempt to reconnect later -// and will not finish to make attempts on authorization failures. -func Connect(addr string, opts Opts) (conn *Connection, err error) { +// - If opts.Reconnect is non-zero, then error will be returned if authorization +// fails, or user has canceled context. If Tarantool is not reachable, then it +// will make attempts to reconnect and will not finish to make attempts on +// authorization failures. +func Connect(ctx context.Context, addr string, opts Opts) (conn *Connection, err error) { conn = &Connection{ addr: addr, requestId: 0, @@ -432,25 +433,8 @@ func Connect(addr string, opts Opts) (conn *Connection, err error) { conn.cond = sync.NewCond(&conn.mutex) - if err = conn.createConnection(false); err != nil { - ter, ok := err.(Error) - if conn.opts.Reconnect <= 0 { - return nil, err - } else if ok && (ter.Code == iproto.ER_NO_SUCH_USER || - ter.Code == iproto.ER_CREDS_MISMATCH) { - // Reported auth errors immediately. - return nil, err - } else { - // Without SkipSchema it is useless. - go func(conn *Connection) { - conn.mutex.Lock() - defer conn.mutex.Unlock() - if err := conn.createConnection(true); err != nil { - conn.closeConnection(err, true) - } - }(conn) - err = nil - } + if err = conn.createConnection(ctx, false); err != nil { + return nil, err } go conn.pinger() @@ -534,7 +518,7 @@ func (conn *Connection) cancelFuture(fut *Future, err error) { } } -func (conn *Connection) dial() (err error) { +func (conn *Connection) dial(ctx context.Context) (err error) { opts := conn.opts dialTimeout := opts.Reconnect / 2 if dialTimeout == 0 { @@ -542,10 +526,11 @@ func (conn *Connection) dial() (err error) { } else if dialTimeout > 5*time.Second { dialTimeout = 5 * time.Second } + nestedCtx, cancel := context.WithTimeout(ctx, dialTimeout) + defer cancel() var c Conn - c, err = conn.opts.Dialer.Dial(conn.addr, DialOpts{ - DialTimeout: dialTimeout, + c, err = conn.opts.Dialer.Dial(nestedCtx, conn.addr, DialOpts{ IoTimeout: opts.Timeout, Transport: opts.Transport, Ssl: opts.Ssl, @@ -658,34 +643,46 @@ func pack(h *smallWBuf, enc *msgpack.Encoder, reqid uint32, return } -func (conn *Connection) createConnection(reconnect bool) (err error) { +func (conn *Connection) createConnection(ctx context.Context, + reconnect bool) error { var reconnects uint for conn.c == nil && conn.state == connDisconnected { now := time.Now() - err = conn.dial() + err := conn.dial(ctx) if err == nil || !reconnect { if err == nil { conn.notify(Connected) } - return + return err } if conn.opts.MaxReconnects > 0 && reconnects > conn.opts.MaxReconnects { conn.opts.Logger.Report(LogLastReconnectFailed, conn, err) - err = ClientError{ErrConnectionClosed, "last reconnect failed"} // mark connection as closed to avoid reopening by another goroutine - return + return ClientError{ErrConnectionClosed, "last reconnect failed"} } conn.opts.Logger.Report(LogReconnectFailed, conn, reconnects, err) conn.notify(ReconnectFailed) reconnects++ conn.mutex.Unlock() - time.Sleep(time.Until(now.Add(conn.opts.Reconnect))) + + timer := time.NewTimer(time.Until(now.Add(conn.opts.Reconnect))) + waitLoop: + for { + select { + case <-ctx.Done(): + conn.mutex.Lock() + return ClientError{ErrConnectionClosed, "operation was canceled"} + case <-timer.C: + break waitLoop + } + } + conn.mutex.Lock() } if conn.state == connClosed { - err = ClientError{ErrConnectionClosed, "using closed connection"} + return ClientError{ErrConnectionClosed, "using closed connection"} } - return + return nil } func (conn *Connection) closeConnection(neterr error, forever bool) (err error) { @@ -731,7 +728,7 @@ func (conn *Connection) reconnectImpl(neterr error, c Conn) { if conn.opts.Reconnect > 0 { if c == conn.c { conn.closeConnection(neterr, false) - if err := conn.createConnection(true); err != nil { + if err := conn.createConnection(context.Background(), true); err != nil { conn.closeConnection(err, true) } } diff --git a/crud/example_test.go b/crud/example_test.go index 363d0570d..567d2dfeb 100644 --- a/crud/example_test.go +++ b/crud/example_test.go @@ -1,6 +1,7 @@ package crud_test import ( + "context" "fmt" "reflect" "time" @@ -21,7 +22,7 @@ var exampleOpts = tarantool.Opts{ } func exampleConnect() *tarantool.Connection { - conn, err := tarantool.Connect(exampleServer, exampleOpts) + conn, err := tarantool.Connect(context.Background(), exampleServer, exampleOpts) if err != nil { panic("Connection is not established: " + err.Error()) } diff --git a/crud/tarantool_test.go b/crud/tarantool_test.go index 5cf29f66a..c88a228ac 100644 --- a/crud/tarantool_test.go +++ b/crud/tarantool_test.go @@ -1,6 +1,7 @@ package crud_test import ( + "context" "fmt" "log" "os" @@ -108,7 +109,7 @@ var object = crud.MapObject{ func connect(t testing.TB) *tarantool.Connection { for i := 0; i < 10; i++ { - conn, err := tarantool.Connect(server, opts) + conn, err := tarantool.Connect(context.Background(), server, opts) if err != nil { t.Fatalf("Failed to connect: %s", err) } diff --git a/datetime/example_test.go b/datetime/example_test.go index 346551629..2f63208a2 100644 --- a/datetime/example_test.go +++ b/datetime/example_test.go @@ -9,6 +9,7 @@ package datetime_test import ( + "context" "fmt" "time" @@ -23,7 +24,7 @@ func Example() { User: "test", Pass: "test", } - conn, err := tarantool.Connect("127.0.0.1:3013", opts) + conn, err := tarantool.Connect(context.Background(), "127.0.0.1:3013", opts) if err != nil { fmt.Printf("Error in connect is %v", err) return diff --git a/decimal/example_test.go b/decimal/example_test.go index a355767f1..0e32a51d2 100644 --- a/decimal/example_test.go +++ b/decimal/example_test.go @@ -9,6 +9,7 @@ package decimal_test import ( + "context" "log" "time" @@ -28,7 +29,7 @@ func Example() { User: "test", Pass: "test", } - client, err := tarantool.Connect(server, opts) + client, err := tarantool.Connect(context.Background(), server, opts) if err != nil { log.Fatalf("Failed to connect: %s", err.Error()) } diff --git a/dial.go b/dial.go index 3ba493ac7..5b17c0534 100644 --- a/dial.go +++ b/dial.go @@ -3,6 +3,7 @@ package tarantool import ( "bufio" "bytes" + "context" "errors" "fmt" "io" @@ -56,8 +57,6 @@ type Conn interface { // DialOpts is a way to configure a Dial method to create a new Conn. type DialOpts struct { - // DialTimeout is a timeout for an initial network dial. - DialTimeout time.Duration // IoTimeout is a timeout per a network read/write. IoTimeout time.Duration // Transport is a connect transport type. @@ -86,7 +85,7 @@ type DialOpts struct { type Dialer interface { // Dial connects to a Tarantool instance to the address with specified // options. - Dial(address string, opts DialOpts) (Conn, error) + Dial(ctx context.Context, address string, opts DialOpts) (Conn, error) } type tntConn struct { @@ -104,11 +103,11 @@ type TtDialer struct { // Dial connects to a Tarantool instance to the address with specified // options. -func (t TtDialer) Dial(address string, opts DialOpts) (Conn, error) { +func (t TtDialer) Dial(ctx context.Context, address string, opts DialOpts) (Conn, error) { var err error conn := new(tntConn) - if conn.net, err = dial(address, opts); err != nil { + if conn.net, err = dial(ctx, address, opts); err != nil { return nil, fmt.Errorf("failed to dial: %w", err) } @@ -199,13 +198,14 @@ func (c *tntConn) ProtocolInfo() ProtocolInfo { } // dial connects to a Tarantool instance. -func dial(address string, opts DialOpts) (net.Conn, error) { +func dial(ctx context.Context, address string, opts DialOpts) (net.Conn, error) { network, address := parseAddress(address) switch opts.Transport { case dialTransportNone: - return net.DialTimeout(network, address, opts.DialTimeout) + dialer := net.Dialer{} + return dialer.DialContext(ctx, network, address) case dialTransportSsl: - return sslDialTimeout(network, address, opts.DialTimeout, opts.Ssl) + return sslDialContext(ctx, network, address, opts.Ssl) default: return nil, fmt.Errorf("unsupported transport type: %s", opts.Transport) } diff --git a/dial_test.go b/dial_test.go index ff8a50aab..03c470cde 100644 --- a/dial_test.go +++ b/dial_test.go @@ -2,6 +2,7 @@ package tarantool_test import ( "bytes" + "context" "errors" "net" "sync" @@ -18,7 +19,7 @@ type mockErrorDialer struct { err error } -func (m mockErrorDialer) Dial(address string, +func (m mockErrorDialer) Dial(ctx context.Context, address string, opts tarantool.DialOpts) (tarantool.Conn, error) { return nil, m.err } @@ -29,9 +30,10 @@ func TestDialer_Dial_error(t *testing.T) { err: errors.New(errMsg), } - conn, err := tarantool.Connect("any", tarantool.Opts{ - Dialer: dialer, - }) + conn, err := tarantool.Connect(context.Background(), "any", + tarantool.Opts{ + Dialer: dialer, + }) assert.Nil(t, conn) assert.ErrorContains(t, err, errMsg) } @@ -41,7 +43,7 @@ type mockPassedDialer struct { opts tarantool.DialOpts } -func (m *mockPassedDialer) Dial(address string, +func (m *mockPassedDialer) Dial(ctx context.Context, address string, opts tarantool.DialOpts) (tarantool.Conn, error) { m.address = address m.opts = opts @@ -51,9 +53,8 @@ func (m *mockPassedDialer) Dial(address string, func TestDialer_Dial_passedOpts(t *testing.T) { const addr = "127.0.0.1:8080" opts := tarantool.DialOpts{ - DialTimeout: 500 * time.Millisecond, - IoTimeout: 2, - Transport: "any", + IoTimeout: 2, + Transport: "any", Ssl: tarantool.SslOpts{ KeyFile: "a", CertFile: "b", @@ -73,7 +74,9 @@ func TestDialer_Dial_passedOpts(t *testing.T) { } dialer := &mockPassedDialer{} - conn, err := tarantool.Connect(addr, tarantool.Opts{ + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + conn, err := tarantool.Connect(ctx, addr, tarantool.Opts{ Dialer: dialer, Timeout: opts.IoTimeout, Transport: opts.Transport, @@ -187,7 +190,7 @@ func newMockIoConn() *mockIoConn { return conn } -func (m *mockIoDialer) Dial(address string, +func (m *mockIoDialer) Dial(ctx context.Context, address string, opts tarantool.DialOpts) (tarantool.Conn, error) { m.conn = newMockIoConn() if m.init != nil { @@ -203,11 +206,12 @@ func dialIo(t *testing.T, dialer := mockIoDialer{ init: init, } - conn, err := tarantool.Connect("any", tarantool.Opts{ - Dialer: &dialer, - Timeout: 1000 * time.Second, // Avoid pings. - SkipSchema: true, - }) + conn, err := tarantool.Connect(context.Background(), "any", + tarantool.Opts{ + Dialer: &dialer, + Timeout: 1000 * time.Second, // Avoid pings. + SkipSchema: true, + }) require.Nil(t, err) require.NotNil(t, conn) diff --git a/example_custom_unpacking_test.go b/example_custom_unpacking_test.go index 1189e16a3..bd60fd901 100644 --- a/example_custom_unpacking_test.go +++ b/example_custom_unpacking_test.go @@ -1,6 +1,7 @@ package tarantool_test import ( + "context" "fmt" "log" "time" @@ -84,7 +85,7 @@ func Example_customUnpacking() { User: "test", Pass: "test", } - conn, err := tarantool.Connect(server, opts) + conn, err := tarantool.Connect(context.Background(), server, opts) if err != nil { log.Fatalf("Failed to connect: %s", err.Error()) } diff --git a/example_test.go b/example_test.go index d871d578a..64d471e93 100644 --- a/example_test.go +++ b/example_test.go @@ -19,7 +19,7 @@ type Tuple struct { } func exampleConnect(opts tarantool.Opts) *tarantool.Connection { - conn, err := tarantool.Connect(server, opts) + conn, err := tarantool.Connect(context.Background(), server, opts) if err != nil { panic("Connection is not established: " + err.Error()) } @@ -38,7 +38,7 @@ func ExampleSslOpts() { CaFile: "testdata/ca.crt", }, } - _, err := tarantool.Connect("127.0.0.1:3013", opts) + _, err := tarantool.Connect(context.Background(), "127.0.0.1:3013", opts) if err != nil { panic("Connection is not established: " + err.Error()) } @@ -913,12 +913,51 @@ func ExampleFuture_GetIterator() { } func ExampleConnect() { - conn, err := tarantool.Connect("127.0.0.1:3013", tarantool.Opts{ - Timeout: 5 * time.Second, - User: "test", - Pass: "test", - Concurrency: 32, - }) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + conn, err := tarantool.Connect(ctx, "127.0.0.1:3013", + tarantool.Opts{ + Timeout: 5 * time.Second, + User: "test", + Pass: "test", + Concurrency: 32, + }) + if err != nil { + fmt.Println("No connection available") + return + } + defer conn.Close() + if conn != nil { + fmt.Println("Connection is ready") + } + // Output: + // Connection is ready +} + +func ExampleConnect_reconnects() { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + reconnect := 500 * time.Millisecond + maxReconnects := 10 + + var conn *tarantool.Connection + var err error + + for i := 0; i < maxReconnects; i++ { + conn, err = tarantool.Connect(ctx, "127.0.0.1:3013", + tarantool.Opts{ + Timeout: 5 * time.Second, + User: "test", + Pass: "test", + Concurrency: 32, + }) + if err == nil { + break + } + time.Sleep(reconnect) + } if err != nil { fmt.Println("No connection available") return @@ -1081,7 +1120,7 @@ func ExampleConnection_NewPrepared() { User: "test", Pass: "test", } - conn, err := tarantool.Connect(server, opts) + conn, err := tarantool.Connect(context.Background(), server, opts) if err != nil { fmt.Printf("Failed to connect: %s", err.Error()) } @@ -1127,7 +1166,7 @@ func ExampleConnection_NewWatcher() { Features: []tarantool.ProtocolFeature{tarantool.WatchersFeature}, }, } - conn, err := tarantool.Connect(server, opts) + conn, err := tarantool.Connect(context.Background(), server, opts) if err != nil { fmt.Printf("Failed to connect: %s\n", err) return diff --git a/export_test.go b/export_test.go index 10d194840..adb67976c 100644 --- a/export_test.go +++ b/export_test.go @@ -1,6 +1,7 @@ package tarantool import ( + "context" "net" "time" @@ -12,6 +13,11 @@ func SslDialTimeout(network, address string, timeout time.Duration, return sslDialTimeout(network, address, timeout, opts) } +func SslDialContext(ctx context.Context, network, address string, + opts SslOpts) (connection net.Conn, err error) { + return sslDialContext(ctx, network, address, opts) +} + func SslCreateContext(opts SslOpts) (ctx interface{}, err error) { return sslCreateContext(opts) } diff --git a/go.mod b/go.mod index bd848308c..68f46c634 100644 --- a/go.mod +++ b/go.mod @@ -8,7 +8,7 @@ require ( github.com/shopspring/decimal v1.3.1 github.com/stretchr/testify v1.7.1 github.com/tarantool/go-iproto v0.1.0 - github.com/tarantool/go-openssl v0.0.8-0.20230801114713-b452431f934a + github.com/tarantool/go-openssl v0.0.8-0.20231002130016-e88579e113cf github.com/vmihailenco/msgpack/v5 v5.3.5 golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4 // indirect gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f // indirect diff --git a/go.sum b/go.sum index 1810c2b3a..414b13c56 100644 --- a/go.sum +++ b/go.sum @@ -21,8 +21,8 @@ github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMT github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/tarantool/go-iproto v0.1.0 h1:zHN9AA8LDawT+JBD0/Nxgr/bIsWkkpDzpcMuaNPSIAQ= github.com/tarantool/go-iproto v0.1.0/go.mod h1:LNCtdyZxojUed8SbOiYHoc3v9NvaZTB7p96hUySMlIo= -github.com/tarantool/go-openssl v0.0.8-0.20230801114713-b452431f934a h1:eeElglRXJ3xWKkHmDbeXrQWlZyQ4t3Ca1YlZsrfdXFU= -github.com/tarantool/go-openssl v0.0.8-0.20230801114713-b452431f934a/go.mod h1:M7H4xYSbzqpW/ZRBMyH0eyqQBsnhAMfsYk5mv0yid7A= +github.com/tarantool/go-openssl v0.0.8-0.20231002130016-e88579e113cf h1:oCQZliFthJ2j/4TgD3PFwazfZjsn+wCA4xOLi2yO7cI= +github.com/tarantool/go-openssl v0.0.8-0.20231002130016-e88579e113cf/go.mod h1:M7H4xYSbzqpW/ZRBMyH0eyqQBsnhAMfsYk5mv0yid7A= github.com/vmihailenco/msgpack/v5 v5.3.5 h1:5gO0H1iULLWGhs2H5tbAHIZTV8/cYafcFOr9znI5mJU= github.com/vmihailenco/msgpack/v5 v5.3.5/go.mod h1:7xyJ9e+0+9SaZT0Wt1RGleJXzli6Q/V5KbhBonMG9jc= github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g= diff --git a/pool/connection_pool.go b/pool/connection_pool.go index 26e2199e9..5ea426cd5 100644 --- a/pool/connection_pool.go +++ b/pool/connection_pool.go @@ -11,6 +11,7 @@ package pool import ( + "context" "errors" "log" "sync" @@ -103,6 +104,7 @@ type ConnectionPool struct { anyPool *roundRobinStrategy poolsMutex sync.RWMutex watcherContainer watcherContainer + ctxCancels []context.CancelFunc } var _ Pooler = (*ConnectionPool)(nil) @@ -133,7 +135,8 @@ func newEndpoint(addr string) *endpoint { // ConnectWithOpts creates pool for instances with addresses addrs // with options opts. -func ConnectWithOpts(addrs []string, connOpts tarantool.Opts, opts Opts) (*ConnectionPool, error) { +func ConnectWithOpts(ctx context.Context, addrs []string, + connOpts tarantool.Opts, opts Opts) (*ConnectionPool, error) { if len(addrs) == 0 { return nil, ErrEmptyAddrs } @@ -161,7 +164,7 @@ func ConnectWithOpts(addrs []string, connOpts tarantool.Opts, opts Opts) (*Conne connPool.addrs[addr] = nil } - somebodyAlive := connPool.fillPools() + somebodyAlive := connPool.fillPools(ctx) if !somebodyAlive { connPool.state.set(closedState) return nil, ErrNoConnection @@ -170,7 +173,9 @@ func ConnectWithOpts(addrs []string, connOpts tarantool.Opts, opts Opts) (*Conne connPool.state.set(connectedState) for _, s := range connPool.addrs { - go connPool.controller(s) + controllerCtx, cancel := context.WithCancel(context.Background()) + connPool.ctxCancels = append(connPool.ctxCancels, cancel) + go connPool.controller(controllerCtx, s) } return connPool, nil @@ -181,11 +186,12 @@ func ConnectWithOpts(addrs []string, connOpts tarantool.Opts, opts Opts) (*Conne // It is useless to set up tarantool.Opts.Reconnect value for a connection. // The connection pool has its own reconnection logic. See // Opts.CheckTimeout description. -func Connect(addrs []string, connOpts tarantool.Opts) (*ConnectionPool, error) { +func Connect(ctx context.Context, addrs []string, + connOpts tarantool.Opts) (*ConnectionPool, error) { opts := Opts{ CheckTimeout: 1 * time.Second, } - return ConnectWithOpts(addrs, connOpts, opts) + return ConnectWithOpts(ctx, addrs, connOpts, opts) } // ConnectedNow gets connected status of pool. @@ -224,7 +230,7 @@ func (p *ConnectionPool) ConfiguredTimeout(mode Mode) (time.Duration, error) { // Add adds a new endpoint with the address into the pool. This function // adds the endpoint only after successful connection. -func (p *ConnectionPool) Add(addr string) error { +func (p *ConnectionPool) Add(ctx context.Context, addr string) error { e := newEndpoint(addr) p.addrsMutex.Lock() @@ -240,7 +246,7 @@ func (p *ConnectionPool) Add(addr string) error { p.addrs[addr] = e p.addrsMutex.Unlock() - if err := p.tryConnect(e); err != nil { + if err := p.tryConnect(ctx, e); err != nil { p.addrsMutex.Lock() delete(p.addrs, addr) p.addrsMutex.Unlock() @@ -248,7 +254,9 @@ func (p *ConnectionPool) Add(addr string) error { return err } - go p.controller(e) + controllerCtx, cancel := context.WithCancel(context.Background()) + p.ctxCancels = append(p.ctxCancels, cancel) + go p.controller(controllerCtx, e) return nil } @@ -306,6 +314,9 @@ func (p *ConnectionPool) Close() []error { } p.addrsMutex.RUnlock() } + for _, cancel := range p.ctxCancels { + cancel() + } return p.waitClose() } @@ -1109,7 +1120,7 @@ func (p *ConnectionPool) handlerDeactivated(conn *tarantool.Connection, } } -func (p *ConnectionPool) fillPools() bool { +func (p *ConnectionPool) fillPools(ctx context.Context) bool { somebodyAlive := false // It is called before controller() goroutines so we don't expect @@ -1120,7 +1131,7 @@ func (p *ConnectionPool) fillPools() bool { connOpts := p.connOpts connOpts.Notify = end.notify - conn, err := tarantool.Connect(addr, connOpts) + conn, err := tarantool.Connect(ctx, addr, connOpts) if err != nil { log.Printf("tarantool: connect to %s failed: %s\n", addr, err.Error()) } else if conn != nil { @@ -1213,7 +1224,7 @@ func (p *ConnectionPool) updateConnection(e *endpoint) { } } -func (p *ConnectionPool) tryConnect(e *endpoint) error { +func (p *ConnectionPool) tryConnect(ctx context.Context, e *endpoint) error { p.poolsMutex.Lock() if p.state.get() != connectedState { @@ -1226,7 +1237,7 @@ func (p *ConnectionPool) tryConnect(e *endpoint) error { connOpts := p.connOpts connOpts.Notify = e.notify - conn, err := tarantool.Connect(e.addr, connOpts) + conn, err := tarantool.Connect(ctx, e.addr, connOpts) if err == nil { role, err := p.getConnectionRole(conn) p.poolsMutex.Unlock() @@ -1265,7 +1276,7 @@ func (p *ConnectionPool) tryConnect(e *endpoint) error { return err } -func (p *ConnectionPool) reconnect(e *endpoint) { +func (p *ConnectionPool) reconnect(ctx context.Context, e *endpoint) { p.poolsMutex.Lock() if p.state.get() != connectedState { @@ -1280,10 +1291,10 @@ func (p *ConnectionPool) reconnect(e *endpoint) { e.conn = nil e.role = UnknownRole - p.tryConnect(e) + p.tryConnect(ctx, e) } -func (p *ConnectionPool) controller(e *endpoint) { +func (p *ConnectionPool) controller(ctx context.Context, e *endpoint) { timer := time.NewTicker(p.opts.CheckTimeout) defer timer.Stop() @@ -1367,11 +1378,11 @@ func (p *ConnectionPool) controller(e *endpoint) { // Relocate connection between subpools // if ro/rw was updated. if e.conn == nil { - p.tryConnect(e) + p.tryConnect(ctx, e) } else if !e.conn.ClosedNow() { p.updateConnection(e) } else { - p.reconnect(e) + p.reconnect(ctx, e) } } } diff --git a/pool/connection_pool_test.go b/pool/connection_pool_test.go index dd0210d62..63a953851 100644 --- a/pool/connection_pool_test.go +++ b/pool/connection_pool_test.go @@ -1,6 +1,7 @@ package pool_test import ( + "context" "fmt" "log" "os" @@ -46,17 +47,18 @@ var defaultTimeoutRetry = 500 * time.Millisecond var instances []test_helpers.TarantoolInstance func TestConnError_IncorrectParams(t *testing.T) { - connPool, err := pool.Connect([]string{}, tarantool.Opts{}) + connPool, err := pool.Connect(context.Background(), []string{}, tarantool.Opts{}) require.Nilf(t, connPool, "conn is not nil with incorrect param") require.NotNilf(t, err, "err is nil with incorrect params") require.Equal(t, "addrs (first argument) should not be empty", err.Error()) - connPool, err = pool.Connect([]string{"err1", "err2"}, connOpts) + connPool, err = pool.Connect(context.Background(), []string{"err1", "err2"}, connOpts) require.Nilf(t, connPool, "conn is not nil with incorrect param") require.NotNilf(t, err, "err is nil with incorrect params") require.Equal(t, "no active connections", err.Error()) - connPool, err = pool.ConnectWithOpts(servers, tarantool.Opts{}, pool.Opts{}) + connPool, err = pool.ConnectWithOpts(context.Background(), servers, + tarantool.Opts{}, pool.Opts{}) require.Nilf(t, connPool, "conn is not nil with incorrect param") require.NotNilf(t, err, "err is nil with incorrect params") require.Equal(t, "wrong check timeout, must be greater than 0", err.Error()) @@ -64,7 +66,7 @@ func TestConnError_IncorrectParams(t *testing.T) { func TestConnSuccessfully(t *testing.T) { server := servers[0] - connPool, err := pool.Connect([]string{"err", server}, connOpts) + connPool, err := pool.Connect(context.Background(), []string{"err", server}, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -84,9 +86,41 @@ func TestConnSuccessfully(t *testing.T) { require.Nil(t, err) } +func TestConnErrorAfterCtxCancel(t *testing.T) { + var connLongReconnectOpts = tarantool.Opts{ + Timeout: 5 * time.Second, + User: "test", + Pass: "test", + Reconnect: time.Second, + MaxReconnects: 100, + } + + ctx, cancel := context.WithCancel(context.Background()) + + var connPool *pool.ConnectionPool + var err error + + var wg sync.WaitGroup + wg.Add(1) + + go func() { + defer wg.Done() + cancel() + connPool, err = pool.Connect(ctx, []string{"err"}, + connLongReconnectOpts) + }() + + wg.Wait() + + if connPool != nil || err == nil { + t.Fatalf("ConnectionPool was created after cancel") + } +} + func TestConnSuccessfullyDuplicates(t *testing.T) { server := servers[0] - connPool, err := pool.Connect([]string{server, server, server, server}, connOpts) + connPool, err := pool.Connect(context.Background(), + []string{server, server, server, server}, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -112,7 +146,7 @@ func TestConnSuccessfullyDuplicates(t *testing.T) { func TestReconnect(t *testing.T) { server := servers[0] - connPool, err := pool.Connect(servers, connOpts) + connPool, err := pool.Connect(context.Background(), servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -158,7 +192,7 @@ func TestDisconnect_withReconnect(t *testing.T) { opts := connOpts opts.Reconnect = 10 * time.Second - connPool, err := pool.Connect([]string{servers[serverId]}, opts) + connPool, err := pool.Connect(context.Background(), []string{servers[serverId]}, opts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -202,7 +236,7 @@ func TestDisconnectAll(t *testing.T) { server1 := servers[0] server2 := servers[1] - connPool, err := pool.Connect([]string{server1, server2}, connOpts) + connPool, err := pool.Connect(context.Background(), []string{server1, server2}, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -249,14 +283,14 @@ func TestDisconnectAll(t *testing.T) { } func TestAdd(t *testing.T) { - connPool, err := pool.Connect([]string{servers[0]}, connOpts) + connPool, err := pool.Connect(context.Background(), []string{servers[0]}, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") defer connPool.Close() for _, server := range servers[1:] { - err = connPool.Add(server) + err = connPool.Add(context.Background(), server) require.Nil(t, err) } @@ -280,13 +314,13 @@ func TestAdd(t *testing.T) { } func TestAdd_exist(t *testing.T) { - connPool, err := pool.Connect([]string{servers[0]}, connOpts) + connPool, err := pool.Connect(context.Background(), []string{servers[0]}, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") defer connPool.Close() - err = connPool.Add(servers[0]) + err = connPool.Add(context.Background(), servers[0]) require.Equal(t, pool.ErrExists, err) args := test_helpers.CheckStatusesArgs{ @@ -305,13 +339,13 @@ func TestAdd_exist(t *testing.T) { } func TestAdd_unreachable(t *testing.T) { - connPool, err := pool.Connect([]string{servers[0]}, connOpts) + connPool, err := pool.Connect(context.Background(), []string{servers[0]}, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") defer connPool.Close() - err = connPool.Add("127.0.0.2:6667") + err = connPool.Add(context.Background(), "127.0.0.2:6667") // The OS-dependent error so we just check for existence. require.NotNil(t, err) @@ -331,17 +365,17 @@ func TestAdd_unreachable(t *testing.T) { } func TestAdd_afterClose(t *testing.T) { - connPool, err := pool.Connect([]string{servers[0]}, connOpts) + connPool, err := pool.Connect(context.Background(), []string{servers[0]}, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") connPool.Close() - err = connPool.Add(servers[0]) + err = connPool.Add(context.Background(), servers[0]) assert.Equal(t, err, pool.ErrClosed) } func TestAdd_Close_concurrent(t *testing.T) { - connPool, err := pool.Connect([]string{servers[0]}, connOpts) + connPool, err := pool.Connect(context.Background(), []string{servers[0]}, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -350,7 +384,7 @@ func TestAdd_Close_concurrent(t *testing.T) { go func() { defer wg.Done() - err = connPool.Add(servers[1]) + err = connPool.Add(context.Background(), servers[1]) if err != nil { assert.Equal(t, pool.ErrClosed, err) } @@ -362,7 +396,7 @@ func TestAdd_Close_concurrent(t *testing.T) { } func TestAdd_CloseGraceful_concurrent(t *testing.T) { - connPool, err := pool.Connect([]string{servers[0]}, connOpts) + connPool, err := pool.Connect(context.Background(), []string{servers[0]}, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -371,7 +405,7 @@ func TestAdd_CloseGraceful_concurrent(t *testing.T) { go func() { defer wg.Done() - err = connPool.Add(servers[1]) + err = connPool.Add(context.Background(), servers[1]) if err != nil { assert.Equal(t, pool.ErrClosed, err) } @@ -383,7 +417,7 @@ func TestAdd_CloseGraceful_concurrent(t *testing.T) { } func TestRemove(t *testing.T) { - connPool, err := pool.Connect(servers, connOpts) + connPool, err := pool.Connect(context.Background(), servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -410,7 +444,7 @@ func TestRemove(t *testing.T) { } func TestRemove_double(t *testing.T) { - connPool, err := pool.Connect([]string{servers[0], servers[1]}, connOpts) + connPool, err := pool.Connect(context.Background(), []string{servers[0], servers[1]}, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -437,7 +471,7 @@ func TestRemove_double(t *testing.T) { } func TestRemove_unknown(t *testing.T) { - connPool, err := pool.Connect([]string{servers[0], servers[1]}, connOpts) + connPool, err := pool.Connect(context.Background(), []string{servers[0], servers[1]}, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -463,7 +497,7 @@ func TestRemove_unknown(t *testing.T) { } func TestRemove_concurrent(t *testing.T) { - connPool, err := pool.Connect([]string{servers[0], servers[1]}, connOpts) + connPool, err := pool.Connect(context.Background(), []string{servers[0], servers[1]}, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -510,7 +544,7 @@ func TestRemove_concurrent(t *testing.T) { } func TestRemove_Close_concurrent(t *testing.T) { - connPool, err := pool.Connect([]string{servers[0], servers[1]}, connOpts) + connPool, err := pool.Connect(context.Background(), []string{servers[0], servers[1]}, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -529,7 +563,7 @@ func TestRemove_Close_concurrent(t *testing.T) { } func TestRemove_CloseGraceful_concurrent(t *testing.T) { - connPool, err := pool.Connect([]string{servers[0], servers[1]}, connOpts) + connPool, err := pool.Connect(context.Background(), []string{servers[0], servers[1]}, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -551,7 +585,7 @@ func TestClose(t *testing.T) { server1 := servers[0] server2 := servers[1] - connPool, err := pool.Connect([]string{server1, server2}, connOpts) + connPool, err := pool.Connect(context.Background(), []string{server1, server2}, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -591,7 +625,7 @@ func TestCloseGraceful(t *testing.T) { server1 := servers[0] server2 := servers[1] - connPool, err := pool.Connect([]string{server1, server2}, connOpts) + connPool, err := pool.Connect(context.Background(), []string{server1, server2}, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -730,7 +764,7 @@ func TestConnectionHandlerOpenUpdateClose(t *testing.T) { CheckTimeout: 100 * time.Microsecond, ConnectionHandler: h, } - connPool, err := pool.ConnectWithOpts(poolServers, connOpts, poolOpts) + connPool, err := pool.ConnectWithOpts(context.Background(), poolServers, connOpts, poolOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -804,7 +838,7 @@ func TestConnectionHandlerOpenError(t *testing.T) { CheckTimeout: 100 * time.Microsecond, ConnectionHandler: h, } - connPool, err := pool.ConnectWithOpts(poolServers, connOpts, poolOpts) + connPool, err := pool.ConnectWithOpts(context.Background(), poolServers, connOpts, poolOpts) if err == nil { defer connPool.Close() } @@ -846,7 +880,7 @@ func TestConnectionHandlerUpdateError(t *testing.T) { CheckTimeout: 100 * time.Microsecond, ConnectionHandler: h, } - connPool, err := pool.ConnectWithOpts(poolServers, connOpts, poolOpts) + connPool, err := pool.ConnectWithOpts(context.Background(), poolServers, connOpts, poolOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") defer connPool.Close() @@ -891,7 +925,7 @@ func TestRequestOnClosed(t *testing.T) { server1 := servers[0] server2 := servers[1] - connPool, err := pool.Connect([]string{server1, server2}, connOpts) + connPool, err := pool.Connect(context.Background(), []string{server1, server2}, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -930,7 +964,7 @@ func TestGetPoolInfo(t *testing.T) { srvs := []string{server1, server2} expected := []string{server1, server2} - connPool, err := pool.Connect(srvs, connOpts) + connPool, err := pool.Connect(context.Background(), srvs, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -948,7 +982,7 @@ func TestCall(t *testing.T) { err := test_helpers.SetClusterRO(servers, connOpts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, connOpts) + connPool, err := pool.Connect(context.Background(), servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -1005,7 +1039,7 @@ func TestCall16(t *testing.T) { err := test_helpers.SetClusterRO(servers, connOpts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, connOpts) + connPool, err := pool.Connect(context.Background(), servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -1062,7 +1096,7 @@ func TestCall17(t *testing.T) { err := test_helpers.SetClusterRO(servers, connOpts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, connOpts) + connPool, err := pool.Connect(context.Background(), servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -1119,7 +1153,7 @@ func TestEval(t *testing.T) { err := test_helpers.SetClusterRO(servers, connOpts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, connOpts) + connPool, err := pool.Connect(context.Background(), servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -1197,7 +1231,7 @@ func TestExecute(t *testing.T) { err := test_helpers.SetClusterRO(servers, connOpts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, connOpts) + connPool, err := pool.Connect(context.Background(), servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -1254,7 +1288,7 @@ func TestRoundRobinStrategy(t *testing.T) { err := test_helpers.SetClusterRO(servers, connOpts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, connOpts) + connPool, err := pool.Connect(context.Background(), servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -1331,7 +1365,7 @@ func TestRoundRobinStrategy_NoReplica(t *testing.T) { err := test_helpers.SetClusterRO(servers, connOpts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, connOpts) + connPool, err := pool.Connect(context.Background(), servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -1402,7 +1436,7 @@ func TestRoundRobinStrategy_NoMaster(t *testing.T) { err := test_helpers.SetClusterRO(servers, connOpts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, connOpts) + connPool, err := pool.Connect(context.Background(), servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -1485,7 +1519,7 @@ func TestUpdateInstancesRoles(t *testing.T) { err := test_helpers.SetClusterRO(servers, connOpts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, connOpts) + connPool, err := pool.Connect(context.Background(), servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -1629,7 +1663,7 @@ func TestInsert(t *testing.T) { err := test_helpers.SetClusterRO(servers, connOpts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, connOpts) + connPool, err := pool.Connect(context.Background(), servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -1728,7 +1762,7 @@ func TestDelete(t *testing.T) { err := test_helpers.SetClusterRO(servers, connOpts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, connOpts) + connPool, err := pool.Connect(context.Background(), servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -1792,7 +1826,7 @@ func TestUpsert(t *testing.T) { err := test_helpers.SetClusterRO(servers, connOpts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, connOpts) + connPool, err := pool.Connect(context.Background(), servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -1864,7 +1898,7 @@ func TestUpdate(t *testing.T) { err := test_helpers.SetClusterRO(servers, connOpts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, connOpts) + connPool, err := pool.Connect(context.Background(), servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -1954,7 +1988,7 @@ func TestReplace(t *testing.T) { err := test_helpers.SetClusterRO(servers, connOpts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, connOpts) + connPool, err := pool.Connect(context.Background(), servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -2040,7 +2074,7 @@ func TestSelect(t *testing.T) { err := test_helpers.SetClusterRO(servers, connOpts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, connOpts) + connPool, err := pool.Connect(context.Background(), servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -2160,7 +2194,7 @@ func TestPing(t *testing.T) { err := test_helpers.SetClusterRO(servers, connOpts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, connOpts) + connPool, err := pool.Connect(context.Background(), servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -2198,7 +2232,7 @@ func TestDo(t *testing.T) { err := test_helpers.SetClusterRO(servers, connOpts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, connOpts) + connPool, err := pool.Connect(context.Background(), servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -2237,7 +2271,7 @@ func TestDo_concurrent(t *testing.T) { err := test_helpers.SetClusterRO(servers, connOpts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, connOpts) + connPool, err := pool.Connect(context.Background(), servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -2268,7 +2302,7 @@ func TestNewPrepared(t *testing.T) { err := test_helpers.SetClusterRO(servers, connOpts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, connOpts) + connPool, err := pool.Connect(context.Background(), servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -2336,7 +2370,7 @@ func TestDoWithStrangerConn(t *testing.T) { err := test_helpers.SetClusterRO(servers, connOpts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, connOpts) + connPool, err := pool.Connect(context.Background(), servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -2365,7 +2399,7 @@ func TestStream_Commit(t *testing.T) { err = test_helpers.SetClusterRO(servers, connOpts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, connOpts) + connPool, err := pool.Connect(context.Background(), servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") defer connPool.Close() @@ -2464,7 +2498,7 @@ func TestStream_Rollback(t *testing.T) { err = test_helpers.SetClusterRO(servers, connOpts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, connOpts) + connPool, err := pool.Connect(context.Background(), servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") defer connPool.Close() @@ -2563,7 +2597,7 @@ func TestStream_TxnIsolationLevel(t *testing.T) { err = test_helpers.SetClusterRO(servers, connOpts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, connOpts) + connPool, err := pool.Connect(context.Background(), servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") defer connPool.Close() @@ -2657,7 +2691,7 @@ func TestConnectionPool_NewWatcher_noWatchersFeature(t *testing.T) { err := test_helpers.SetClusterRO(servers, opts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, opts) + connPool, err := pool.Connect(context.Background(), servers, opts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") defer connPool.Close() @@ -2685,7 +2719,7 @@ func TestConnectionPool_NewWatcher_modes(t *testing.T) { err := test_helpers.SetClusterRO(servers, opts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, opts) + connPool, err := pool.Connect(context.Background(), servers, opts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") defer connPool.Close() @@ -2767,7 +2801,7 @@ func TestConnectionPool_NewWatcher_update(t *testing.T) { poolOpts := pool.Opts{ CheckTimeout: 500 * time.Millisecond, } - pool, err := pool.ConnectWithOpts(servers, opts, poolOpts) + pool, err := pool.ConnectWithOpts(context.Background(), servers, opts, poolOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, pool, "conn is nil after Connect") @@ -2851,7 +2885,7 @@ func TestWatcher_Unregister(t *testing.T) { err := test_helpers.SetClusterRO(servers, opts, roles) require.Nilf(t, err, "fail to set roles for cluster") - pool, err := pool.Connect(servers, opts) + pool, err := pool.Connect(context.Background(), servers, opts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, pool, "conn is nil after Connect") defer pool.Close() @@ -2910,7 +2944,7 @@ func TestConnectionPool_NewWatcher_concurrent(t *testing.T) { err := test_helpers.SetClusterRO(servers, opts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, opts) + connPool, err := pool.Connect(context.Background(), servers, opts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") defer connPool.Close() @@ -2950,7 +2984,7 @@ func TestWatcher_Unregister_concurrent(t *testing.T) { err := test_helpers.SetClusterRO(servers, opts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, opts) + connPool, err := pool.Connect(context.Background(), servers, opts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") defer connPool.Close() diff --git a/pool/example_test.go b/pool/example_test.go index 84a41ff7b..75204f22b 100644 --- a/pool/example_test.go +++ b/pool/example_test.go @@ -1,6 +1,7 @@ package pool_test import ( + "context" "fmt" "time" @@ -24,7 +25,8 @@ func examplePool(roles []bool, connOpts tarantool.Opts) (*pool.ConnectionPool, e if err != nil { return nil, fmt.Errorf("ConnectionPool is not established") } - connPool, err := pool.Connect(servers, connOpts) + ctx := context.Background() + connPool, err := pool.Connect(ctx, servers, connOpts) if err != nil || connPool == nil { return nil, fmt.Errorf("ConnectionPool is not established") } diff --git a/queue/example_connection_pool_test.go b/queue/example_connection_pool_test.go index 51fb967a5..5795b2682 100644 --- a/queue/example_connection_pool_test.go +++ b/queue/example_connection_pool_test.go @@ -1,6 +1,7 @@ package queue_test import ( + "context" "fmt" "sync" "sync/atomic" @@ -164,7 +165,7 @@ func Example_connectionPool() { CheckTimeout: 5 * time.Second, ConnectionHandler: h, } - connPool, err := pool.ConnectWithOpts(servers, connOpts, poolOpts) + connPool, err := pool.ConnectWithOpts(context.Background(), servers, connOpts, poolOpts) if err != nil { fmt.Printf("Unable to connect to the pool: %s", err) return diff --git a/queue/example_msgpack_test.go b/queue/example_msgpack_test.go index 6fd101e09..4737c1cf6 100644 --- a/queue/example_msgpack_test.go +++ b/queue/example_msgpack_test.go @@ -9,6 +9,7 @@ package queue_test import ( + "context" "fmt" "log" "time" @@ -55,7 +56,7 @@ func Example_simpleQueueCustomMsgPack() { User: "test", Pass: "test", } - conn, err := tarantool.Connect("127.0.0.1:3013", opts) + conn, err := tarantool.Connect(context.Background(), "127.0.0.1:3013", opts) if err != nil { log.Fatalf("connection: %s", err) return diff --git a/queue/example_test.go b/queue/example_test.go index 711ee31d4..7eeac8c43 100644 --- a/queue/example_test.go +++ b/queue/example_test.go @@ -9,6 +9,7 @@ package queue_test import ( + "context" "fmt" "time" @@ -31,7 +32,7 @@ func Example_simpleQueue() { Pass: "test", } - conn, err := tarantool.Connect("127.0.0.1:3013", opts) + conn, err := tarantool.Connect(context.Background(), "127.0.0.1:3013", opts) if err != nil { fmt.Printf("error in prepare is %v", err) return diff --git a/settings/example_test.go b/settings/example_test.go index b1d0e5d4f..1116443e2 100644 --- a/settings/example_test.go +++ b/settings/example_test.go @@ -1,6 +1,7 @@ package settings_test import ( + "context" "fmt" "github.com/tarantool/go-tarantool/v2" @@ -9,7 +10,7 @@ import ( ) func example_connect(opts tarantool.Opts) *tarantool.Connection { - conn, err := tarantool.Connect(server, opts) + conn, err := tarantool.Connect(context.Background(), server, opts) if err != nil { panic("Connection is not established: " + err.Error()) } diff --git a/shutdown_test.go b/shutdown_test.go index bb4cfa099..066393810 100644 --- a/shutdown_test.go +++ b/shutdown_test.go @@ -6,6 +6,7 @@ package tarantool_test import ( + "context" "fmt" "sync" "syscall" @@ -462,7 +463,7 @@ func TestGracefulShutdownCloseConcurrent(t *testing.T) { // Do not wait till Tarantool register out watcher, // test everything is ok even on async. - conn, err := Connect(shtdnServer, shtdnClntOpts) + conn, err := Connect(context.Background(), shtdnServer, shtdnClntOpts) if err != nil { t.Errorf("Failed to connect: %s", err) } else { diff --git a/ssl.go b/ssl.go index a23238849..a0743edd1 100644 --- a/ssl.go +++ b/ssl.go @@ -5,6 +5,7 @@ package tarantool import ( "bufio" + "context" "errors" "io/ioutil" "net" @@ -25,6 +26,16 @@ func sslDialTimeout(network, address string, timeout time.Duration, return openssl.DialTimeout(network, address, timeout, ctx.(*openssl.Ctx), 0) } +func sslDialContext(ctx context.Context, network, address string, + opts SslOpts) (connection net.Conn, err error) { + var sslCtx interface{} + if sslCtx, err = sslCreateContext(opts); err != nil { + return + } + + return openssl.DialContext(ctx, network, address, sslCtx.(*openssl.Ctx), 0) +} + // interface{} is a hack. It helps to avoid dependency of go-openssl in build // of tests with the tag 'go_tarantool_ssl_disable'. func sslCreateContext(opts SslOpts) (ctx interface{}, err error) { diff --git a/ssl_test.go b/ssl_test.go index 30078703c..f1716c436 100644 --- a/ssl_test.go +++ b/ssl_test.go @@ -4,6 +4,7 @@ package tarantool_test import ( + "context" "errors" "fmt" "io/ioutil" @@ -150,7 +151,7 @@ func serverTntStop(inst test_helpers.TarantoolInstance) { } func checkTntConn(clientOpts SslOpts) error { - conn, err := Connect(tntHost, Opts{ + conn, err := Connect(context.Background(), tntHost, Opts{ Auth: AutoAuth, Timeout: 500 * time.Millisecond, User: "test", @@ -645,6 +646,79 @@ func TestSslOpts(t *testing.T) { } } +func TestSslDialContext(t *testing.T) { + serverOpts := SslOpts{ + KeyFile: "testdata/localhost.key", + CertFile: "testdata/localhost.crt", + CaFile: "testdata/ca.crt", + Ciphers: "ECDHE-RSA-AES256-GCM-SHA384", + } + clientOpts := SslOpts{ + KeyFile: "testdata/localhost.key", + CertFile: "testdata/localhost.crt", + CaFile: "testdata/ca.crt", + Ciphers: "ECDHE-RSA-AES256-GCM-SHA384", + } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + l, err := serverSsl("tcp", sslHost+":0", serverOpts) + if err != nil { + t.Fatalf("Unable to create server, error %q", err.Error()) + } + + msgs, errs := serverSslAccept(l) + + port := l.Addr().(*net.TCPAddr).Port + c, err := SslDialContext(ctx, "tcp", sslHost+":"+strconv.Itoa(port), clientOpts) + if err != nil { + t.Fatalf("Error while creating a client: %v", err) + } + + const message = "any test string" + c.Write([]byte(message)) + c.Close() + + recv, err := serverSslRecv(msgs, errs) + + if err != nil { + t.Errorf("An unexpected server error: %q", err.Error()) + } else if recv != message { + t.Errorf("An unexpected server message: %q, expected %q", recv, message) + } +} + +func TestSslDialContextCancel(t *testing.T) { + serverOpts := SslOpts{ + KeyFile: "testdata/localhost.key", + CertFile: "testdata/localhost.crt", + CaFile: "testdata/ca.crt", + Ciphers: "ECDHE-RSA-AES256-GCM-SHA384", + } + clientOpts := SslOpts{ + KeyFile: "testdata/localhost.key", + CertFile: "testdata/localhost.crt", + CaFile: "testdata/ca.crt", + Ciphers: "ECDHE-RSA-AES256-GCM-SHA384", + } + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + l, err := serverSsl("tcp", sslHost+":0", serverOpts) + if err != nil { + t.Fatalf("Unable to create server, error %q", err.Error()) + } + + serverSslAccept(l) + + port := l.Addr().(*net.TCPAddr).Port + _, err = SslDialContext(ctx, "tcp", sslHost+":"+strconv.Itoa(port), clientOpts) + + if err == nil { + t.Fatalf("Expected error, dial was not canceled") + } +} + func TestOpts_PapSha256Auth(t *testing.T) { isTntSsl := isTestTntSsl() if !isTntSsl { diff --git a/tarantool_test.go b/tarantool_test.go index 6339164f1..f153d53b4 100644 --- a/tarantool_test.go +++ b/tarantool_test.go @@ -708,7 +708,7 @@ func TestTtDialer(t *testing.T) { assert := assert.New(t) require := require.New(t) - conn, err := TtDialer{}.Dial(server, DialOpts{}) + conn, err := TtDialer{}.Dial(context.Background(), server, DialOpts{}) require.Nil(err) require.NotNil(conn) defer conn.Close() @@ -778,7 +778,7 @@ func TestOptsAuth_PapSha256AuthForbit(t *testing.T) { papSha256Opts := opts papSha256Opts.Auth = PapSha256Auth - conn, err := Connect(server, papSha256Opts) + conn, err := Connect(context.Background(), server, papSha256Opts) if err == nil { t.Error("An error expected.") conn.Close() @@ -3408,7 +3408,7 @@ func TestConnectionProtocolVersionRequirementSuccess(t *testing.T) { Version: ProtocolVersion(3), } - conn, err := Connect(server, connOpts) + conn, err := Connect(context.Background(), server, connOpts) require.Nilf(t, err, "No errors on connect") require.NotNilf(t, conn, "Connect success") @@ -3424,7 +3424,7 @@ func TestConnectionProtocolVersionRequirementFail(t *testing.T) { Version: ProtocolVersion(3), } - conn, err := Connect(server, connOpts) + conn, err := Connect(context.Background(), server, connOpts) require.Nilf(t, conn, "Connect fail") require.NotNilf(t, err, "Got error on connect") @@ -3439,7 +3439,7 @@ func TestConnectionProtocolFeatureRequirementSuccess(t *testing.T) { Features: []ProtocolFeature{TransactionsFeature}, } - conn, err := Connect(server, connOpts) + conn, err := Connect(context.Background(), server, connOpts) require.NotNilf(t, conn, "Connect success") require.Nilf(t, err, "No errors on connect") @@ -3455,7 +3455,7 @@ func TestConnectionProtocolFeatureRequirementFail(t *testing.T) { Features: []ProtocolFeature{TransactionsFeature}, } - conn, err := Connect(server, connOpts) + conn, err := Connect(context.Background(), server, connOpts) require.Nilf(t, conn, "Connect fail") require.NotNilf(t, err, "Got error on connect") @@ -3471,7 +3471,7 @@ func TestConnectionProtocolFeatureRequirementManyFail(t *testing.T) { Features: []ProtocolFeature{TransactionsFeature, ProtocolFeature(15532)}, } - conn, err := Connect(server, connOpts) + conn, err := Connect(context.Background(), server, connOpts) require.Nilf(t, conn, "Connect fail") require.NotNilf(t, err, "Got error on connect") @@ -4003,7 +4003,7 @@ func TestConnect_schema_update(t *testing.T) { for i := 0; i < 100; i++ { fut := conn.Do(NewCallRequest("create_spaces")) - if conn, err := Connect(server, opts); err != nil { + if conn, err := Connect(context.Background(), server, opts); err != nil { if err.Error() != "concurrent schema update" { t.Errorf("unexpected error: %s", err) } @@ -4019,6 +4019,38 @@ func TestConnect_schema_update(t *testing.T) { } } +func TestConnect_context_cancel(t *testing.T) { + var connLongReconnectOpts = Opts{ + Timeout: 5 * time.Second, + User: "test", + Pass: "test", + Reconnect: time.Second, + MaxReconnects: 100, + } + + ctx, cancel := context.WithCancel(context.Background()) + + var conn *Connection + var err error + + var wg sync.WaitGroup + wg.Add(1) + + go func() { + defer wg.Done() + cancel() + conn, err = Connect(ctx, server, connLongReconnectOpts) + }() + + cancel() + + wg.Wait() + + if conn != nil || err == nil { + t.Fatalf("ConnectionPool was created after cancel") + } +} + // runTestMain is a body of TestMain function // (see https://pkg.go.dev/testing#hdr-Main). // Using defer + os.Exit is not works so TestMain body diff --git a/test_helpers/main.go b/test_helpers/main.go index 894ebb653..90145ef46 100644 --- a/test_helpers/main.go +++ b/test_helpers/main.go @@ -11,6 +11,7 @@ package test_helpers import ( + "context" "errors" "fmt" "io" @@ -97,7 +98,7 @@ func isReady(server string, opts *tarantool.Opts) error { var conn *tarantool.Connection var resp *tarantool.Response - conn, err = tarantool.Connect(server, *opts) + conn, err = tarantool.Connect(context.Background(), server, *opts) if err != nil { return err } diff --git a/test_helpers/pool_helper.go b/test_helpers/pool_helper.go index c44df2f6a..7028512b8 100644 --- a/test_helpers/pool_helper.go +++ b/test_helpers/pool_helper.go @@ -1,6 +1,7 @@ package test_helpers import ( + "context" "fmt" "reflect" "time" @@ -132,7 +133,7 @@ func Retry(f func(interface{}) error, args interface{}, count int, timeout time. func InsertOnInstance(server string, connOpts tarantool.Opts, space interface{}, tuple interface{}) error { - conn, err := tarantool.Connect(server, connOpts) + conn, err := tarantool.Connect(context.Background(), server, connOpts) if err != nil { return fmt.Errorf("fail to connect to %s: %s", server, err.Error()) } @@ -192,7 +193,7 @@ func InsertOnInstances(servers []string, connOpts tarantool.Opts, space interfac } func SetInstanceRO(server string, connOpts tarantool.Opts, isReplica bool) error { - conn, err := tarantool.Connect(server, connOpts) + conn, err := tarantool.Connect(context.Background(), server, connOpts) if err != nil { return err } diff --git a/test_helpers/utils.go b/test_helpers/utils.go index 3771a5f9e..8967da4a0 100644 --- a/test_helpers/utils.go +++ b/test_helpers/utils.go @@ -1,6 +1,7 @@ package test_helpers import ( + "context" "fmt" "testing" "time" @@ -17,7 +18,7 @@ func ConnectWithValidation(t testing.TB, opts tarantool.Opts) *tarantool.Connection { t.Helper() - conn, err := tarantool.Connect(server, opts) + conn, err := tarantool.Connect(context.Background(), server, opts) if err != nil { t.Fatalf("Failed to connect: %s", err.Error()) } diff --git a/uuid/example_test.go b/uuid/example_test.go index 632f620be..ed5b68ae8 100644 --- a/uuid/example_test.go +++ b/uuid/example_test.go @@ -9,6 +9,7 @@ package uuid_test import ( + "context" "fmt" "log" @@ -25,7 +26,8 @@ func Example() { User: "test", Pass: "test", } - client, err := tarantool.Connect("127.0.0.1:3013", opts) + ctx := context.Background() + client, err := tarantool.Connect(ctx, "127.0.0.1:3013", opts) if err != nil { log.Fatalf("Failed to connect: %s", err.Error()) }