diff --git a/cassandra_test.go b/cassandra_test.go index 9277cb0c1..e74672c87 100644 --- a/cassandra_test.go +++ b/cassandra_test.go @@ -1019,6 +1019,14 @@ func TestBatchQueryInfo(t *testing.T) { } } +func getRandomConn(t *testing.T, session *Session) *Conn { + conn := session.getConn() + if conn == nil { + t.Fatal("unable to get a connection") + } + return conn +} + func injectInvalidPreparedStatement(t *testing.T, session *Session, table string) (string, *Conn) { if err := createTable(session, `CREATE TABLE gocql_test.`+table+` ( foo varchar, @@ -1029,7 +1037,8 @@ func injectInvalidPreparedStatement(t *testing.T, session *Session, table string } stmt := "INSERT INTO " + table + " (foo, bar) VALUES (?, 7)" - _, conn := session.pool.Pick(nil) + + conn := getRandomConn(t, session) flight := new(inflightPrepare) key := session.stmtsLRU.keyFor(conn.addr, "", stmt) @@ -1060,7 +1069,7 @@ func injectInvalidPreparedStatement(t *testing.T, session *Session, table string func TestPrepare_MissingSchemaPrepare(t *testing.T) { s := createSession(t) - _, conn := s.pool.Pick(nil) + conn := getRandomConn(t, s) defer s.Close() insertQry := &Query{stmt: "INSERT INTO invalidschemaprep (val) VALUES (?)", values: []interface{}{5}, cons: s.cons, @@ -1108,7 +1117,7 @@ func TestQueryInfo(t *testing.T) { session := createSession(t) defer session.Close() - _, conn := session.pool.Pick(nil) + conn := getRandomConn(t, session) info, err := conn.prepareStatement("SELECT release_version, host_id FROM system.local WHERE key = ?", nil) if err != nil { @@ -1982,18 +1991,7 @@ func TestNegativeStream(t *testing.T) { session := createSession(t) defer session.Close() - var conn *Conn - for i := 0; i < 5; i++ { - if conn != nil { - break - } - - _, conn = session.pool.Pick(nil) - } - - if conn == nil { - t.Fatal("no connections available in the pool") - } + conn := getRandomConn(t, session) const stream = -50 writer := frameWriterFunc(func(f *framer, streamID int) error { diff --git a/cluster.go b/cluster.go index ed18ce94e..d311f3d40 100644 --- a/cluster.go +++ b/cluster.go @@ -16,24 +16,10 @@ type PoolConfig struct { // HostSelectionPolicy sets the policy for selecting which host to use for a // given query (default: RoundRobinHostPolicy()) HostSelectionPolicy HostSelectionPolicy - - // ConnSelectionPolicy sets the policy factory for selecting a connection to use for - // each host for a query (default: RoundRobinConnPolicy()) - ConnSelectionPolicy func() ConnSelectionPolicy } func (p PoolConfig) buildPool(session *Session) *policyConnPool { - hostSelection := p.HostSelectionPolicy - if hostSelection == nil { - hostSelection = RoundRobinHostPolicy() - } - - connSelection := p.ConnSelectionPolicy - if connSelection == nil { - connSelection = RoundRobinConnPolicy() - } - - return newPolicyConnPool(session, hostSelection, connSelection) + return newPolicyConnPool(session) } type DiscoveryConfig struct { diff --git a/conn_test.go b/conn_test.go index af1002e57..6b95f3bb9 100644 --- a/conn_test.go +++ b/conn_test.go @@ -284,7 +284,6 @@ func TestPolicyConnPoolSSL(t *testing.T) { cluster := createTestSslCluster(srv.Address, defaultProto, true) cluster.PoolConfig.HostSelectionPolicy = RoundRobinHostPolicy() - cluster.PoolConfig.ConnSelectionPolicy = RoundRobinConnPolicy() db, err := cluster.CreateSession() if err != nil { diff --git a/connectionpool.go b/connectionpool.go index f3c0ca1bd..a286de8bb 100644 --- a/connectionpool.go +++ b/connectionpool.go @@ -14,6 +14,7 @@ import ( "math/rand" "net" "sync" + "sync/atomic" "time" ) @@ -65,8 +66,6 @@ type policyConnPool struct { keyspace string mu sync.RWMutex - hostPolicy HostSelectionPolicy - connPolicy func() ConnSelectionPolicy hostConnPools map[string]*hostConnPool endpoints []string @@ -99,17 +98,13 @@ func connConfig(session *Session) (*ConnConfig, error) { }, nil } -func newPolicyConnPool(session *Session, hostPolicy HostSelectionPolicy, - connPolicy func() ConnSelectionPolicy) *policyConnPool { - +func newPolicyConnPool(session *Session) *policyConnPool { // create the pool pool := &policyConnPool{ session: session, port: session.cfg.Port, numConns: session.cfg.NumConns, keyspace: session.cfg.Keyspace, - hostPolicy: hostPolicy, - connPolicy: connPolicy, hostConnPools: map[string]*hostConnPool{}, } @@ -150,7 +145,6 @@ func (p *policyConnPool) SetHosts(hosts []*HostInfo) { p.port, p.numConns, p.keyspace, - p.connPolicy(), ) }(host) } @@ -170,13 +164,6 @@ func (p *policyConnPool) SetHosts(hosts []*HostInfo) { delete(p.hostConnPools, addr) go pool.Close() } - - // update the policy - p.hostPolicy.SetHosts(hosts) -} - -func (p *policyConnPool) SetPartitioner(partitioner string) { - p.hostPolicy.SetPartitioner(partitioner) } func (p *policyConnPool) Size() int { @@ -197,41 +184,10 @@ func (p *policyConnPool) getPool(addr string) (pool *hostConnPool, ok bool) { return } -func (p *policyConnPool) Pick(qry *Query) (SelectedHost, *Conn) { - nextHost := p.hostPolicy.Pick(qry) - - var ( - host SelectedHost - conn *Conn - ) - - p.mu.RLock() - defer p.mu.RUnlock() - for conn == nil { - host = nextHost() - if host == nil { - break - } else if host.Info() == nil { - panic(fmt.Sprintf("policy %T returned no host info: %+v", p.hostPolicy, host)) - } - - pool, ok := p.hostConnPools[host.Info().Peer()] - if !ok { - continue - } - - conn = pool.Pick(qry) - } - return host, conn -} - func (p *policyConnPool) Close() { p.mu.Lock() defer p.mu.Unlock() - // remove the hosts from the policy - p.hostPolicy.SetHosts(nil) - // close the pools for addr, pool := range p.hostConnPools { delete(p.hostConnPools, addr) @@ -249,7 +205,6 @@ func (p *policyConnPool) addHost(host *HostInfo) { host.Port(), // TODO: if port == 0 use pool.port? p.numConns, p.keyspace, - p.connPolicy(), ) p.hostConnPools[host.Peer()] = pool @@ -257,17 +212,10 @@ func (p *policyConnPool) addHost(host *HostInfo) { p.mu.Unlock() pool.fill() - - // update policy - // TODO: policy should not have conns, it should have hosts and return a host - // iter which the pool will use to serve conns - p.hostPolicy.AddHost(host) } func (p *policyConnPool) removeHost(addr string) { - p.hostPolicy.RemoveHost(addr) p.mu.Lock() - pool, ok := p.hostConnPools[addr] if !ok { p.mu.Unlock() @@ -301,12 +249,13 @@ type hostConnPool struct { addr string size int keyspace string - policy ConnSelectionPolicy // protection for conns, closed, filling mu sync.RWMutex conns []*Conn closed bool filling bool + + pos uint32 } func (h *hostConnPool) String() string { @@ -317,7 +266,7 @@ func (h *hostConnPool) String() string { } func newHostConnPool(session *Session, host *HostInfo, port, size int, - keyspace string, policy ConnSelectionPolicy) *hostConnPool { + keyspace string) *hostConnPool { pool := &hostConnPool{ session: session, @@ -326,7 +275,6 @@ func newHostConnPool(session *Session, host *HostInfo, port, size int, addr: JoinHostPort(host.Peer(), port), size: size, keyspace: keyspace, - policy: policy, conns: make([]*Conn, 0, size), filling: false, closed: false, @@ -337,16 +285,15 @@ func newHostConnPool(session *Session, host *HostInfo, port, size int, } // Pick a connection from this connection pool for the given query. -func (pool *hostConnPool) Pick(qry *Query) *Conn { +func (pool *hostConnPool) Pick() *Conn { pool.mu.RLock() + defer pool.mu.RUnlock() + if pool.closed { - pool.mu.RUnlock() return nil } size := len(pool.conns) - pool.mu.RUnlock() - if size < pool.size { // try to fill the pool go pool.fill() @@ -356,7 +303,23 @@ func (pool *hostConnPool) Pick(qry *Query) *Conn { } } - return pool.policy.Pick(qry) + pos := int(atomic.AddUint32(&pool.pos, 1) - 1) + + var ( + leastBusyConn *Conn + streamsAvailable int + ) + + // find the conn which has the most available streams, this is racy + for i := 0; i < size; i++ { + conn := pool.conns[(pos+i)%size] + if streams := conn.AvailableStreams(); streams > streamsAvailable { + leastBusyConn = conn + streamsAvailable = streams + } + } + + return leastBusyConn } //Size returns the number of connections currently active in the pool @@ -543,10 +506,6 @@ func (pool *hostConnPool) connect() (err error) { pool.conns = append(pool.conns, conn) - conns := make([]*Conn, len(pool.conns)) - copy(conns, pool.conns) - pool.policy.SetConns(conns) - return nil } @@ -573,11 +532,6 @@ func (pool *hostConnPool) HandleError(conn *Conn, err error, closed bool) { // remove the connection, not preserving order pool.conns[i], pool.conns = pool.conns[len(pool.conns)-1], pool.conns[:len(pool.conns)-1] - // update the policy - conns := make([]*Conn, len(pool.conns)) - copy(conns, pool.conns) - pool.policy.SetConns(conns) - // lost a connection, so fill the pool go pool.fill() break @@ -590,9 +544,6 @@ func (pool *hostConnPool) drainLocked() { conns := pool.conns pool.conns = nil - // update the policy - pool.policy.SetConns(nil) - // close the connections for _, conn := range conns { conn.Close() diff --git a/control.go b/control.go index 968587f6d..160683c2b 100644 --- a/control.go +++ b/control.go @@ -238,14 +238,14 @@ func (c *controlConn) reconnect(refreshring bool) { // TODO: should have our own roundrobbin for hosts so that we can try each // in succession and guantee that we get a different host each time. if newConn == nil { - _, conn := c.session.pool.Pick(nil) - if conn == nil { + host := c.session.ring.rrHost() + if host == nil { c.connect(c.session.ring.endpoints) return } var err error - newConn, err = c.session.connect(conn.addr, c, conn.host) + newConn, err = c.session.connect(host.Peer(), c, host) if err != nil { // TODO: add log handler for things like this return diff --git a/events.go b/events.go index ab8de8fef..7304647bb 100644 --- a/events.go +++ b/events.go @@ -201,6 +201,7 @@ func (s *Session) handleNewNode(host net.IP, port int, waitForBinary bool) { } s.pool.addHost(hostInfo) + s.policy.AddHost(hostInfo) hostInfo.setState(NodeUp) if s.control != nil { @@ -222,6 +223,7 @@ func (s *Session) handleRemovedNode(ip net.IP, port int) { } host.setState(NodeDown) + s.policy.RemoveHost(addr) s.pool.removeHost(addr) s.ring.removeHost(addr) @@ -251,6 +253,7 @@ func (s *Session) handleNodeUp(ip net.IP, port int, waitForBinary bool) { host.setPort(port) s.pool.hostUp(host) + s.policy.HostUp(host) host.setState(NodeUp) return } @@ -270,5 +273,6 @@ func (s *Session) handleNodeDown(ip net.IP, port int) { } host.setState(NodeDown) + s.policy.HostDown(addr) s.pool.hostDown(addr) } diff --git a/host_source.go b/host_source.go index b2bd8965b..29ab4d452 100644 --- a/host_source.go +++ b/host_source.go @@ -390,6 +390,6 @@ func (r *ringDescriber) refreshRing() error { } } - r.session.pool.SetPartitioner(partitioner) + r.session.metadata.setPartitioner(partitioner) return nil } diff --git a/policies.go b/policies.go index 8a647f342..f03f9f482 100644 --- a/policies.go +++ b/policies.go @@ -162,17 +162,17 @@ func (s *SimpleRetryPolicy) Attempt(q RetryableQuery) bool { type HostStateNotifier interface { AddHost(host *HostInfo) RemoveHost(addr string) - // TODO(zariel): add host up/down + HostUp(host *HostInfo) + HostDown(addr string) } // HostSelectionPolicy is an interface for selecting // the most appropriate host to execute a given query. type HostSelectionPolicy interface { HostStateNotifier - SetHosts SetPartitioner //Pick returns an iteration function over selected hosts - Pick(*Query) NextHost + Pick(ExecutableQuery) NextHost } // SelectedHost is an interface returned when picking a host from a host @@ -182,6 +182,14 @@ type SelectedHost interface { Mark(error) } +type selectedHost HostInfo + +func (host *selectedHost) Info() *HostInfo { + return (*HostInfo)(host) +} + +func (host *selectedHost) Mark(err error) {} + // NextHost is an iteration function over picked hosts type NextHost func() SelectedHost @@ -197,15 +205,11 @@ type roundRobinHostPolicy struct { mu sync.RWMutex } -func (r *roundRobinHostPolicy) SetHosts(hosts []*HostInfo) { - r.hosts.set(hosts) -} - func (r *roundRobinHostPolicy) SetPartitioner(partitioner string) { // noop } -func (r *roundRobinHostPolicy) Pick(qry *Query) NextHost { +func (r *roundRobinHostPolicy) Pick(qry ExecutableQuery) NextHost { // i is used to limit the number of attempts to find a host // to the number of hosts known to this policy var i int @@ -223,7 +227,7 @@ func (r *roundRobinHostPolicy) Pick(qry *Query) NextHost { } host := hosts[(pos)%uint32(len(hosts))] i++ - return selectedRoundRobinHost{host} + return (*selectedHost)(host) } } @@ -235,18 +239,12 @@ func (r *roundRobinHostPolicy) RemoveHost(addr string) { r.hosts.remove(addr) } -// selectedRoundRobinHost is a host returned by the roundRobinHostPolicy and -// implements the SelectedHost interface -type selectedRoundRobinHost struct { - info *HostInfo +func (r *roundRobinHostPolicy) HostUp(host *HostInfo) { + r.AddHost(host) } -func (host selectedRoundRobinHost) Info() *HostInfo { - return host.info -} - -func (host selectedRoundRobinHost) Mark(err error) { - // noop +func (r *roundRobinHostPolicy) HostDown(addr string) { + r.RemoveHost(addr) } // TokenAwareHostPolicy is a token aware host selection policy, where hosts are @@ -264,18 +262,6 @@ type tokenAwareHostPolicy struct { fallback HostSelectionPolicy } -func (t *tokenAwareHostPolicy) SetHosts(hosts []*HostInfo) { - t.hosts.set(hosts) - - t.mu.Lock() - defer t.mu.Unlock() - - // always update the fallback - t.fallback.SetHosts(hosts) - - t.resetTokenRing() -} - func (t *tokenAwareHostPolicy) SetPartitioner(partitioner string) { t.mu.Lock() defer t.mu.Unlock() @@ -299,12 +285,21 @@ func (t *tokenAwareHostPolicy) AddHost(host *HostInfo) { func (t *tokenAwareHostPolicy) RemoveHost(addr string) { t.hosts.remove(addr) + t.fallback.RemoveHost(addr) t.mu.Lock() t.resetTokenRing() t.mu.Unlock() } +func (t *tokenAwareHostPolicy) HostUp(host *HostInfo) { + t.AddHost(host) +} + +func (t *tokenAwareHostPolicy) HostDown(addr string) { + t.RemoveHost(addr) +} + func (t *tokenAwareHostPolicy) resetTokenRing() { if t.partitioner == "" { // partitioner not yet set @@ -323,14 +318,9 @@ func (t *tokenAwareHostPolicy) resetTokenRing() { t.tokenRing = tokenRing } -func (t *tokenAwareHostPolicy) Pick(qry *Query) NextHost { +func (t *tokenAwareHostPolicy) Pick(qry ExecutableQuery) NextHost { if qry == nil { return t.fallback.Pick(qry) - } else if qry.binding != nil && len(qry.values) == 0 { - // If this query was created using session.Bind we wont have the query - // values yet, so we have to pass down to the next policy. - // TODO: Remove this and handle this case - return t.fallback.Pick(qry) } routingKey, err := qry.GetRoutingKey() @@ -359,7 +349,7 @@ func (t *tokenAwareHostPolicy) Pick(qry *Query) NextHost { return func() SelectedHost { if !hostReturned { hostReturned = true - return selectedTokenAwareHost{host} + return (*selectedHost)(host) } // fallback @@ -378,20 +368,6 @@ func (t *tokenAwareHostPolicy) Pick(qry *Query) NextHost { } } -// selectedTokenAwareHost is a host returned by the tokenAwareHostPolicy and -// implements the SelectedHost interface -type selectedTokenAwareHost struct { - info *HostInfo -} - -func (host selectedTokenAwareHost) Info() *HostInfo { - return host.info -} - -func (host selectedTokenAwareHost) Mark(err error) { - // noop -} - // HostPoolHostPolicy is a host policy which uses the bitly/go-hostpool library // to distribute queries between hosts and prevent sending queries to // unresponsive hosts. When creating the host pool that is passed to the policy @@ -466,11 +442,19 @@ func (r *hostPoolHostPolicy) RemoveHost(addr string) { r.hp.SetHosts(hosts) } +func (r *hostPoolHostPolicy) HostUp(host *HostInfo) { + r.AddHost(host) +} + +func (r *hostPoolHostPolicy) HostDown(addr string) { + r.RemoveHost(addr) +} + func (r *hostPoolHostPolicy) SetPartitioner(partitioner string) { // noop } -func (r *hostPoolHostPolicy) Pick(qry *Query) NextHost { +func (r *hostPoolHostPolicy) Pick(qry ExecutableQuery) NextHost { return func() SelectedHost { r.mu.RLock() defer r.mu.RUnlock() @@ -516,56 +500,3 @@ func (host selectedHostPoolHost) Mark(err error) { host.hostR.Mark(err) } - -//ConnSelectionPolicy is an interface for selecting an -//appropriate connection for executing a query -type ConnSelectionPolicy interface { - SetConns(conns []*Conn) - Pick(*Query) *Conn -} - -type roundRobinConnPolicy struct { - // pos is still used to evenly distribute queries amongst connections. - pos uint32 - conns atomic.Value // *[]*Conn -} - -func RoundRobinConnPolicy() func() ConnSelectionPolicy { - return func() ConnSelectionPolicy { - p := &roundRobinConnPolicy{} - var conns []*Conn - p.conns.Store(&conns) - return p - } -} - -func (r *roundRobinConnPolicy) SetConns(conns []*Conn) { - // NOTE: we do not need to lock here due to the conneciton pool is already - // holding its own mutex over the conn seleciton policy - r.conns.Store(&conns) -} - -func (r *roundRobinConnPolicy) Pick(qry *Query) *Conn { - conns := *(r.conns.Load().(*[]*Conn)) - if len(conns) == 0 { - return nil - } - - pos := int(atomic.AddUint32(&r.pos, 1) - 1) - - var ( - leastBusyConn *Conn - streamsAvailable int - ) - - // find the conn which has the most available streams, this is racy - for i := 0; i < len(conns); i++ { - conn := conns[(pos+i)%len(conns)] - if streams := conn.AvailableStreams(); streams > streamsAvailable { - leastBusyConn = conn - streamsAvailable = streams - } - } - - return leastBusyConn -} diff --git a/policies_test.go b/policies_test.go index fee4a6ee3..f2cc54367 100644 --- a/policies_test.go +++ b/policies_test.go @@ -6,7 +6,6 @@ package gocql import ( "fmt" - "github.com/gocql/gocql/internal/streams" "testing" "github.com/hailocab/go-hostpool" @@ -16,12 +15,14 @@ import ( func TestRoundRobinHostPolicy(t *testing.T) { policy := RoundRobinHostPolicy() - hosts := []*HostInfo{ + hosts := [...]*HostInfo{ {hostId: "0"}, {hostId: "1"}, } - policy.SetHosts(hosts) + for _, host := range hosts { + policy.AddHost(host) + } // interleaved iteration should always increment the host iterA := policy.Pick(nil) @@ -65,13 +66,15 @@ func TestTokenAwareHostPolicy(t *testing.T) { } // set the hosts - hosts := []*HostInfo{ + hosts := [...]*HostInfo{ {peer: "0", tokens: []string{"00"}}, {peer: "1", tokens: []string{"25"}}, {peer: "2", tokens: []string{"50"}}, {peer: "3", tokens: []string{"75"}}, } - policy.SetHosts(hosts) + for _, host := range hosts { + policy.AddHost(host) + } // the token ring is not setup without the partitioner, but the fallback // should work @@ -108,12 +111,14 @@ func TestTokenAwareHostPolicy(t *testing.T) { func TestHostPoolHostPolicy(t *testing.T) { policy := HostPoolHostPolicy(hostpool.New(nil)) - hosts := []*HostInfo{ + hosts := [...]*HostInfo{ {hostId: "0", peer: "0"}, {hostId: "1", peer: "1"}, } - policy.SetHosts(hosts) + for _, host := range hosts { + policy.AddHost(host) + } // the first host selected is actually at [1], but this is ok for RR // interleaved iteration should always increment the host @@ -143,35 +148,11 @@ func TestHostPoolHostPolicy(t *testing.T) { actualD.Mark(nil) } -// Tests of the round-robin connection selection policy implementation -func TestRoundRobinConnPolicy(t *testing.T) { - policy := RoundRobinConnPolicy()() - - conn0 := &Conn{streams: streams.New(1)} - conn1 := &Conn{streams: streams.New(1)} - conn := []*Conn{ - conn0, - conn1, - } - - policy.SetConns(conn) - - if actual := policy.Pick(nil); actual != conn0 { - t.Error("Expected conn1") - } - if actual := policy.Pick(nil); actual != conn1 { - t.Error("Expected conn0") - } - if actual := policy.Pick(nil); actual != conn0 { - t.Error("Expected conn1") - } -} - func TestRoundRobinNilHostInfo(t *testing.T) { policy := RoundRobinHostPolicy() host := &HostInfo{hostId: "host-1"} - policy.SetHosts([]*HostInfo{host}) + policy.AddHost(host) iter := policy.Pick(nil) next := iter() @@ -195,13 +176,15 @@ func TestRoundRobinNilHostInfo(t *testing.T) { func TestTokenAwareNilHostInfo(t *testing.T) { policy := TokenAwareHostPolicy(RoundRobinHostPolicy()) - hosts := []*HostInfo{ + hosts := [...]*HostInfo{ {peer: "0", tokens: []string{"00"}}, {peer: "1", tokens: []string{"25"}}, {peer: "2", tokens: []string{"50"}}, {peer: "3", tokens: []string{"75"}}, } - policy.SetHosts(hosts) + for _, host := range hosts { + policy.AddHost(host) + } policy.SetPartitioner("OrderedPartitioner") query := &Query{} @@ -218,8 +201,9 @@ func TestTokenAwareNilHostInfo(t *testing.T) { } // Empty the hosts to trigger the panic when using the fallback. - hosts = []*HostInfo{} - policy.SetHosts(hosts) + for _, host := range hosts { + policy.RemoveHost(host.Peer()) + } next = iter() if next != nil { diff --git a/query_executor.go b/query_executor.go new file mode 100644 index 000000000..1aa397506 --- /dev/null +++ b/query_executor.go @@ -0,0 +1,65 @@ +package gocql + +import ( + "time" +) + +type ExecutableQuery interface { + execute(conn *Conn) *Iter + attempt(time.Duration) + retryPolicy() RetryPolicy + GetRoutingKey() ([]byte, error) + RetryableQuery +} + +type queryExecutor struct { + pool *policyConnPool + policy HostSelectionPolicy +} + +func (q *queryExecutor) executeQuery(qry ExecutableQuery) (*Iter, error) { + rt := qry.retryPolicy() + hostIter := q.policy.Pick(qry) + + var iter *Iter + for hostResponse := hostIter(); hostResponse != nil; hostResponse = hostIter() { + host := hostResponse.Info() + if !host.IsUp() { + continue + } + + pool, ok := q.pool.getPool(host.Peer()) + if !ok { + continue + } + + conn := pool.Pick() + if conn == nil { + continue + } + + start := time.Now() + iter = qry.execute(conn) + + qry.attempt(time.Since(start)) + + // Update host + hostResponse.Mark(iter.err) + + // Exit for loop if the query was successful + if iter.err == nil { + return iter, nil + } + + if rt == nil || !rt.Attempt(qry) { + // What do here? Should we just return an error here? + break + } + } + + if iter == nil { + return nil, ErrNoConnections + } + + return iter, nil +} diff --git a/ring.go b/ring.go index fa1b3d3c3..43fe94e25 100644 --- a/ring.go +++ b/ring.go @@ -2,6 +2,7 @@ package gocql import ( "sync" + "sync/atomic" ) type ring struct { @@ -9,13 +10,27 @@ type ring struct { // to in the case it can not reach any of its hosts. They are also used to boot // strap the initial connection. endpoints []string + // hosts are the set of all hosts in the cassandra ring that we know of mu sync.RWMutex hosts map[string]*HostInfo + hostList []*HostInfo + pos uint32 + // TODO: we should store the ring metadata here also. } +func (r *ring) rrHost() *HostInfo { + // TODO: should we filter hosts that get used here? These hosts will be used + // for the control connection, should we also provide an iterator? + r.mu.RLock() + defer r.mu.RUnlock() + + pos := int(atomic.AddUint32(&r.pos, 1) - 1) + return r.hostList[pos%len(r.hostList)] +} + func (r *ring) getHost(addr string) *HostInfo { r.mu.RLock() host := r.hosts[addr] @@ -73,3 +88,18 @@ func (r *ring) removeHost(addr string) bool { r.mu.Unlock() return ok } + +type clusterMetadata struct { + mu sync.RWMutex + partitioner string +} + +func (c *clusterMetadata) setPartitioner(partitioner string) { + c.mu.RLock() + defer c.mu.RUnlock() + + if c.partitioner != partitioner { + // TODO: update other things now + c.partitioner = partitioner + } +} diff --git a/session.go b/session.go index 40396d70d..3a3ffc0f5 100644 --- a/session.go +++ b/session.go @@ -31,7 +31,6 @@ import ( // and automatically sets a default consinstency level on all operations // that do not have a consistency level set. type Session struct { - pool *policyConnPool cons Consistency pageSize int prefetch float64 @@ -39,11 +38,17 @@ type Session struct { schemaDescriber *schemaDescriber trace Tracer hostSource *ringDescriber - ring ring stmtsLRU *preparedLRU connCfg *ConnConfig + executor *queryExecutor + pool *policyConnPool + policy HostSelectionPolicy + + ring ring + metadata clusterMetadata + mu sync.RWMutex control *controlConn @@ -116,7 +121,17 @@ func NewSession(cfg ClusterConfig) (*Session, error) { closeChan: make(chan bool), } - s.pool = cfg.PoolConfig.buildPool(s) + pool := cfg.PoolConfig.buildPool(s) + if cfg.PoolConfig.HostSelectionPolicy == nil { + cfg.PoolConfig.HostSelectionPolicy = RoundRobinHostPolicy() + } + + s.pool = pool + s.policy = cfg.PoolConfig.HostSelectionPolicy + s.executor = &queryExecutor{ + pool: pool, + policy: cfg.PoolConfig.HostSelectionPolicy, + } var hosts []*HostInfo @@ -284,44 +299,17 @@ func (s *Session) Closed() bool { } func (s *Session) executeQuery(qry *Query) *Iter { - // fail fast if s.Closed() { return &Iter{err: ErrSessionClosed} } - var iter *Iter - qry.attempts = 0 - qry.totalLatency = 0 - for { - host, conn := s.pool.Pick(qry) - - qry.attempts++ - //Assign the error unavailable to the iterator - if conn == nil { - if qry.rt == nil || !qry.rt.Attempt(qry) { - iter = &Iter{err: ErrNoConnections} - break - } - - continue - } - - t := time.Now() - iter = conn.executeQuery(qry) - qry.totalLatency += time.Now().Sub(t).Nanoseconds() - - // Update host - host.Mark(iter.err) - - // Exit for loop if the query was successful - if iter.err == nil { - break - } - - if qry.rt == nil || !qry.rt.Attempt(qry) { - break - } + iter, err := s.executor.executeQuery(qry) + if err != nil { + return &Iter{err: err} + } + if iter == nil { + panic("nil iter") } return iter @@ -348,6 +336,28 @@ func (s *Session) KeyspaceMetadata(keyspace string) (*KeyspaceMetadata, error) { return s.schemaDescriber.getSchema(keyspace) } +func (s *Session) getConn() *Conn { + hosts := s.ring.allHosts() + var conn *Conn + for _, host := range hosts { + if !host.IsUp() { + continue + } + + pool, ok := s.pool.getPool(host.Peer()) + if !ok { + continue + } + + conn = pool.Pick() + if conn != nil { + return conn + } + } + + return nil +} + // returns routing key indexes and type info func (s *Session) routingKeyInfo(stmt string) (*routingKeyInfo, error) { s.routingKeyInfoCache.mu.Lock() @@ -384,26 +394,23 @@ func (s *Session) routingKeyInfo(stmt string) (*routingKeyInfo, error) { partitionKey []*ColumnMetadata ) - // get the query info for the statement - host, conn := s.pool.Pick(nil) + conn := s.getConn() if conn == nil { - // no connections - inflight.err = ErrNoConnections - // don't cache this error - s.routingKeyInfoCache.Remove(stmt) + // TODO: better error? + inflight.err = errors.New("gocql: unable to fetch preapred info: no connection avilable") return nil, inflight.err } + // get the query info for the statement info, inflight.err = conn.prepareStatement(stmt, nil) if inflight.err != nil { // don't cache this error s.routingKeyInfoCache.Remove(stmt) - host.Mark(inflight.err) return nil, inflight.err } - // Mark host as OK - host.Mark(nil) + // TODO: it would be nice to mark hosts here but as we are not using the policies + // to fetch hosts we cant if info.request.colCount == 0 { // no arguments, no routing key, and no error @@ -455,6 +462,7 @@ func (s *Session) routingKeyInfo(stmt string) (*routingKeyInfo, error) { indexes: make([]int, size), types: make([]TypeInfo, size), } + for keyIndex, keyColumn := range partitionKey { // set an indicator for checking if the mapping is missing routingKeyInfo.indexes[keyIndex] = -1 @@ -482,6 +490,10 @@ func (s *Session) routingKeyInfo(stmt string) (*routingKeyInfo, error) { return routingKeyInfo, nil } +func (b *Batch) execute(conn *Conn) *Iter { + return conn.executeBatch(b) +} + func (s *Session) executeBatch(batch *Batch) *Iter { // fail fast if s.Closed() { @@ -495,45 +507,9 @@ func (s *Session) executeBatch(batch *Batch) *Iter { return &Iter{err: ErrTooManyStmts} } - var iter *Iter - batch.attempts = 0 - batch.totalLatency = 0 - for { - host, conn := s.pool.Pick(nil) - - batch.attempts++ - if conn == nil { - if batch.rt == nil || !batch.rt.Attempt(batch) { - // Assign the error unavailable and break loop - iter = &Iter{err: ErrNoConnections} - break - } - - continue - } - - if conn == nil { - iter = &Iter{err: ErrNoConnections} - break - } - - t := time.Now() - - iter = conn.executeBatch(batch) - - batch.totalLatency += time.Since(t).Nanoseconds() - // Exit loop if operation executed correctly - if iter.err == nil { - host.Mark(nil) - break - } - - // Mark host with error if returned from Close - host.Mark(iter.Close()) - - if batch.rt == nil || !batch.rt.Attempt(batch) { - break - } + iter, err := s.executor.executeQuery(batch) + if err != nil { + return &Iter{err: err} } return iter @@ -680,6 +656,20 @@ func (q *Query) RoutingKey(routingKey []byte) *Query { return q } +func (q *Query) execute(conn *Conn) *Iter { + return conn.executeQuery(q) +} + +func (q *Query) attempt(d time.Duration) { + q.attempts++ + q.totalLatency += d.Nanoseconds() + // TODO: track latencies per host and things as well instead of just total +} + +func (q *Query) retryPolicy() RetryPolicy { + return q.rt +} + // GetRoutingKey gets the routing key to use for routing this query. If // a routing key has not been explicitly set, then the routing key will // be constructed if possible using the keyspace's schema and the query @@ -689,6 +679,11 @@ func (q *Query) RoutingKey(routingKey []byte) *Query { func (q *Query) GetRoutingKey() ([]byte, error) { if q.routingKey != nil { return q.routingKey, nil + } else if q.binding != nil && len(q.values) == 0 { + // If this query was created using session.Bind we wont have the query + // values yet, so we have to pass down to the next policy. + // TODO: Remove this and handle this case + return nil, nil } // try to determine the routing key @@ -816,8 +811,7 @@ func (q *Query) NoSkipMetadata() *Query { // Exec executes the query without returning any rows. func (q *Query) Exec() error { - iter := q.Iter() - return iter.Close() + return q.Iter().Close() } func isUseStatement(stmt string) bool { @@ -1107,6 +1101,10 @@ func (b *Batch) Bind(stmt string, bind func(q *QueryInfo) ([]interface{}, error) b.Entries = append(b.Entries, BatchEntry{Stmt: stmt, binding: bind}) } +func (b *Batch) retryPolicy() RetryPolicy { + return b.rt +} + // RetryPolicy sets the retry policy to use when executing the batch operation func (b *Batch) RetryPolicy(r RetryPolicy) *Batch { b.rt = r @@ -1141,6 +1139,17 @@ func (b *Batch) DefaultTimestamp(enable bool) *Batch { return b } +func (b *Batch) attempt(d time.Duration) { + b.attempts++ + b.totalLatency += d.Nanoseconds() + // TODO: track latencies per host and things as well instead of just total +} + +func (b *Batch) GetRoutingKey() ([]byte, error) { + // TODO: use the first statement in the batch as the routing key? + return nil, nil +} + type BatchType byte const ( @@ -1285,7 +1294,7 @@ var ( ErrTooManyStmts = errors.New("too many statements") ErrUseStmt = errors.New("use statements aren't supported. Please see https://github.com/gocql/gocql for explaination.") ErrSessionClosed = errors.New("session has been closed") - ErrNoConnections = errors.New("no connections available") + ErrNoConnections = errors.New("qocql: no hosts available in the pool") ErrNoKeyspace = errors.New("no keyspace provided") ErrNoMetadata = errors.New("no metadata available") ) diff --git a/session_test.go b/session_test.go index c15a90562..75761e9f0 100644 --- a/session_test.go +++ b/session_test.go @@ -12,11 +12,16 @@ func TestSessionAPI(t *testing.T) { cfg := &ClusterConfig{} s := &Session{ - cfg: *cfg, - cons: Quorum, + cfg: *cfg, + cons: Quorum, + policy: RoundRobinHostPolicy(), } s.pool = cfg.PoolConfig.buildPool(s) + s.executor = &queryExecutor{ + pool: s.pool, + policy: s.policy, + } defer s.Close() s.SetConsistency(All)