diff --git a/common_test.go b/common_test.go index f071453c3..58acaf8f8 100644 --- a/common_test.go +++ b/common_test.go @@ -224,36 +224,31 @@ func staticAddressTranslator(newAddr net.IP, newPort int) AddressTranslator { } func assertTrue(t *testing.T, description string, value bool) { - t.Helper() if !value { - t.Fatalf("expected %s to be true", description) + t.Errorf("expected %s to be true", description) } } func assertEqual(t *testing.T, description string, expected, actual interface{}) { - t.Helper() if expected != actual { - t.Fatalf("expected %s to be (%+v) but was (%+v) instead", description, expected, actual) + t.Errorf("expected %s to be (%+v) but was (%+v) instead", description, expected, actual) } } func assertDeepEqual(t *testing.T, description string, expected, actual interface{}) { - t.Helper() if !reflect.DeepEqual(expected, actual) { - t.Fatalf("expected %s to be (%+v) but was (%+v) instead", description, expected, actual) + t.Errorf("expected %s to be (%+v) but was (%+v) instead", description, expected, actual) } } func assertNil(t *testing.T, description string, actual interface{}) { - t.Helper() if actual != nil { - t.Fatalf("expected %s to be (nil) but was (%+v) instead", description, actual) + t.Errorf("expected %s to be (nil) but was (%+v) instead", description, actual) } } func assertNotNil(t *testing.T, description string, actual interface{}) { - t.Helper() if actual == nil { - t.Fatalf("expected %s not to be (nil)", description) + t.Errorf("expected %s not to be (nil)", description) } } diff --git a/conn_test.go b/conn_test.go index 13605d1df..c7e5c4997 100644 --- a/conn_test.go +++ b/conn_test.go @@ -48,8 +48,8 @@ func TestApprove(t *testing.T) { func TestJoinHostPort(t *testing.T) { tests := map[string]string{ - "127.0.0.1:0": JoinHostPort("127.0.0.1", 0), - "127.0.0.1:1": JoinHostPort("127.0.0.1:1", 9142), + "127.0.0.1:0": JoinHostPort("127.0.0.1", 0), + "127.0.0.1:1": JoinHostPort("127.0.0.1:1", 9142), "[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:0": JoinHostPort("2001:0db8:85a3:0000:0000:8a2e:0370:7334", 0), "[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:1": JoinHostPort("[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:1", 9142), } diff --git a/control.go b/control.go index aa5cf3570..443b1ccfe 100644 --- a/control.go +++ b/control.go @@ -149,14 +149,14 @@ func hostInfo(addr string, defaultPort int) ([]*HostInfo, error) { } func shuffleHosts(hosts []*HostInfo) []*HostInfo { - shuffled := make([]*HostInfo, len(hosts)) - copy(shuffled, hosts) - mutRandr.Lock() - randr.Shuffle(len(hosts), func(i, j int) { - shuffled[i], shuffled[j] = shuffled[j], shuffled[i] - }) + perm := randr.Perm(len(hosts)) mutRandr.Unlock() + shuffled := make([]*HostInfo, len(hosts)) + + for i, host := range hosts { + shuffled[perm[i]] = host + } return shuffled } diff --git a/host_source.go b/host_source.go index f8ab3c109..ccc408889 100644 --- a/host_source.go +++ b/host_source.go @@ -110,7 +110,7 @@ type HostInfo struct { // TODO(zariel): reduce locking maybe, not all values will change, but to ensure // that we are thread safe use a mutex to access all fields. mu sync.RWMutex - hostname string + hostname string peer net.IP broadcastAddress net.IP listenAddress net.IP @@ -128,7 +128,7 @@ type HostInfo struct { clusterName string version cassVersion state nodeState - schemaVersion string + schemaVersion string tokens []string } diff --git a/marshal.go b/marshal.go index 0592457fc..644e9d886 100644 --- a/marshal.go +++ b/marshal.go @@ -1301,15 +1301,15 @@ func unmarshalDate(info TypeInfo, data []byte, value interface{}) error { *v = time.Unix(0, timestamp*int64(time.Millisecond)).In(time.UTC) return nil case *string: - if len(data) == 0 { - *v = "" - return nil - } - var origin uint32 = 1 << 31 - var current uint32 = binary.BigEndian.Uint32(data) - timestamp := (int64(current) - int64(origin)) * 86400000 + if len(data) == 0 { + *v = "" + return nil + } + var origin uint32 = 1 << 31 + var current uint32 = binary.BigEndian.Uint32(data) + timestamp := (int64(current) - int64(origin)) * 86400000 *v = time.Unix(0, timestamp*int64(time.Millisecond)).In(time.UTC).Format("2006-01-02") - return nil + return nil } return unmarshalErrorf("can not unmarshal %s into %T", info, value) } diff --git a/marshal_test.go b/marshal_test.go index 932d22638..025b36ba9 100644 --- a/marshal_test.go +++ b/marshal_test.go @@ -1871,58 +1871,58 @@ func TestReadCollectionSize(t *testing.T) { } tests := []struct { - name string - info CollectionType - data []byte - isError bool + name string + info CollectionType + data []byte + isError bool expectedSize int }{ { - name: "short read 0 proto 2", - info: listV2, - data: []byte{}, + name: "short read 0 proto 2", + info: listV2, + data: []byte{}, isError: true, }, { - name: "short read 1 proto 2", - info: listV2, - data: []byte{0x01}, + name: "short read 1 proto 2", + info: listV2, + data: []byte{0x01}, isError: true, }, { - name: "good read proto 2", - info: listV2, - data: []byte{0x01, 0x38}, + name: "good read proto 2", + info: listV2, + data: []byte{0x01, 0x38}, expectedSize: 0x0138, }, { - name: "short read 0 proto 3", - info: listV3, - data: []byte{}, + name: "short read 0 proto 3", + info: listV3, + data: []byte{}, isError: true, }, { - name: "short read 1 proto 3", - info: listV3, - data: []byte{0x01}, + name: "short read 1 proto 3", + info: listV3, + data: []byte{0x01}, isError: true, }, { - name: "short read 2 proto 3", - info: listV3, - data: []byte{0x01, 0x38}, + name: "short read 2 proto 3", + info: listV3, + data: []byte{0x01, 0x38}, isError: true, }, { - name: "short read 3 proto 3", - info: listV3, - data: []byte{0x01, 0x38, 0x42}, + name: "short read 3 proto 3", + info: listV3, + data: []byte{0x01, 0x38, 0x42}, isError: true, }, { - name: "good read proto 3", - info: listV3, - data: []byte{0x01, 0x38, 0x42, 0x22}, + name: "good read proto 3", + info: listV3, + data: []byte{0x01, 0x38, 0x42, 0x22}, expectedSize: 0x01384222, }, } diff --git a/policies.go b/policies.go index a45c1c8e0..63fe6b82e 100644 --- a/policies.go +++ b/policies.go @@ -424,14 +424,14 @@ func TokenAwareHostPolicy(fallback HostSelectionPolicy, opts ...func(*tokenAware // and the pointer in clusterMeta updated to point to the new value. type clusterMeta struct { // replicas is map[keyspace]map[token]hosts - replicas map[string]tokenRingReplicas + replicas map[string]map[token][]*HostInfo tokenRing *tokenRing } type tokenAwareHostPolicy struct { - fallback HostSelectionPolicy + fallback HostSelectionPolicy getKeyspaceMetadata func(keyspace string) (*KeyspaceMetadata, error) - getKeyspaceName func() string + getKeyspaceName func() string shuffleReplicas bool nonLocalReplicasFallback bool @@ -446,7 +446,7 @@ type tokenAwareHostPolicy struct { func (t *tokenAwareHostPolicy) Init(s *Session) { t.getKeyspaceMetadata = s.KeyspaceMetadata - t.getKeyspaceName = func() string { return s.cfg.Keyspace } + t.getKeyspaceName = func() string {return s.cfg.Keyspace} } func (t *tokenAwareHostPolicy) IsLocal(host *HostInfo) bool { @@ -465,14 +465,15 @@ func (t *tokenAwareHostPolicy) KeyspaceChanged(update KeyspaceUpdateEvent) { // It must be called with t.mu mutex locked. // meta must not be nil and it's replicas field will be updated. func (t *tokenAwareHostPolicy) updateReplicas(meta *clusterMeta, keyspace string) { - newReplicas := make(map[string]tokenRingReplicas, len(meta.replicas)) + newReplicas := make(map[string]map[token][]*HostInfo, len(meta.replicas)) ks, err := t.getKeyspaceMetadata(keyspace) if err == nil { strat := getStrategy(ks) if strat != nil { if meta != nil && meta.tokenRing != nil { - newReplicas[keyspace] = strat.replicaMap(meta.tokenRing) + hosts := t.hosts.get() + newReplicas[keyspace] = strat.replicaMap(hosts, meta.tokenRing.tokens) } } } @@ -593,6 +594,14 @@ func (m *clusterMeta) resetTokenRing(partitioner string, hosts []*HostInfo) { m.tokenRing = tokenRing } +func (m *clusterMeta) getReplicas(keyspace string, token token) ([]*HostInfo, bool) { + if m.replicas == nil { + return nil, false + } + replicas, ok := m.replicas[keyspace][token] + return replicas, ok +} + func (t *tokenAwareHostPolicy) Pick(qry ExecutableQuery) NextHost { if qry == nil { return t.fallback.Pick(qry) @@ -610,23 +619,22 @@ func (t *tokenAwareHostPolicy) Pick(qry ExecutableQuery) NextHost { return t.fallback.Pick(qry) } - token := meta.tokenRing.partitioner.Hash(routingKey) - ht := meta.replicas[qry.Keyspace()].replicasFor(token) + primaryEndpoint, token, endToken := meta.tokenRing.GetHostForPartitionKey(routingKey) + if primaryEndpoint == nil || endToken == nil { + return t.fallback.Pick(qry) + } - var replicas []*HostInfo - if ht == nil { - host, _ := meta.tokenRing.GetHostForToken(token) - replicas = []*HostInfo{host} + replicas, ok := meta.getReplicas(qry.Keyspace(), endToken) + if !ok { + replicas = []*HostInfo{primaryEndpoint} } else if t.shuffleReplicas { replicas = shuffleHosts(replicas) - } else { - replicas = ht.hosts } var ( fallbackIter NextHost - i, j int - remote []*HostInfo + i int + j int ) used := make(map[*HostInfo]bool, len(replicas)) @@ -635,23 +643,18 @@ func (t *tokenAwareHostPolicy) Pick(qry ExecutableQuery) NextHost { h := replicas[i] i++ - if !t.fallback.IsLocal(h) { - remote = append(remote, h) - continue - } - - if h.IsUp() { + if h.IsUp() && t.fallback.IsLocal(h) { used[h] = true return selectedHost{info: h, token: token} } } if t.nonLocalReplicasFallback { - for j < len(remote) { - h := remote[j] + for j < len(replicas) { + h := replicas[j] j++ - if h.IsUp() { + if h.IsUp() && !t.fallback.IsLocal(h) { used[h] = true return selectedHost{info: h, token: token} } @@ -666,11 +669,9 @@ func (t *tokenAwareHostPolicy) Pick(qry ExecutableQuery) NextHost { // filter the token aware selected hosts from the fallback hosts for fallbackHost := fallbackIter(); fallbackHost != nil; fallbackHost = fallbackIter() { if !used[fallbackHost.Info()] { - used[fallbackHost.Info()] = true return fallbackHost } } - return nil } } diff --git a/policies_test.go b/policies_test.go index b0d3acd35..7ede38fa0 100644 --- a/policies_test.go +++ b/policies_test.go @@ -84,22 +84,22 @@ func TestHostPolicy_TokenAware_SimpleStrategy(t *testing.T) { return &KeyspaceMetadata{ Name: keyspace, StrategyClass: "SimpleStrategy", - StrategyOptions: map[string]interface{}{ - "class": "SimpleStrategy", + StrategyOptions: map[string]interface{} { + "class": "SimpleStrategy", "replication_factor": 2, }, }, nil } - policy.KeyspaceChanged(KeyspaceUpdateEvent{Keyspace: keyspace}) + policy.KeyspaceChanged(KeyspaceUpdateEvent{Keyspace: "myKeyspace"}) // The SimpleStrategy above should generate the following replicas. // It's handy to have as reference here. - assertDeepEqual(t, "replicas", map[string]tokenRingReplicas{ + assertDeepEqual(t, "replicas", map[string]map[token][]*HostInfo{ "myKeyspace": { - {orderedToken("00"), []*HostInfo{hosts[0], hosts[1]}}, - {orderedToken("25"), []*HostInfo{hosts[1], hosts[2]}}, - {orderedToken("50"), []*HostInfo{hosts[2], hosts[3]}}, - {orderedToken("75"), []*HostInfo{hosts[3], hosts[0]}}, + orderedToken("00"): {hosts[0], hosts[1]}, + orderedToken("25"): {hosts[1], hosts[2]}, + orderedToken("50"): {hosts[2], hosts[3]}, + orderedToken("75"): {hosts[3], hosts[0]}, }, }, policyInternal.getMetadataReadOnly().replicas) @@ -179,7 +179,7 @@ func TestHostPolicy_RoundRobin_NilHostInfo(t *testing.T) { func TestHostPolicy_TokenAware_NilHostInfo(t *testing.T) { policy := TokenAwareHostPolicy(RoundRobinHostPolicy()) policyInternal := policy.(*tokenAwareHostPolicy) - policyInternal.getKeyspaceName = func() string { return "myKeyspace" } + policyInternal.getKeyspaceName = func() string {return "myKeyspace"} policyInternal.getKeyspaceMetadata = func(ks string) (*KeyspaceMetadata, error) { return nil, errors.New("not initialized") } @@ -196,7 +196,7 @@ func TestHostPolicy_TokenAware_NilHostInfo(t *testing.T) { policy.SetPartitioner("OrderedPartitioner") query := &Query{} - query.getKeyspace = func() string { return "myKeyspace" } + query.getKeyspace = func() string {return "myKeyspace"} query.RoutingKey([]byte("20")) iter := policy.Pick(query) @@ -426,20 +426,20 @@ func TestHostPolicy_DCAwareRR(t *testing.T) { } + // Tests of the token-aware host selection policy implementation with a // DC aware round-robin host selection policy fallback // with {"class": "NetworkTopologyStrategy", "a": 1, "b": 1, "c": 1} replication. -func TestHostPolicy_TokenAware(t *testing.T) { - const keyspace = "myKeyspace" +func TestHostPolicy_TokenAware_DCAwareRR(t *testing.T) { policy := TokenAwareHostPolicy(DCAwareRoundRobinPolicy("local")) policyInternal := policy.(*tokenAwareHostPolicy) - policyInternal.getKeyspaceName = func() string { return keyspace } + policyInternal.getKeyspaceName = func() string {return "myKeyspace"} policyInternal.getKeyspaceMetadata = func(ks string) (*KeyspaceMetadata, error) { return nil, errors.New("not initialized") } query := &Query{} - query.getKeyspace = func() string { return keyspace } + query.getKeyspace = func() string {return "myKeyspace"} iter := policy.Pick(nil) if iter == nil { @@ -482,16 +482,17 @@ func TestHostPolicy_TokenAware(t *testing.T) { policy.SetPartitioner("OrderedPartitioner") + policyInternal.getKeyspaceMetadata = func(keyspaceName string) (*KeyspaceMetadata, error) { - if keyspaceName != keyspace { + if keyspaceName != "myKeyspace" { return nil, fmt.Errorf("unknown keyspace: %s", keyspaceName) } return &KeyspaceMetadata{ - Name: keyspace, + Name: "myKeyspace", StrategyClass: "NetworkTopologyStrategy", - StrategyOptions: map[string]interface{}{ - "class": "NetworkTopologyStrategy", - "local": 1, + StrategyOptions: map[string]interface{} { + "class": "NetworkTopologyStrategy", + "local": 1, "remote1": 1, "remote2": 1, }, @@ -501,20 +502,20 @@ func TestHostPolicy_TokenAware(t *testing.T) { // The NetworkTopologyStrategy above should generate the following replicas. // It's handy to have as reference here. - assertDeepEqual(t, "replicas", map[string]tokenRingReplicas{ + assertDeepEqual(t, "replicas", map[string]map[token][]*HostInfo{ "myKeyspace": { - {orderedToken("05"), []*HostInfo{hosts[0], hosts[1], hosts[2]}}, - {orderedToken("10"), []*HostInfo{hosts[1], hosts[2], hosts[3]}}, - {orderedToken("15"), []*HostInfo{hosts[2], hosts[3], hosts[4]}}, - {orderedToken("20"), []*HostInfo{hosts[3], hosts[4], hosts[5]}}, - {orderedToken("25"), []*HostInfo{hosts[4], hosts[5], hosts[6]}}, - {orderedToken("30"), []*HostInfo{hosts[5], hosts[6], hosts[7]}}, - {orderedToken("35"), []*HostInfo{hosts[6], hosts[7], hosts[8]}}, - {orderedToken("40"), []*HostInfo{hosts[7], hosts[8], hosts[9]}}, - {orderedToken("45"), []*HostInfo{hosts[8], hosts[9], hosts[10]}}, - {orderedToken("50"), []*HostInfo{hosts[9], hosts[10], hosts[11]}}, - {orderedToken("55"), []*HostInfo{hosts[10], hosts[11], hosts[0]}}, - {orderedToken("60"), []*HostInfo{hosts[11], hosts[0], hosts[1]}}, + orderedToken("05"): {hosts[0], hosts[1], hosts[2]}, + orderedToken("10"): {hosts[1], hosts[2], hosts[3]}, + orderedToken("15"): {hosts[2], hosts[3], hosts[4]}, + orderedToken("20"): {hosts[3], hosts[4], hosts[5]}, + orderedToken("25"): {hosts[4], hosts[5], hosts[6]}, + orderedToken("30"): {hosts[5], hosts[6], hosts[7]}, + orderedToken("35"): {hosts[6], hosts[7], hosts[8]}, + orderedToken("40"): {hosts[7], hosts[8], hosts[9]}, + orderedToken("45"): {hosts[8], hosts[9], hosts[10]}, + orderedToken("50"): {hosts[9], hosts[10], hosts[11]}, + orderedToken("55"): {hosts[10], hosts[11], hosts[0]}, + orderedToken("60"): {hosts[11], hosts[0], hosts[1]}, }, }, policyInternal.getMetadataReadOnly().replicas) @@ -578,9 +579,9 @@ func TestHostPolicy_TokenAware_NetworkStrategy(t *testing.T) { return &KeyspaceMetadata{ Name: keyspace, StrategyClass: "NetworkTopologyStrategy", - StrategyOptions: map[string]interface{}{ - "class": "NetworkTopologyStrategy", - "local": 2, + StrategyOptions: map[string]interface{} { + "class": "NetworkTopologyStrategy", + "local": 2, "remote1": 2, "remote2": 2, }, diff --git a/session.go b/session.go index 48db95e73..1b527c0d1 100644 --- a/session.go +++ b/session.go @@ -761,7 +761,7 @@ func (qm *queryMetrics) latency() int64 { qm.l.Lock() var ( attempts int - latency int64 + latency int64 ) for _, metric := range qm.m { attempts += metric.Attempts @@ -1549,16 +1549,16 @@ func NewBatch(typ BatchType) *Batch { func (s *Session) NewBatch(typ BatchType) *Batch { s.mu.RLock() batch := &Batch{ - Type: typ, - rt: s.cfg.RetryPolicy, - serialCons: s.cfg.SerialConsistency, - observer: s.batchObserver, - session: s, - Cons: s.cons, - defaultTimestamp: s.cfg.DefaultTimestamp, - keyspace: s.cfg.Keyspace, - metrics: &queryMetrics{m: make(map[string]*hostMetrics)}, - spec: &NonSpeculativeExecution{}, + Type: typ, + rt: s.cfg.RetryPolicy, + serialCons: s.cfg.SerialConsistency, + observer: s.batchObserver, + session: s, + Cons: s.cons, + defaultTimestamp: s.cfg.DefaultTimestamp, + keyspace: s.cfg.Keyspace, + metrics: &queryMetrics{m: make(map[string]*hostMetrics)}, + spec: &NonSpeculativeExecution{}, } s.mu.RUnlock() diff --git a/token.go b/token.go index 92f0f7dca..8a3d8c60b 100644 --- a/token.go +++ b/token.go @@ -131,13 +131,10 @@ func (ht hostToken) String() string { type tokenRing struct { partitioner partitioner tokens []hostToken - hosts []*HostInfo } func newTokenRing(partitioner string, hosts []*HostInfo) (*tokenRing, error) { - tokenRing := &tokenRing{ - hosts: hosts, - } + tokenRing := &tokenRing{} if strings.HasSuffix(partitioner, "Murmur3Partitioner") { tokenRing.partitioner = murmur3Partitioner{} @@ -215,15 +212,15 @@ func (t *tokenRing) GetHostForToken(token token) (host *HostInfo, endToken token } // find the primary replica - p := sort.Search(len(t.tokens), func(i int) bool { + ringIndex := sort.Search(len(t.tokens), func(i int) bool { return !t.tokens[i].token.Less(token) }) - if p == len(t.tokens) { + if ringIndex == len(t.tokens) { // wrap around to the first in the ring - p = 0 + ringIndex = 0 } - v := t.tokens[p] + v := t.tokens[ringIndex] return v.host, v.token } diff --git a/topology.go b/topology.go index 008f4a7a1..6e4d2aa36 100644 --- a/topology.go +++ b/topology.go @@ -2,51 +2,12 @@ package gocql import ( "fmt" - "sort" "strconv" "strings" ) -type hostTokens struct { - token token - hosts []*HostInfo -} - -type tokenRingReplicas []hostTokens - -func (h tokenRingReplicas) Less(i, j int) bool { return h[i].token.Less(h[j].token) } -func (h tokenRingReplicas) Len() int { return len(h) } -func (h tokenRingReplicas) Swap(i, j int) { h[i], h[j] = h[j], h[i] } - -func (h tokenRingReplicas) replicasFor(t token) *hostTokens { - if len(h) == 0 { - return nil - } - - p := sort.Search(len(h), func(i int) bool { - return !h[i].token.Less(t) - }) - - // TODO: simplify this - if p < len(h) && h[p].token == t { - return &h[p] - } - - p-- - - if p >= len(h) { - // rollover - p = 0 - } else if p < 0 { - // rollunder - p = len(h) - 1 - } - - return &h[p] -} - type placementStrategy interface { - replicaMap(tokenRing *tokenRing) tokenRingReplicas + replicaMap(hosts []*HostInfo, tokens []hostToken) map[token][]*HostInfo replicationFactor(dc string) int } @@ -102,28 +63,19 @@ func (s *simpleStrategy) replicationFactor(dc string) int { return s.rf } -func (s *simpleStrategy) replicaMap(tokenRing *tokenRing) tokenRingReplicas { - tokens := tokenRing.tokens - ring := make(tokenRingReplicas, len(tokens)) +func (s *simpleStrategy) replicaMap(_ []*HostInfo, tokens []hostToken) map[token][]*HostInfo { + tokenRing := make(map[token][]*HostInfo, len(tokens)) for i, th := range tokens { replicas := make([]*HostInfo, 0, s.rf) - seen := make(map[*HostInfo]bool) - for j := 0; j < len(tokens) && len(replicas) < s.rf; j++ { h := tokens[(i+j)%len(tokens)] - if !seen[h.host] { - replicas = append(replicas, h.host) - seen[h.host] = true - } + replicas = append(replicas, h.host) } - - ring[i] = hostTokens{th.token, replicas} + tokenRing[th.token] = replicas } - sort.Sort(ring) - - return ring + return tokenRing } type networkTopology struct { @@ -148,16 +100,10 @@ func (n *networkTopology) haveRF(replicaCounts map[string]int) bool { return true } -func (n *networkTopology) replicaMap(tokenRing *tokenRing) tokenRingReplicas { - dcRacks := make(map[string]map[string]struct{}, len(n.dcs)) - // skipped hosts in a dc - skipped := make(map[string][]*HostInfo, len(n.dcs)) - // number of replicas per dc - replicasInDC := make(map[string]int, len(n.dcs)) - // dc -> racks - seenDCRacks := make(map[string]map[string]struct{}, len(n.dcs)) +func (n *networkTopology) replicaMap(hosts []*HostInfo, tokens []hostToken) map[token][]*HostInfo { + dcRacks := make(map[string]map[string]struct{}) - for _, h := range tokenRing.hosts { + for _, h := range hosts { dc := h.DataCenter() rack := h.Rack() @@ -169,30 +115,21 @@ func (n *networkTopology) replicaMap(tokenRing *tokenRing) tokenRingReplicas { racks[rack] = struct{}{} } - for dc, racks := range dcRacks { - replicasInDC[dc] = 0 - seenDCRacks[dc] = make(map[string]struct{}, len(racks)) - } - - tokens := tokenRing.tokens - replicaRing := make(tokenRingReplicas, len(tokens)) + tokenRing := make(map[token][]*HostInfo, len(tokens)) var totalRF int for _, rf := range n.dcs { totalRF += rf } - for i, th := range tokenRing.tokens { - for k, v := range skipped { - skipped[k] = v[:0] - } - - for dc := range n.dcs { - replicasInDC[dc] = 0 - for rack := range seenDCRacks[dc] { - delete(seenDCRacks[dc], rack) - } - } + for i, th := range tokens { + // number of replicas per dc + // TODO: recycle these + replicasInDC := make(map[string]int, len(n.dcs)) + // dc -> racks + seenDCRacks := make(map[string]map[string]struct{}, len(n.dcs)) + // skipped hosts in a dc + skipped := make(map[string][]*HostInfo, len(n.dcs)) replicas := make([]*HostInfo, 0, totalRF) for j := 0; j < len(tokens) && (len(replicas) < totalRF && !n.haveRF(replicasInDC)); j++ { @@ -259,18 +196,16 @@ func (n *networkTopology) replicaMap(tokenRing *tokenRing) tokenRingReplicas { } } - if len(replicas) == 0 { - panic(fmt.Sprintf("no replicas for token: %v", th.token)) - } else if !replicas[0].Equal(th.host) { - panic(fmt.Sprintf("first replica is not the primary replica for the token: expected %v got %v", replicas[0].ConnectAddress(), th.host.ConnectAddress())) + if len(replicas) == 0 || replicas[0] != th.host { + panic("first replica is not the primary replica for the token") } - replicaRing[i] = hostTokens{th.token, replicas} + tokenRing[th.token] = replicas } - if len(replicaRing) != len(tokens) { - panic(fmt.Sprintf("token map different size to token ring: got %d expected %d", len(replicaRing), len(tokens))) + if len(tokenRing) != len(tokens) { + panic(fmt.Sprintf("token map different size to token ring: got %d expected %d", len(tokenRing), len(tokens))) } - return replicaRing + return tokenRing } diff --git a/topology_test.go b/topology_test.go index cbc0f83a6..57229196b 100644 --- a/topology_test.go +++ b/topology_test.go @@ -12,7 +12,7 @@ func TestPlacementStrategy_SimpleStrategy(t *testing.T) { host50 := &HostInfo{hostId: "50"} host75 := &HostInfo{hostId: "75"} - tokens := []hostToken{ + tokenRing := []hostToken{ {intToken(0), host0}, {intToken(25), host25}, {intToken(50), host50}, @@ -22,27 +22,27 @@ func TestPlacementStrategy_SimpleStrategy(t *testing.T) { hosts := []*HostInfo{host0, host25, host50, host75} strat := &simpleStrategy{rf: 2} - tokenReplicas := strat.replicaMap(&tokenRing{hosts: hosts, tokens: tokens}) - if len(tokenReplicas) != len(tokens) { - t.Fatalf("expected replica map to have %d items but has %d", len(tokens), len(tokenReplicas)) + tokenReplicas := strat.replicaMap(hosts, tokenRing) + if len(tokenReplicas) != len(tokenRing) { + t.Fatalf("expected replica map to have %d items but has %d", len(tokenRing), len(tokenReplicas)) } - for _, replicas := range tokenReplicas { - if len(replicas.hosts) != strat.rf { - t.Errorf("expected to have %d replicas got %d for token=%v", strat.rf, len(replicas.hosts), replicas.token) + for token, replicas := range tokenReplicas { + if len(replicas) != strat.rf { + t.Errorf("expected to have %d replicas got %d for token=%v", strat.rf, len(replicas), token) } } - for i, token := range tokens { - ht := tokenReplicas.replicasFor(token.token) - if ht.token != token.token { - t.Errorf("token %v not in replica map: %v", token, ht.hosts) + for i, token := range tokenRing { + replicas, ok := tokenReplicas[token.token] + if !ok { + t.Errorf("token %v not in replica map", token) } - for j, replica := range ht.hosts { - exp := tokens[(i+j)%len(tokens)].host + for j, replica := range replicas { + exp := tokenRing[(i+j)%len(tokenRing)].host if exp != replica { - t.Errorf("expected host %v to be a replica of %v got %v", exp.hostId, token, replica.hostId) + t.Errorf("expected host %v to be a replica of %v got %v", exp, token, replica) } } } @@ -103,7 +103,7 @@ func TestPlacementStrategy_NetworkStrategy(t *testing.T) { expReplicas += rf } - tokenReplicas := strat.replicaMap(&tokenRing{hosts: hosts, tokens: tokens}) + tokenReplicas := strat.replicaMap(hosts, tokens) if len(tokenReplicas) != len(tokens) { t.Fatalf("expected replica map to have %d items but has %d", len(tokens), len(tokenReplicas)) } @@ -112,8 +112,8 @@ func TestPlacementStrategy_NetworkStrategy(t *testing.T) { } for token, replicas := range tokenReplicas { - if len(replicas.hosts) != expReplicas { - t.Fatalf("expected to have %d replicas got %d for token=%v", expReplicas, len(replicas.hosts), token) + if len(replicas) != expReplicas { + t.Fatalf("expected to have %d replicas got %d for token=%v", expReplicas, len(replicas), token) } } @@ -121,13 +121,13 @@ func TestPlacementStrategy_NetworkStrategy(t *testing.T) { dcTokens := dcRing[dc] for i, th := range dcTokens { token := th.token - allReplicas := tokenReplicas.replicasFor(token) - if allReplicas.token != token { + allReplicas, ok := tokenReplicas[token] + if !ok { t.Fatalf("token %v not in replica map", token) } var replicas []*HostInfo - for _, replica := range allReplicas.hosts { + for _, replica := range allReplicas { if replica.dataCenter == dc { replicas = append(replicas, replica) }