Skip to content

Commit

Permalink
Merge pull request apache#504 from Zariel/ring-discovery-proxy
Browse files Browse the repository at this point in the history
host_source: use system.local rpc_address
  • Loading branch information
Zariel committed Oct 23, 2015
2 parents 8041a37 + a70d1ac commit 9935df5
Show file tree
Hide file tree
Showing 4 changed files with 259 additions and 48 deletions.
141 changes: 141 additions & 0 deletions cassandra_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"bytes"
"flag"
"fmt"
"io"
"log"
"math"
"math/big"
Expand Down Expand Up @@ -2288,3 +2289,143 @@ func TestUDF(t *testing.T) {
t.Fatal(err)
}
}

func TestDiscoverViaProxy(t *testing.T) {
// This (complicated) test tests that when the driver is given an initial host
// that is infact a proxy it discovers the rest of the ring behind the proxy
// and does not store the proxies address as a host in its connection pool.
// See https://github.com/gocql/gocql/issues/481
proxy, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatalf("unable to create proxy listener: %v", err)
}

var (
wg sync.WaitGroup
mu sync.Mutex
proxyConns []net.Conn
closed bool
)

go func(wg *sync.WaitGroup) {
cassandraAddr := JoinHostPort(clusterHosts[0], 9042)

cassandra := func() (net.Conn, error) {
return net.Dial("tcp", cassandraAddr)
}

proxyFn := func(wg *sync.WaitGroup, from, to net.Conn) {
defer wg.Done()

_, err := io.Copy(to, from)
if err != nil {
mu.Lock()
if !closed {
t.Error(err)
}
mu.Unlock()
}
}

// handle dials cassandra and then proxies requests and reponsess. It waits
// for both the read and write side of the TCP connection to close before
// returning.
handle := func(conn net.Conn) error {
defer conn.Close()

cass, err := cassandra()
if err != nil {
return err
}

mu.Lock()
proxyConns = append(proxyConns, cass)
mu.Unlock()

defer cass.Close()

var wg sync.WaitGroup
wg.Add(1)
go proxyFn(&wg, conn, cass)

wg.Add(1)
go proxyFn(&wg, cass, conn)

wg.Wait()

return nil
}

for {
// proxy just accepts connections and then proxies them to cassandra,
// it runs until it is closed.
conn, err := proxy.Accept()
if err != nil {
mu.Lock()
if !closed {
t.Error(err)
}
mu.Unlock()
return
}

mu.Lock()
proxyConns = append(proxyConns, conn)
mu.Unlock()

wg.Add(1)
go func(conn net.Conn) {
defer wg.Done()

if err := handle(conn); err != nil {
t.Error(err)
return
}
}(conn)
}
}(&wg)

defer wg.Wait()

proxyAddr := proxy.Addr().String()

cluster := createCluster()
cluster.DiscoverHosts = true
cluster.NumConns = 1
cluster.Discovery.Sleep = 100 * time.Millisecond
// initial host is the proxy address
cluster.Hosts = []string{proxyAddr}

session := createSessionFromCluster(cluster, t)
defer session.Close()

if !session.hostSource.localHasRpcAddr {
t.Skip("Target cluster does not have rpc_address in system.local.")
goto close
}

// we shouldnt need this but to be safe
time.Sleep(1 * time.Second)

session.pool.mu.RLock()
for _, host := range clusterHosts {
if _, ok := session.pool.hostConnPools[host]; !ok {
t.Errorf("missing host in pool after discovery: %q", host)
}
}
session.pool.mu.RUnlock()

close:
if err := proxy.Close(); err != nil {
t.Log(err)
}

mu.Lock()
closed = true
for _, conn := range proxyConns {
if err := conn.Close(); err != nil {
t.Log(err)
}
}
mu.Unlock()
}
24 changes: 18 additions & 6 deletions control.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ func createControlConn(session *Session) *controlConn {
}

control.conn.Store((*Conn)(nil))
control.reconnect()
go control.heartBeat()

return control
Expand Down Expand Up @@ -55,14 +54,14 @@ func (c *controlConn) heartBeat() {
}

reconn:
c.reconnect()
time.Sleep(5 * time.Second)
c.reconnect(true)
// time.Sleep(5 * time.Second)
continue

}
}

func (c *controlConn) reconnect() {
func (c *controlConn) reconnect(refreshring bool) {
if !atomic.CompareAndSwapUint64(&c.connecting, 0, 1) {
return
}
Expand Down Expand Up @@ -101,6 +100,10 @@ func (c *controlConn) reconnect() {
if oldConn != nil {
oldConn.Close()
}

if refreshring {
c.session.hostSource.refreshRing()
}
}

func (c *controlConn) HandleError(conn *Conn, err error, closed bool) {
Expand All @@ -113,7 +116,7 @@ func (c *controlConn) HandleError(conn *Conn, err error, closed bool) {
return
}

c.reconnect()
c.reconnect(true)
}

func (c *controlConn) writeFrame(w frameWriter) (frame, error) {
Expand Down Expand Up @@ -146,7 +149,7 @@ func (c *controlConn) query(statement string, values ...interface{}) (iter *Iter

connectAttempts++

c.reconnect()
c.reconnect(false)
continue
}

Expand Down Expand Up @@ -212,6 +215,15 @@ func (c *controlConn) awaitSchemaAgreement() (err error) {
// not exported
return errors.New("gocql: cluster schema versions not consistent")
}

func (c *controlConn) addr() string {
conn := c.conn.Load().(*Conn)
if conn == nil {
return ""
}
return conn.addr
}

func (c *controlConn) close() {
// TODO: handle more gracefully
close(c.quit)
Expand Down
115 changes: 78 additions & 37 deletions host_source.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package gocql

import (
"fmt"
"log"
"net"
"time"
Expand All @@ -14,6 +15,10 @@ type HostInfo struct {
Tokens []string
}

func (h HostInfo) String() string {
return fmt.Sprintf("[hostinfo peer=%q data_centre=%q rack=%q host_id=%q num_tokens=%d]", h.Peer, h.DataCenter, h.Rack, h.HostId, len(h.Tokens))
}

// Polls system.peers at a specific interval to find new hosts
type ringDescriber struct {
dcFilter string
Expand All @@ -22,46 +27,78 @@ type ringDescriber struct {
prevPartitioner string
session *Session
closeChan chan bool
// indicates that we can use system.local to get the connections remote address
localHasRpcAddr bool
}

func checkSystemLocal(control *controlConn) (bool, error) {
iter := control.query("SELECT rpc_address FROM system.local")
if err := iter.err; err != nil {
if errf, ok := err.(*errorFrame); ok {
if errf.code == errSyntax {
return false, nil
}
}

return false, err
}

return true, nil
}

func (r *ringDescriber) GetHosts() (hosts []HostInfo, partitioner string, err error) {
// we need conn to be the same because we need to query system.peers and system.local
// on the same node to get the whole cluster

iter := r.session.control.query("SELECT data_center, rack, host_id, tokens, partitioner FROM system.local")
if iter == nil {
return r.prevHosts, r.prevPartitioner, nil
}
const (
legacyLocalQuery = "SELECT data_center, rack, host_id, tokens, partitioner FROM system.local"
// only supported in 2.2.0, 2.1.6, 2.0.16
localQuery = "SELECT rpc_address, data_center, rack, host_id, tokens, partitioner FROM system.local"
)

var localHost HostInfo
if r.localHasRpcAddr {
iter := r.session.control.query(localQuery)
if iter == nil {
return r.prevHosts, r.prevPartitioner, nil
}

conn := r.session.pool.Pick(nil)
if conn == nil {
return r.prevHosts, r.prevPartitioner, nil
}
iter.Scan(&localHost.Peer, &localHost.DataCenter, &localHost.Rack,
&localHost.HostId, &localHost.Tokens, &partitioner)

host := HostInfo{}
iter.Scan(&host.DataCenter, &host.Rack, &host.HostId, &host.Tokens, &partitioner)
if err = iter.Close(); err != nil {
return nil, "", err
}
} else {
iter := r.session.control.query(legacyLocalQuery)
if iter == nil {
return r.prevHosts, r.prevPartitioner, nil
}

if err = iter.Close(); err != nil {
return nil, "", err
}
iter.Scan(&localHost.DataCenter, &localHost.Rack, &localHost.HostId, &localHost.Tokens, &partitioner)

addr, _, err := net.SplitHostPort(conn.Address())
if err != nil {
// this should not happen, ever, as this is the address that was dialed by conn, here
// a panic makes sense, please report a bug if it occurs.
panic(err)
}
if err = iter.Close(); err != nil {
return nil, "", err
}

host.Peer = addr
addr, _, err := net.SplitHostPort(r.session.control.addr())
if err != nil {
// this should not happen, ever, as this is the address that was dialed by conn, here
// a panic makes sense, please report a bug if it occurs.
panic(err)
}

localHost.Peer = addr
}

hosts = []HostInfo{host}
hosts = []HostInfo{localHost}

iter = r.session.control.query("SELECT peer, data_center, rack, host_id, tokens FROM system.peers")
iter := r.session.control.query("SELECT peer, data_center, rack, host_id, tokens FROM system.peers")
if iter == nil {
return r.prevHosts, r.prevPartitioner, nil
}

host = HostInfo{}
host := HostInfo{}
for iter.Scan(&host.Peer, &host.DataCenter, &host.Rack, &host.HostId, &host.Tokens) {
if r.matchFilter(&host) {
hosts = append(hosts, host)
Expand Down Expand Up @@ -92,28 +129,32 @@ func (r *ringDescriber) matchFilter(host *HostInfo) bool {
return true
}

func (h *ringDescriber) run(sleep time.Duration) {
func (r *ringDescriber) refreshRing() {
// if we have 0 hosts this will return the previous list of hosts to
// attempt to reconnect to the cluster otherwise we would never find
// downed hosts again, could possibly have an optimisation to only
// try to add new hosts if GetHosts didnt error and the hosts didnt change.
hosts, partitioner, err := r.GetHosts()
if err != nil {
log.Println("RingDescriber: unable to get ring topology:", err)
return
}

r.session.pool.SetHosts(hosts)
r.session.pool.SetPartitioner(partitioner)
}

func (r *ringDescriber) run(sleep time.Duration) {
if sleep == 0 {
sleep = 30 * time.Second
}

for {
// if we have 0 hosts this will return the previous list of hosts to
// attempt to reconnect to the cluster otherwise we would never find
// downed hosts again, could possibly have an optimisation to only
// try to add new hosts if GetHosts didnt error and the hosts didnt change.
hosts, partitioner, err := h.GetHosts()
if err != nil {
log.Println("RingDescriber: unable to get ring topology:", err)
continue
}

h.session.pool.SetHosts(hosts)
h.session.pool.SetPartitioner(partitioner)
r.refreshRing()

select {
case <-time.After(sleep):
case <-h.closeChan:
case <-r.closeChan:
return
}
}
Expand Down
Loading

0 comments on commit 9935df5

Please sign in to comment.