diff --git a/control.go b/control.go index e18fc72bf..8178cf101 100644 --- a/control.go +++ b/control.go @@ -89,6 +89,22 @@ func (c *controlConn) heartBeat() { } } +func hostInfo(addr string, defaultPort int) (*HostInfo, error) { + var port int + host, portStr, err := net.SplitHostPort(addr) + if err != nil { + host = addr + port = defaultPort + } else { + port, err = strconv.Atoi(portStr) + if err != nil { + return nil, err + } + } + + return &HostInfo{peer: host, port: port}, nil +} + func (c *controlConn) shuffleDial(endpoints []string) (conn *Conn, err error) { perm := randr.Perm(len(endpoints)) shuffled := make([]string, len(endpoints)) @@ -101,24 +117,19 @@ func (c *controlConn) shuffleDial(endpoints []string) (conn *Conn, err error) { // node. for _, addr := range shuffled { if addr == "" { - return nil, fmt.Errorf("control: invalid address: %q", addr) + return nil, fmt.Errorf("invalid address: %q", addr) } port := c.session.cfg.Port addr = JoinHostPort(addr, port) - host, portStr, err := net.SplitHostPort(addr) + + var host *HostInfo + host, err = hostInfo(addr, port) if err != nil { - host = addr - port = c.session.cfg.Port - err = nil - } else { - port, err = strconv.Atoi(portStr) - if err != nil { - return nil, err - } + return nil, fmt.Errorf("invalid address: %q: %v", addr, err) } - hostInfo, _ := c.session.ring.addHostIfMissing(&HostInfo{peer: host, port: port}) + hostInfo, _ := c.session.ring.addHostIfMissing(host) conn, err = c.session.connect(addr, c, hostInfo) if err == nil { return conn, err @@ -127,7 +138,11 @@ func (c *controlConn) shuffleDial(endpoints []string) (conn *Conn, err error) { log.Printf("gocql: unable to dial control conn %v: %v\n", addr, err) } - return + if err != nil { + return nil, err + } + + return conn, nil } func (c *controlConn) connect(endpoints []string) error { @@ -137,9 +152,7 @@ func (c *controlConn) connect(endpoints []string) error { conn, err := c.shuffleDial(endpoints) if err != nil { - return fmt.Errorf("control: unable to connect: %v", err) - } else if conn == nil { - return errors.New("control: unable to connect to initial endpoints") + return fmt.Errorf("control: unable to connect to initial hosts: %v", err) } if err := c.setupConn(conn); err != nil {