Skip to content

Commit

Permalink
store the hostInfo on the connection
Browse files Browse the repository at this point in the history
  • Loading branch information
Zariel committed Feb 15, 2016
1 parent 2cbb8f1 commit 2d49157
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 10 deletions.
7 changes: 6 additions & 1 deletion conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ type Conn struct {
currentKeyspace string
started bool

host *HostInfo

session *Session

closed int32
Expand All @@ -148,7 +150,9 @@ type Conn struct {
}

// Connect establishes a connection to a Cassandra node.
func Connect(addr string, cfg *ConnConfig, errorHandler ConnErrorHandler, session *Session) (*Conn, error) {
func Connect(host *HostInfo, addr string, cfg *ConnConfig,
errorHandler ConnErrorHandler, session *Session) (*Conn, error) {

var (
err error
conn net.Conn
Expand Down Expand Up @@ -196,6 +200,7 @@ func Connect(addr string, cfg *ConnConfig, errorHandler ConnErrorHandler, sessio
quit: make(chan struct{}),
session: session,
streams: streams.New(cfg.ProtoVersion),
host: host,
}

if cfg.Keepalive > 0 {
Expand Down
3 changes: 2 additions & 1 deletion conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,8 @@ func TestStream0(t *testing.T) {
}
})

conn, err := Connect(srv.Address, &ConnConfig{ProtoVersion: int(srv.protocol)}, errorHandler, nil)
host := &HostInfo{peer: srv.Address}
conn, err := Connect(host, srv.Address, &ConnConfig{ProtoVersion: int(srv.protocol)}, errorHandler, nil)
if err != nil {
t.Fatal(err)
}
Expand Down
4 changes: 2 additions & 2 deletions connectionpool.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ func (p *policyConnPool) addHost(host *HostInfo) {
pool = newHostConnPool(
p.session,
host,
host.Port(),
host.Port(), // TODO: if port == 0 use pool.port?
p.numConns,
p.keyspace,
p.connPolicy(),
Expand Down Expand Up @@ -506,7 +506,7 @@ func (pool *hostConnPool) connect() (err error) {
// try to connect
var conn *Conn
for i := 0; i < maxAttempts; i++ {
conn, err = pool.session.connect(pool.addr, pool)
conn, err = pool.session.connect(pool.addr, pool, pool.host)
if err == nil {
break
}
Expand Down
31 changes: 27 additions & 4 deletions control.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,28 @@ func (c *controlConn) shuffleDial(endpoints []string) (conn *Conn, err error) {
// shuffle endpoints so not all drivers will connect to the same initial
// node.
for _, addr := range shuffled {
conn, err = c.session.connect(JoinHostPort(addr, c.session.cfg.Port), c)
if addr == "" {
return nil, fmt.Errorf("control: invalid address: %q", addr)
}

port := c.session.cfg.Port
addr = JoinHostPort(addr, port)
host, portStr, err := net.SplitHostPort(addr)
if err != nil {
host = addr
port = c.session.cfg.Port
err = nil
} else {
port, err = strconv.Atoi(portStr)
if err != nil {
return nil, err
}
}

hostInfo, _ := c.session.ring.addHostIfMissing(&HostInfo{peer: host, port: port})
conn, err = c.session.connect(addr, c, hostInfo)
if err == nil {
return
return conn, err
}

log.Printf("gocql: unable to dial control conn %v: %v\n", addr, err)
Expand All @@ -111,6 +130,10 @@ func (c *controlConn) shuffleDial(endpoints []string) (conn *Conn, err error) {
}

func (c *controlConn) connect(endpoints []string) error {
if len(endpoints) == 0 {
return errors.New("control: no endpoints specified")
}

conn, err := c.shuffleDial(endpoints)
if err != nil {
return fmt.Errorf("control: unable to connect: %v", err)
Expand Down Expand Up @@ -200,7 +223,7 @@ func (c *controlConn) reconnect(refreshring bool) {
var newConn *Conn
if addr != "" {
// try to connect to the old host
conn, err := c.session.connect(addr, c)
conn, err := c.session.connect(addr, c, oldConn.host)
if err != nil {
// host is dead
// TODO: this is replicated in a few places
Expand All @@ -222,7 +245,7 @@ func (c *controlConn) reconnect(refreshring bool) {
}

var err error
newConn, err = c.session.connect(conn.addr, c)
newConn, err = c.session.connect(conn.addr, c, conn.host)
if err != nil {
// TODO: add log handler for things like this
return
Expand Down
1 change: 1 addition & 0 deletions events.go
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ func (s *Session) handleNodeUp(ip net.IP, port int, waitForBinary bool) {
time.Sleep(t)
}

host.setPort(port)
s.pool.hostUp(host)
host.setState(NodeUp)
return
Expand Down
4 changes: 2 additions & 2 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -583,8 +583,8 @@ func (s *Session) MapExecuteBatchCAS(batch *Batch, dest map[string]interface{})
return applied, iter, iter.err
}

func (s *Session) connect(addr string, errorHandler ConnErrorHandler) (*Conn, error) {
return Connect(addr, s.connCfg, errorHandler, s)
func (s *Session) connect(addr string, errorHandler ConnErrorHandler, host *HostInfo) (*Conn, error) {
return Connect(host, addr, s.connCfg, errorHandler, s)
}

// Query represents a CQL statement that can be executed.
Expand Down

0 comments on commit 2d49157

Please sign in to comment.