From 2d49157aa7797ffa5540ffbfde3f814ea48c9b88 Mon Sep 17 00:00:00 2001 From: Chris Bannister Date: Sun, 14 Feb 2016 13:14:13 +0000 Subject: [PATCH] store the hostInfo on the connection --- conn.go | 7 ++++++- conn_test.go | 3 ++- connectionpool.go | 4 ++-- control.go | 31 +++++++++++++++++++++++++++---- events.go | 1 + session.go | 4 ++-- 6 files changed, 40 insertions(+), 10 deletions(-) diff --git a/conn.go b/conn.go index 7612ae5cd..5f40f7822 100644 --- a/conn.go +++ b/conn.go @@ -139,6 +139,8 @@ type Conn struct { currentKeyspace string started bool + host *HostInfo + session *Session closed int32 @@ -148,7 +150,9 @@ type Conn struct { } // Connect establishes a connection to a Cassandra node. -func Connect(addr string, cfg *ConnConfig, errorHandler ConnErrorHandler, session *Session) (*Conn, error) { +func Connect(host *HostInfo, addr string, cfg *ConnConfig, + errorHandler ConnErrorHandler, session *Session) (*Conn, error) { + var ( err error conn net.Conn @@ -196,6 +200,7 @@ func Connect(addr string, cfg *ConnConfig, errorHandler ConnErrorHandler, sessio quit: make(chan struct{}), session: session, streams: streams.New(cfg.ProtoVersion), + host: host, } if cfg.Keepalive > 0 { diff --git a/conn_test.go b/conn_test.go index ed639952b..6caf49c38 100644 --- a/conn_test.go +++ b/conn_test.go @@ -416,7 +416,8 @@ func TestStream0(t *testing.T) { } }) - conn, err := Connect(srv.Address, &ConnConfig{ProtoVersion: int(srv.protocol)}, errorHandler, nil) + host := &HostInfo{peer: srv.Address} + conn, err := Connect(host, srv.Address, &ConnConfig{ProtoVersion: int(srv.protocol)}, errorHandler, nil) if err != nil { t.Fatal(err) } diff --git a/connectionpool.go b/connectionpool.go index 5b22e0111..6c2b5dcf7 100644 --- a/connectionpool.go +++ b/connectionpool.go @@ -246,7 +246,7 @@ func (p *policyConnPool) addHost(host *HostInfo) { pool = newHostConnPool( p.session, host, - host.Port(), + host.Port(), // TODO: if port == 0 use pool.port? p.numConns, p.keyspace, p.connPolicy(), @@ -506,7 +506,7 @@ func (pool *hostConnPool) connect() (err error) { // try to connect var conn *Conn for i := 0; i < maxAttempts; i++ { - conn, err = pool.session.connect(pool.addr, pool) + conn, err = pool.session.connect(pool.addr, pool, pool.host) if err == nil { break } diff --git a/control.go b/control.go index b579c61f1..1948187a7 100644 --- a/control.go +++ b/control.go @@ -99,9 +99,28 @@ func (c *controlConn) shuffleDial(endpoints []string) (conn *Conn, err error) { // shuffle endpoints so not all drivers will connect to the same initial // node. for _, addr := range shuffled { - conn, err = c.session.connect(JoinHostPort(addr, c.session.cfg.Port), c) + if addr == "" { + return nil, fmt.Errorf("control: invalid address: %q", addr) + } + + port := c.session.cfg.Port + addr = JoinHostPort(addr, port) + host, portStr, err := net.SplitHostPort(addr) + if err != nil { + host = addr + port = c.session.cfg.Port + err = nil + } else { + port, err = strconv.Atoi(portStr) + if err != nil { + return nil, err + } + } + + hostInfo, _ := c.session.ring.addHostIfMissing(&HostInfo{peer: host, port: port}) + conn, err = c.session.connect(addr, c, hostInfo) if err == nil { - return + return conn, err } log.Printf("gocql: unable to dial control conn %v: %v\n", addr, err) @@ -111,6 +130,10 @@ func (c *controlConn) shuffleDial(endpoints []string) (conn *Conn, err error) { } func (c *controlConn) connect(endpoints []string) error { + if len(endpoints) == 0 { + return errors.New("control: no endpoints specified") + } + conn, err := c.shuffleDial(endpoints) if err != nil { return fmt.Errorf("control: unable to connect: %v", err) @@ -200,7 +223,7 @@ func (c *controlConn) reconnect(refreshring bool) { var newConn *Conn if addr != "" { // try to connect to the old host - conn, err := c.session.connect(addr, c) + conn, err := c.session.connect(addr, c, oldConn.host) if err != nil { // host is dead // TODO: this is replicated in a few places @@ -222,7 +245,7 @@ func (c *controlConn) reconnect(refreshring bool) { } var err error - newConn, err = c.session.connect(conn.addr, c) + newConn, err = c.session.connect(conn.addr, c, conn.host) if err != nil { // TODO: add log handler for things like this return diff --git a/events.go b/events.go index 4dcb8dcdd..ab8de8fef 100644 --- a/events.go +++ b/events.go @@ -249,6 +249,7 @@ func (s *Session) handleNodeUp(ip net.IP, port int, waitForBinary bool) { time.Sleep(t) } + host.setPort(port) s.pool.hostUp(host) host.setState(NodeUp) return diff --git a/session.go b/session.go index 185fa710f..7dfc8b211 100644 --- a/session.go +++ b/session.go @@ -583,8 +583,8 @@ func (s *Session) MapExecuteBatchCAS(batch *Batch, dest map[string]interface{}) return applied, iter, iter.err } -func (s *Session) connect(addr string, errorHandler ConnErrorHandler) (*Conn, error) { - return Connect(addr, s.connCfg, errorHandler, s) +func (s *Session) connect(addr string, errorHandler ConnErrorHandler, host *HostInfo) (*Conn, error) { + return Connect(host, addr, s.connCfg, errorHandler, s) } // Query represents a CQL statement that can be executed.