Skip to content

Commit

Permalink
host_source: correctly get local hostinfo (apache#932)
Browse files Browse the repository at this point in the history
The system.local table might not have the address info for the connected
host, if this is the case use the connect address from the control
connection.

Add more checks that the connect address is valid to not end up adding
invalid hosts.
  • Loading branch information
Zariel authored Jul 2, 2017
1 parent cf5e3dd commit c9b7799
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 12 deletions.
41 changes: 29 additions & 12 deletions host_source.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,11 @@ func (h *HostInfo) setPeer(peer net.IP) *HostInfo {
return h
}

func (h *HostInfo) invalidConnectAddr() bool {
addr := h.ConnectAddress()
return addr == nil || addr.IsUnspecified()
}

// Returns the address that should be used to connect to the host.
// If you wish to override this, use an AddressTranslator or
// use a HostFilter to SetConnectAddress()
Expand Down Expand Up @@ -472,7 +477,16 @@ func (r *ringDescriber) GetLocalHostInfo() (*HostInfo, error) {
if it == nil {
return nil, errors.New("Attempted to query 'system.local' on a closed control connection")
}
return r.extractHostInfo(it)
host, err := r.extractHostInfo(it)
if err != nil {
return nil, err
}

if host.invalidConnectAddr() {
host.SetConnectAddress(r.session.control.GetHostInfo().ConnectAddress())
}

return host, nil
}

// Given an ip address and port, return a peer that matched the ip address
Expand Down Expand Up @@ -550,6 +564,8 @@ func (r *ringDescriber) GetHosts() ([]*HostInfo, string, error) {
localHost, err := r.GetLocalHostInfo()
if err != nil {
return r.prevHosts, r.prevPartitioner, err
} else if localHost.invalidConnectAddr() {
panic(fmt.Sprintf("unable to get localhost connect address: %v", localHost))
}

// Update our list of hosts by querying the cluster
Expand Down Expand Up @@ -602,13 +618,6 @@ func (r *ringDescriber) GetHostInfo(ip net.IP, port int) (*HostInfo, error) {
// If we are asking about the same node our control connection has a connection too
if controlHost.ConnectAddress().Equal(ip) {
host, err = r.GetLocalHostInfo()

// Always respect the provided control node address and disregard the ip address
// the cassandra node provides. We do this as we are already connected and have a
// known valid ip address. This insulates gocql from client connection issues stemming
// from node misconfiguration. For instance when a node is run from a container, by
// default the node will report its ip address as 127.0.0.1 which is typically invalid.
host.SetConnectAddress(ip)
} else {
host, err = r.GetPeerHostInfo(ip, port)
}
Expand All @@ -618,12 +627,20 @@ func (r *ringDescriber) GetHostInfo(ip net.IP, port int) (*HostInfo, error) {
return nil, err
}

// Apply host filter to the result
if r.session.cfg.HostFilter != nil && r.session.cfg.HostFilter.Accept(host) != true {
return nil, err
if controlHost.ConnectAddress().Equal(ip) {
// Always respect the provided control node address and disregard the ip address
// the cassandra node provides. We do this as we are already connected and have a
// known valid ip address. This insulates gocql from client connection issues stemming
// from node misconfiguration. For instance when a node is run from a container, by
// default the node will report its ip address as 127.0.0.1 which is typically invalid.
host.SetConnectAddress(ip)
}

if host.invalidConnectAddr() {
return nil, fmt.Errorf("host ConnectAddress invalid: %v", host)
}

return host, err
return host, nil
}

func (r *ringDescriber) refreshRing() error {
Expand Down
7 changes: 7 additions & 0 deletions ring.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package gocql

import (
"fmt"
"net"
"sync"
"sync/atomic"
Expand Down Expand Up @@ -53,6 +54,9 @@ func (r *ring) allHosts() []*HostInfo {
}

func (r *ring) addHost(host *HostInfo) bool {
if host.invalidConnectAddr() {
panic(fmt.Sprintf("invalid host: %v", host))
}
ip := host.ConnectAddress().String()

r.mu.Lock()
Expand All @@ -79,6 +83,9 @@ func (r *ring) addOrUpdate(host *HostInfo) *HostInfo {
}

func (r *ring) addHostIfMissing(host *HostInfo) (*HostInfo, bool) {
if host.invalidConnectAddr() {
panic(fmt.Sprintf("invalid host: %v", host))
}
ip := host.ConnectAddress().String()

r.mu.Lock()
Expand Down

0 comments on commit c9b7799

Please sign in to comment.